git » alan.git » commit 01d1ae5

Improve planner feasibility and force remaining quotas (fem plural, plural, adj)

author Alan Dipert
2025-12-04 04:31:13 UTC
committer Alan Dipert
2025-12-04 04:31:13 UTC
parent 9611ab3b7b998dc303f30648127544eab0d65605

Improve planner feasibility and force remaining quotas (fem plural, plural, adj)

main.py +17 -9
test_generator.py +67 -30

diff --git a/main.py b/main.py
index c145d48..7faa186 100644
--- a/main.py
+++ b/main.py
@@ -38,21 +38,29 @@ def main() -> None:
     actual_seed = args.seed if args.seed is not None else random.randint(0, 1_000_000)
     concepts = get_default_concepts()
     blueprint = get_default_blueprint()
-    max_attempts = 50
+    max_attempts = 200
     total_items = sum(s.num_items for s in blueprint.sections)
 
     def scale(base: int) -> int:
         return min(total_items, int(round(base * args.hardness_multiplier)))
 
+    req_irregular = max(1, scale(args.min_irregular))
+    req_irregular_contrast = max(1, scale(args.min_irregular_contrast))
+    req_ditransitive = max(1, min(scale(args.min_ditransitive), total_items))
+    feasible_irregular = max(1, min(req_irregular, max(0, total_items - req_ditransitive)))
+    max_feature_load = 12  # estimated feasible max given grammar
     params = {
-        "min_irregular": max(1, scale(args.min_irregular)),
-        "min_irregular_contrast": max(1, scale(args.min_irregular_contrast)),
-        "min_irregular_distractor": max(3, scale(args.min_irregular_contrast)),
-        "min_ditransitive": max(1, scale(args.min_ditransitive)),
-        "min_plural": max(1, scale(args.min_plural)),
-        "min_adjective": max(1, scale(args.min_adjective)),
-        "min_fem_plural": max(1, scale(args.min_fem_plural)),
-        "min_feature_load": max(args.min_feature_load, int(round(args.min_feature_load * args.hardness_multiplier))),
+        "min_irregular": feasible_irregular,
+        "min_irregular_contrast": min(req_irregular_contrast, feasible_irregular),
+        "min_irregular_distractor": max(3, min(req_irregular_contrast, feasible_irregular)),
+        "min_ditransitive": req_ditransitive,
+        "min_plural": max(1, min(scale(args.min_plural), total_items)),
+        "min_adjective": max(1, min(scale(args.min_adjective), total_items)),
+        "min_fem_plural": max(1, min(scale(args.min_fem_plural), total_items)),
+        "min_feature_load": max(
+            args.min_feature_load,
+            min(max_feature_load, int(round(args.min_feature_load * args.hardness_multiplier))),
+        ),
         "hardness_multiplier": args.hardness_multiplier,
     }
 
