| author | Alan Dipert
<alan@dipert.org> 2025-12-04 04:31:13 UTC |
| committer | Alan Dipert
<alan@dipert.org> 2025-12-04 04:31:13 UTC |
| parent | 9611ab3b7b998dc303f30648127544eab0d65605 |
| main.py | +17 | -9 |
| test_generator.py | +67 | -30 |
diff --git a/main.py b/main.py index c145d48..7faa186 100644 --- a/main.py +++ b/main.py @@ -38,21 +38,29 @@ def main() -> None: actual_seed = args.seed if args.seed is not None else random.randint(0, 1_000_000) concepts = get_default_concepts() blueprint = get_default_blueprint() - max_attempts = 50 + max_attempts = 200 total_items = sum(s.num_items for s in blueprint.sections) def scale(base: int) -> int: return min(total_items, int(round(base * args.hardness_multiplier))) + req_irregular = max(1, scale(args.min_irregular)) + req_irregular_contrast = max(1, scale(args.min_irregular_contrast)) + req_ditransitive = max(1, min(scale(args.min_ditransitive), total_items)) + feasible_irregular = max(1, min(req_irregular, max(0, total_items - req_ditransitive))) + max_feature_load = 12 # estimated feasible max given grammar params = { - "min_irregular": max(1, scale(args.min_irregular)), - "min_irregular_contrast": max(1, scale(args.min_irregular_contrast)), - "min_irregular_distractor": max(3, scale(args.min_irregular_contrast)), - "min_ditransitive": max(1, scale(args.min_ditransitive)), - "min_plural": max(1, scale(args.min_plural)), - "min_adjective": max(1, scale(args.min_adjective)), - "min_fem_plural": max(1, scale(args.min_fem_plural)), - "min_feature_load": max(args.min_feature_load, int(round(args.min_feature_load * args.hardness_multiplier))), + "min_irregular": feasible_irregular, + "min_irregular_contrast": min(req_irregular_contrast, feasible_irregular), + "min_irregular_distractor": max(3, min(req_irregular_contrast, feasible_irregular)), + "min_ditransitive": req_ditransitive, + "min_plural": max(1, min(scale(args.min_plural), total_items)), + "min_adjective": max(1, min(scale(args.min_adjective), total_items)), + "min_fem_plural": max(1, min(scale(args.min_fem_plural), total_items)), + "min_feature_load": max( + args.min_feature_load, + min(max_feature_load, int(round(args.min_feature_load * args.hardness_multiplier))), + ), "hardness_multiplier": args.hardness_multiplier, } diff --git a/test_generator.py b/test_generator.py index 74b353a..9c299b1 100644 --- a/test_generator.py +++ b/test_generator.py @@ -444,6 +444,7 @@ def _planned_features( difficulty: str, remaining: Dict[str, int], idx: int, + items_left: int, ) -> tuple[SentenceFeatures, Dict[str, int]]: """Greedy planner to satisfy coverage quotas deterministically with overlap. @@ -475,15 +476,25 @@ def _planned_features( use_irregular_verb = True delta: Dict[str, int] = {"irregular_verb": 0, "irregular_noun": 0, "ditransitive": 0, "fem_plural": 0, "plural": 0, "adjective": 0} - # 1) Irregular verb coverage (monotransitive chase past) - if remaining.get("irregular_verb", 0) > 0: + # 0) Force ditransitive if quota equals items_left + if remaining.get("ditransitive", 0) >= items_left: + verb_id = "give" + obj2 = theme_pool[idx % len(theme_pool)] + delta["ditransitive"] = 1 + # 1) Irregular verb coverage (monotransitive chase past), forced if needed + if verb_id == "see" and remaining.get("irregular_verb", 0) >= items_left: + verb_id = "chase" + tense = "PAST" + use_irregular_verb = True + delta["irregular_verb"] = 1 + elif verb_id == "see" and remaining.get("irregular_verb", 0) > 0: verb_id = "chase" tense = "PAST" use_irregular_verb = True delta["irregular_verb"] = 1 # 2) Ditransitive coverage (can overlap with irregular noun) - if verb_id != "chase" and remaining.get("ditransitive", 0) > 0: + if verb_id != "chase" and remaining.get("ditransitive", 0) > 0 and delta.get("ditransitive", 0) == 0: verb_id = "give" obj2 = theme_pool[idx % len(theme_pool)] delta["ditransitive"] = 1 @@ -492,19 +503,33 @@ def _planned_features( if remaining.get("irregular_noun", 0) > 0: obj1 = np_features("boy", RECIPIENT, plural=True, adjectives=["red"], use_irregular=True) delta["irregular_noun"] = 1 + # if we still need more irregular noun items, also set subject to irregular boy plural + if remaining.get("irregular_noun", 0) - delta["irregular_noun"] > 0: + subj = np_features("boy", AGENT, plural=True, adjectives=subj.adjectives or ["tall"], use_irregular=True) # 4) fem plural receiver (if applicable) - if remaining.get("fem_plural", 0) > 0 and obj1.noun_id in {"woman", "girl"}: - obj1 = np_features(obj1.noun_id, RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives) - delta["fem_plural"] = 1 + must_fem = remaining.get("fem_plural", 0) >= items_left + if (remaining.get("fem_plural", 0) > 0 or must_fem): + if obj1.noun_id not in {"woman", "girl"}: + obj1 = np_features("woman", RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives or ["red"]) + delta["fem_plural"] = 1 + elif not (obj1.feminine and obj1.plural): + obj1 = np_features(obj1.noun_id, RECIPIENT, feminine=True, plural=True, adjectives=obj1.adjectives) + delta["fem_plural"] = 1 # 5) plural coverage (if still needed) - if remaining.get("plural", 0) > 0 and not (obj1.plural or subj.plural or (obj2 and obj2.plural)): + must_plural = remaining.get("plural", 0) >= items_left + if (remaining.get("plural", 0) > 0 or must_plural) and not (obj1.plural or subj.plural or (obj2 and obj2.plural)): obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=True, adjectives=obj1.adjectives) delta["plural"] = 1 + # secondary plural if still needed and possible + if remaining.get("plural", 0) - delta.get("plural", 0) > 0 and not subj.plural: + subj = NPFeature(subj.noun_id, subj.feminine, True, subj.adjectives, subj.role, subj.use_irregular) + delta["plural"] = delta.get("plural", 0) + 1 # 6) adjective coverage - if remaining.get("adjective", 0) > 0: + must_adj = remaining.get("adjective", 0) >= items_left + if (remaining.get("adjective", 0) > 0 or must_adj): if not obj1.adjectives: obj1 = np_features(obj1.noun_id, obj1.role, feminine=obj1.feminine, plural=obj1.plural, adjectives=["red"]) delta["adjective"] = 1 @@ -556,37 +581,48 @@ def _boost_feature_load(sf: SentenceFeatures, target: int) -> SentenceFeatures: verb_id = sf.verb_id tense = sf.tense use_irregular_verb = sf.use_irregular_verb + adj_pool = ["red", "fast", "big", "tall"] - def add_adj(np: NPFeature, adj: str) -> NPFeature: - if adj in np.adjectives: - return np - return NPFeature(np.noun_id, np.feminine, np.plural, np.adjectives + [adj], np.role, np.use_irregular) + def add_adj(np: NPFeature) -> NPFeature: + for adj in adj_pool: + if adj not in np.adjectives: + return NPFeature(np.noun_id, np.feminine, np.plural, np.adjectives + [adj], np.role, np.use_irregular) + return np steps = [ - lambda: ("tense", None), - lambda: ("obj1_adj", None), - lambda: ("subj_adj", None), - lambda: ("obj1_plural", None), - lambda: ("subj_plural", None), - lambda: ("obj2_adj", None), + "tense", + "obj1_adj", + "subj_adj", + "obj1_plural", + "subj_plural", + "obj2_adj", + "obj1_second_adj", + "subj_second_adj", + "obj2_second_adj", ] - for _ in range(12): + for i in range(24): current = _feature_load(SentenceFeatures(subj, obj1, obj2, verb_id, tense, use_irregular_verb)) if current >= target: break - step_name, _ = steps[_ % len(steps)]() + step_name = steps[i % len(steps)] if step_name == "tense" and tense == "PRES": tense = "PAST" elif step_name == "obj1_adj": - obj1 = add_adj(obj1, "red") + obj1 = add_adj(obj1) elif step_name == "subj_adj": - subj = add_adj(subj, "tall") + subj = add_adj(subj) elif step_name == "obj1_plural" and not obj1.plural: obj1 = NPFeature(obj1.noun_id, obj1.feminine, True, obj1.adjectives, obj1.role, obj1.use_irregular) elif step_name == "subj_plural" and not subj.plural: subj = NPFeature(subj.noun_id, subj.feminine, True, subj.adjectives, subj.role, subj.use_irregular) elif step_name == "obj2_adj" and obj2: - obj2 = add_adj(obj2, "red") + obj2 = add_adj(obj2) + elif step_name == "obj1_second_adj": + obj1 = add_adj(obj1) + elif step_name == "subj_second_adj": + subj = add_adj(subj) + elif step_name == "obj2_second_adj" and obj2: + obj2 = add_adj(obj2) return SentenceFeatures(subj, obj1, obj2, verb_id, tense, use_irregular_verb) @@ -671,12 +707,12 @@ def generate_test( irregular_target = cfg.get("min_irregular", 6) # counters for coverage remaining = { - "irregular_noun": irregular_target, - "irregular_verb": irregular_target, - "ditransitive": cfg.get("min_ditransitive", 8), - "fem_plural": cfg.get("min_fem_plural", 4), - "plural": cfg.get("min_plural", 12), - "adjective": cfg.get("min_adjective", 12), + "irregular_noun": min(irregular_target, total_items), + "irregular_verb": min(irregular_target, total_items), + "ditransitive": min(cfg.get("min_ditransitive", 8), total_items), + "fem_plural": min(cfg.get("min_fem_plural", 4), total_items), + "plural": min(cfg.get("min_plural", 12), total_items), + "adjective": min(cfg.get("min_adjective", 12), total_items), } for section in blueprint.sections: @@ -690,7 +726,8 @@ def generate_test( current_number = question_counter + len(questions) difficulty_tag = "early" if current_number <= 8 else "mid" if current_number <= 24 else "late" - sf_override, delta = _planned_features(spec, rng, difficulty_tag, remaining.copy(), current_number) + items_left = total_items - current_number + 1 + sf_override, delta = _planned_features(spec, rng, difficulty_tag, remaining.copy(), current_number, items_left) q = generate_item( spec, section.focus_concepts,