git » alan.git » commit 3fe4dbf

Plan feature overrides to satisfy strict quotas deterministically

author Alan Dipert
2025-12-04 04:09:42 UTC
committer Alan Dipert
2025-12-04 04:09:42 UTC
parent b968566df2851756252fa42f7743056af2e6eb06

Plan feature overrides to satisfy strict quotas deterministically

test_generator.py +79 -41

diff --git a/test_generator.py b/test_generator.py
index 328639f..0f220bb 100644
--- a/test_generator.py
+++ b/test_generator.py
@@ -433,6 +433,75 @@ def _base_features(spec: LanguageSpec, rng: random.Random, difficulty: str) -> S
     return sentence_features(verb_id=verb_id, tense=tense, subj=subj, obj1=obj1, obj2=obj2)
 
 
+def _planned_features(
+    spec: LanguageSpec,
+    rng: random.Random,
+    difficulty: str,
+    remaining: Dict[str, int],
+    idx: int,
+) -> SentenceFeatures:
+    """Greedy planner to satisfy coverage quotas deterministically."""
+    # cycles for variation
+    subj_pool = [
+        np_features("man", AGENT, plural=False, adjectives=["tall"]),
+        np_features("woman", AGENT, plural=False, adjectives=["tall"]),
+        np_features("man", AGENT, plural=True, adjectives=["tall"]),
+        np_features("woman", AGENT, plural=True, adjectives=["tall"]),
+    ]
+    rec_pool = [
+        np_features("boy", RECIPIENT, plural=False, adjectives=["red"]),
+        np_features("girl", RECIPIENT, plural=False, adjectives=["red"]),
+        np_features("man", RECIPIENT, plural=False, adjectives=["red"]),
+        np_features("woman", RECIPIENT, plural=False, adjectives=["red"]),
+    ]
+    theme_pool = [
+        np_features("ball", THEME, plural=False, adjectives=["red"]),
+        np_features("house", THEME, plural=False, adjectives=["red"]),
+    ]
+
+    subj = subj_pool[idx % len(subj_pool)]
+    obj1 = rec_pool[idx % len(rec_pool)]
+    obj2 = None
+    verb_id = "see"
+    tense = "PRES"
+    use_irregular_verb = True
+
+    # Priority: irregular verb, irregular noun, ditransitive, fem plural, plural, adjective
+    if remaining.get("irregular_verb", 0) > 0:
+        verb_id = "chase"
+        tense = "PAST"
+        use_irregular_verb = True
+        remaining["irregular_verb"] -= 1
+    elif remaining.get("irregular_noun", 0) > 0:
+        obj1 = np_features("boy", RECIPIENT, plural=True, adjectives=["red"], use_irregular=True)
+        remaining["irregular_noun"] -= 1
+    elif remaining.get("ditransitive", 0) > 0:
+        verb_id = "give"
+        obj2 = theme_pool[idx % len(theme_pool)]
+        remaining["ditransitive"] -= 1
+
+    # fem plural receiver
+    if remaining.get("fem_plural", 0) > 0 and obj1.noun_id in {"woman", "girl"}:
+        obj1 = np_features(obj1.noun_id, RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives)
+        remaining["fem_plural"] -= 1
+
+    # enforce plural coverage
+    if remaining.get("plural", 0) > 0 and not (obj1.plural or subj.plural or (obj2 and obj2.plural)):
+        obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=True, adjectives=obj1.adjectives)
+        remaining["plural"] -= 1
+
+    # enforce adjective coverage
+    if remaining.get("adjective", 0) > 0:
+        if not obj1.adjectives:
+            obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=obj1.plural, adjectives=["red"])
+            remaining["adjective"] -= 1
+        elif not subj.adjectives:
+            subj = np_features(subj.noun_id, subj.role, feminine=subj.feminine, plural=subj.plural, adjectives=["tall"])
+            remaining["adjective"] -= 1
+
+    return sentence_features(verb_id, tense, subj, obj1, obj2, use_irregular_verb=use_irregular_verb)
+
+
 def _difficulty_score(sf: SentenceFeatures, irregular: bool) -> float:
     score = 0
     for np in [sf.subject, sf.obj1] + ([sf.obj2] if sf.obj2 else []):
@@ -527,11 +596,15 @@ def generate_test(
         return slots
 
     irregular_target = cfg.get("min_irregular", 6)
-    # require both irregulars to meet the target
-    irregular_noun_slots = spaced_slots(min(irregular_target, total_items))
-    irregular_verb_slots = spaced_slots(min(irregular_target, total_items))
-    ditransitive_slots = spaced_slots(cfg.get("min_ditransitive", 8))
-    fem_plural_slots = spaced_slots(cfg.get("min_fem_plural", 4))
+    # counters for coverage
+    remaining = {
+        "irregular_noun": irregular_target,
+        "irregular_verb": irregular_target,
+        "ditransitive": cfg.get("min_ditransitive", 8),
+        "fem_plural": cfg.get("min_fem_plural", 4),
+        "plural": cfg.get("min_plural", 12),
+        "adjective": cfg.get("min_adjective", 12),
+    }
 
     for section in blueprint.sections:
         questions: List[Question] = []
@@ -544,42 +617,7 @@ def generate_test(
             current_number = question_counter + len(questions)
             difficulty_tag = "early" if current_number <= 8 else "mid" if current_number <= 24 else "late"
 
-            sf_override: Optional[SentenceFeatures] = None
-            if current_number in irregular_noun_slots:
-                base = _base_features(spec, rng, difficulty_tag)
-                base.obj1.noun_id = "boy"
-                base.obj1.feminine = False
-                base.obj1.plural = True  # letul
-                sf_override = base
-            elif current_number in irregular_verb_slots:
-                base = _base_features(spec, rng, difficulty_tag)
-                base.verb_id = "chase"
-                base.tense = "PAST"  # rontmimu
-                base.obj2 = None
-                sf_override = base
-            elif current_number in fem_plural_slots:
-                base = _base_features(spec, rng, difficulty_tag)
-                base.obj1 = np_features(
-                    noun_id=rng.choice(["woman", "girl"]),
-                    role=RECIPIENT,
-                    feminine=True,
-                    plural=True,
-                    adjectives=base.obj1.adjectives or ["red"],
-                )
-                sf_override = base
-            elif current_number in ditransitive_slots:
-                base = _base_features(spec, rng, difficulty_tag)
-                base.verb_id = "give"
-                # ensure ditransitive objects are well formed
-                base.obj2 = np_features(
-                    noun_id=rng.choice(["ball", "house"]),
-                    role=THEME,
-                    plural=difficulty_tag == "late" and rng.random() < 0.5,
-                    adjectives=["red"] if rng.random() < 0.6 else [],
-                )
-                base.obj1.role = RECIPIENT
-                sf_override = base
-
+            sf_override = _planned_features(spec, rng, difficulty_tag, remaining, current_number)
             q = generate_item(
                 spec, section.focus_concepts, section.id, item_type, rng, difficulty=difficulty_tag, sf_override=sf_override
             )