diff --git a/test_generator.py b/test_generator.py
index 74b353a..9c299b1 100644
--- a/test_generator.py
+++ b/test_generator.py
@@ -444,6 +444,7 @@ def _planned_features(
     difficulty: str,
     remaining: Dict[str, int],
     idx: int,
+    items_left: int,
 ) -> tuple[SentenceFeatures, Dict[str, int]]:
     """Greedy planner to satisfy coverage quotas deterministically with overlap.
 
@@ -475,15 +476,25 @@ def _planned_features(
     use_irregular_verb = True
     delta: Dict[str, int] = {"irregular_verb": 0, "irregular_noun": 0, "ditransitive": 0, "fem_plural": 0, "plural": 0, "adjective": 0}
 
-    # 1) Irregular verb coverage (monotransitive chase past)
-    if remaining.get("irregular_verb", 0) > 0:
+    # 0) Force ditransitive if quota equals items_left
+    if remaining.get("ditransitive", 0) >= items_left:
+        verb_id = "give"
+        obj2 = theme_pool[idx % len(theme_pool)]
+        delta["ditransitive"] = 1
+    # 1) Irregular verb coverage (monotransitive chase past), forced if needed
+    if verb_id == "see" and remaining.get("irregular_verb", 0) >= items_left:
+        verb_id = "chase"
+        tense = "PAST"
+        use_irregular_verb = True
+        delta["irregular_verb"] = 1
+    elif verb_id == "see" and remaining.get("irregular_verb", 0) > 0:
         verb_id = "chase"
         tense = "PAST"
         use_irregular_verb = True
         delta["irregular_verb"] = 1
 
     # 2) Ditransitive coverage (can overlap with irregular noun)
-    if verb_id != "chase" and remaining.get("ditransitive", 0) > 0:
+    if verb_id != "chase" and remaining.get("ditransitive", 0) > 0 and delta.get("ditransitive", 0) == 0:
         verb_id = "give"
         obj2 = theme_pool[idx % len(theme_pool)]
         delta["ditransitive"] = 1
@@ -492,19 +503,33 @@ def _planned_features(
     if remaining.get("irregular_noun", 0) > 0:
         obj1 = np_features("boy", RECIPIENT, plural=True, adjectives=["red"], use_irregular=True)
         delta["irregular_noun"] = 1
+        # if we still need more irregular noun items, also set subject to irregular boy plural
+        if remaining.get("irregular_noun", 0) - delta["irregular_noun"] > 0:
+            subj = np_features("boy", AGENT, plural=True, adjectives=subj.adjectives or ["tall"], use_irregular=True)
 
     # 4) fem plural receiver (if applicable)
-    if remaining.get("fem_plural", 0) > 0 and obj1.noun_id in {"woman", "girl"}:
-        obj1 = np_features(obj1.noun_id, RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives)
-        delta["fem_plural"] = 1
+    must_fem = remaining.get("fem_plural", 0) >= items_left
+    if (remaining.get("fem_plural", 0) > 0 or must_fem):
+        if obj1.noun_id not in {"woman", "girl"}:
+            obj1 = np_features("woman", RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives or ["red"])
+            delta["fem_plural"] = 1
+        elif not (obj1.feminine and obj1.plural):
+            obj1 = np_features(obj1.noun_id, RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives)
+            delta["fem_plural"] = 1
 
     # 5) plural coverage (if still needed)
-    if remaining.get("plural", 0) > 0 and not (obj1.plural or subj.plural or (obj2 and obj2.plural)):
+    must_plural = remaining.get("plural", 0) >= items_left
+    if (remaining.get("plural", 0) > 0 or must_plural) and not (obj1.plural or subj.plural or (obj2 and obj2.plural)):
         obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=True, adjectives=obj1.adjectives)
         delta["plural"] = 1
+    # secondary plural if still needed and possible
+    if remaining.get("plural", 0) - delta.get("plural", 0) > 0 and not subj.plural:
+        subj = NPFeature(subj.noun_id, subj.feminine, True, subj.adjectives, subj.role, subj.use_irregular)
+        delta["plural"] = delta.get("plural", 0) + 1
 
     # 6) adjective coverage
-    if remaining.get("adjective", 0) > 0:
+    must_adj = remaining.get("adjective", 0) >= items_left
+    if (remaining.get("adjective", 0) > 0 or must_adj):
         if not obj1.adjectives:
             obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=obj1.plural, adjectives=["red"])
             delta["adjective"] = 1
@@ -556,37 +581,48 @@ def _boost_feature_load(sf: SentenceFeatures, target: int) -> SentenceFeatures:
     verb_id = sf.verb_id
     tense = sf.tense
     use_irregular_verb = sf.use_irregular_verb
+    adj_pool = ["red", "fast", "big", "tall"]
 
-    def add_adj(np: NPFeature, adj: str) -> NPFeature:
-        if adj in np.adjectives:
-            return np
-        return NPFeature(np.noun_id, np.feminine, np.plural, np.adjectives + [adj], np.role, np.use_irregular)
+    def add_adj(np: NPFeature) -> NPFeature:
+        for adj in adj_pool:
+            if adj not in np.adjectives:
+                return NPFeature(np.noun_id, np.feminine, np.plural, np.adjectives + [adj], np.role, np.use_irregular)
+        return np
 
     steps = [
-        lambda: ("tense", None),
-        lambda: ("obj1_adj", None),
-        lambda: ("subj_adj", None),
-        lambda: ("obj1_plural", None),
-        lambda: ("subj_plural", None),
-        lambda: ("obj2_adj", None),
+        "tense",
+        "obj1_adj",
+        "subj_adj",
+        "obj1_plural",
+        "subj_plural",
+        "obj2_adj",
+        "obj1_second_adj",
+        "subj_second_adj",
+        "obj2_second_adj",
     ]
-    for _ in range(12):
+    for i in range(24):
         current = _feature_load(SentenceFeatures(subj, obj1, obj2, verb_id, tense, use_irregular_verb))
         if current >= target:
             break
-        step_name, _ = steps[_ % len(steps)]()
+        step_name = steps[i % len(steps)]
         if step_name == "tense" and tense == "PRES":
             tense = "PAST"
         elif step_name == "obj1_adj":
-            obj1 = add_adj(obj1, "red")
+            obj1 = add_adj(obj1)
         elif step_name == "subj_adj":
-            subj = add_adj(subj, "tall")
+            subj = add_adj(subj)
         elif step_name == "obj1_plural" and not obj1.plural:
             obj1 = NPFeature(obj1.noun_id, obj1.feminine, True, obj1.adjectives, obj1.role, obj1.use_irregular)
         elif step_name == "subj_plural" and not subj.plural:
             subj = NPFeature(subj.noun_id, subj.feminine, True, subj.adjectives, subj.role, subj.use_irregular)
         elif step_name == "obj2_adj" and obj2:
-            obj2 = add_adj(obj2, "red")
+            obj2 = add_adj(obj2)
+        elif step_name == "obj1_second_adj":
+            obj1 = add_adj(obj1)
+        elif step_name == "subj_second_adj":
+            subj = add_adj(subj)
+        elif step_name == "obj2_second_adj" and obj2:
+            obj2 = add_adj(obj2)
     return SentenceFeatures(subj, obj1, obj2, verb_id, tense, use_irregular_verb)
 
 
@@ -671,12 +707,12 @@ def generate_test(
     irregular_target = cfg.get("min_irregular", 6)
     # counters for coverage
     remaining = {
-        "irregular_noun": irregular_target,
-        "irregular_verb": irregular_target,
-        "ditransitive": cfg.get("min_ditransitive", 8),
-        "fem_plural": cfg.get("min_fem_plural", 4),
-        "plural": cfg.get("min_plural", 12),
-        "adjective": cfg.get("min_adjective", 12),
+        "irregular_noun": min(irregular_target, total_items),
+        "irregular_verb": min(irregular_target, total_items),
+        "ditransitive": min(cfg.get("min_ditransitive", 8), total_items),
+        "fem_plural": min(cfg.get("min_fem_plural", 4), total_items),
+        "plural": min(cfg.get("min_plural", 12), total_items),
+        "adjective": min(cfg.get("min_adjective", 12), total_items),
     }
 
     for section in blueprint.sections:
@@ -690,7 +726,8 @@ def generate_test(
             current_number = question_counter + len(questions)
             difficulty_tag = "early" if current_number <= 8 else "mid" if current_number <= 24 else "late"
 
-            sf_override, delta = _planned_features(spec, rng, difficulty_tag, remaining.copy(), current_number)
+            items_left = total_items - current_number + 1
+            sf_override, delta = _planned_features(spec, rng, difficulty_tag, remaining.copy(), current_number, items_left)
             q = generate_item(
                 spec,
                 section.focus_concepts,