PHON-104 — CSP perf: vectorize enumeration via Polars¶
Date: 2026-05-08
Branch: feature/csp-iteration
Scope: spike-internal perf optimization (productionization to packages/generators/csp/ is PHON-109's scope)
Frame¶
The CSP spike's hot path enumerates a Cartesian over per-slot fillers in solve_shape (skeleton_csp.py) and a structurally similar nested loop in solve() (paradigm_3_csp.py). The largest acceptance probe (melt × spec6) produces 9,120 (nsubj × dobj) pairs × 30 adverbs = 273,600 Cartesian points and runs in 2.19s. The dominant cost is Python iteration overhead — generator-based recursion in enumerate_assignments, dict construction per assignment, function-call dispatch.
The original ticket title — "per-slot top-N pruning" — is unsafe: dropping low-PMI candidates before the reranker sees them risks losing candidates that the teacher-distilled quality reranker (Step 15, Spearman 0.633 on held-out (verb, band) groups) would otherwise rank into top-K. The reranker uses MiniLM features on full sentences and cannot be evaluated per-filler.
This ticket reframes PHON-104 as vectorize the enumeration in Polars. No candidates are dropped; the same full Cartesian is produced and scored, just orders of magnitude faster.
Goal¶
For the largest acceptance probe (melt × spec6), expect 10×+ wall-clock speedup. Bit-identical top-K output between vectorized and Python paths under the same inputs.
Non-goals¶
- Per-slot top-N pruning (rejected as unsafe — reranker can re-rank low-PMI candidates into top-K).
- Vectorizing cross-slot scorers (ContrastiveConstraint minpair check). Falls back to Python; PHON-106 reworks contrastive scoring anyway.
- Productionization move to
packages/generators/csp/(PHON-109 scope). - Numpy/JAX backend variants. Polars is the codebase's data layer.
Architecture¶
Two-part work, both committed under PHON-104:
Part A — solve() delegation¶
paradigm_3_csp.solve() is rewritten as a thin wrapper around solve_shape(). Constructs a SkeletonShape from the implicit nsubj,V,dobj[,advmod] arg structure, calls solve_shape, repackages the return value into the legacy (top, stats) tuple. Public signature unchanged.
paradigm_3_csp.py
└── solve() → constructs SkeletonShape("nsubj,V,dobj" + optional advmod)
→ calls solve_shape() under the hood
→ wraps return into (top, stats) via _build_solve_stats()
→ preserves existing public signature
Part B — solve_shape vectorization¶
skeleton_csp.py
└── solve_shape (existing API surface, rewritten internals)
├── _build_slot_filler_tables (per-slot fillers + scores → list of polars frames)
├── _enumerate_vectorized (cartesian via cross-join + score columns)
└── _enumerate_python_fallback (extracted from current enumerate_assignments;
used when ContrastiveConstraint present)
Routing decision: at the top of solve_shape, inspect cross_axes — if any scorer is registered (i.e., ContrastiveConstraint is in the request), take the Python fallback. Else, take the vectorized path. ccomp shapes recursively call solve_shape; the inner solve routes by its own constraint set independently.
solve_shape's public contract is unchanged — same args, same return shape (list[dict] with score_components, total_score, sentence, etc.). Only the internals swap.
Vectorized enumeration¶
def _enumerate_vectorized(
shape: SkeletonShape,
slot_fillers: list[tuple[str, list[str], dict[str, float]]],
word_axes: dict[str, dict[str, float]],
weights: dict[str, float] | None,
locked_slots: dict[str, str],
) -> pl.DataFrame:
# Build per-slot frames first
slot_frames = {}
for slot, fillers, scores in slot_fillers:
if slot in locked_slots:
w = locked_slots[slot]
slot_frames[slot] = pl.DataFrame({
slot: [w],
f"pmi_{slot}": [scores.get(w, 0.0)],
})
else:
slot_frames[slot] = pl.DataFrame({
slot: fillers,
f"pmi_{slot}": [scores.get(f, 0.0) for f in fillers],
})
# Cartesian via successive cross joins
cart = slot_frames[shape.slots[0]]
for s in shape.slots[1:]:
cart = cart.join(slot_frames[s], how="cross")
# nsubj != dobj invariant (only when both present)
if "nsubj" in shape.slots and "dobj" in shape.slots:
cart = cart.filter(pl.col("nsubj") != pl.col("dobj"))
# Per-word soft axes — sum contributions across content slots
for axis_name, lookup in word_axes.items():
contributions = [
pl.col(content_slot).replace_strict(lookup, default=0.0).cast(pl.Float64)
for content_slot in shape.content_slots
]
cart = cart.with_columns(pl.sum_horizontal(contributions).alias(axis_name))
# Adverb sentinel (constant; only when shape has advmod with no real PMI)
if "advmod" in shape.slots:
cart = cart.with_columns(pl.lit(0.001).alias("adv_sentinel"))
# Total score = weighted sum of all score columns
score_cols = [
c for c in cart.columns
if c.startswith("pmi_") or c in word_axes or c == "adv_sentinel"
]
weighted = [
pl.col(c) * (weights.get(c, 1.0) if weights else 1.0)
for c in score_cols
]
cart = cart.with_columns(pl.sum_horizontal(weighted).alias("total_score"))
return cart
Dedup + truncate (Polars expressions)¶
content_keys = list(shape.content_slots)
deduped = (
cart
.sort("total_score", descending=True)
.unique(subset=content_keys, keep="first")
.head(top_k * over_fetch)
)
Final step iterates the small deduped frame in Python to build the list[dict] return shape (sentence realization via realize(), ccomp resolution via _resolve_embedded_clause, etc.). Python loop runs over top_k * over_fetch rows (8–32 typically), not the full Cartesian.
solve() delegation¶
def solve(
verb: str,
spec_id: str,
spec_words: frozenset[str],
sel_df: pl.DataFrame,
*,
constraints: list[Constraint] | None = None,
word_df: pl.DataFrame | None = None,
band: str = BAND,
top_k: int = TOP_K,
include_adverb: bool = True,
weights: dict[str, float] | None = None,
) -> tuple[list[dict], dict]:
"""Solve nsubj-V-dobj (+ optional advmod) shape via solve_shape."""
constraints = list(constraints or [])
arg = "nsubj,V,dobj,advmod" if include_adverb else "nsubj,V,dobj"
shape = SkeletonShape(arg, parse_arg_structure(arg), band_freq=0)
filtered_spec, trace = _resolve_domain_words(spec_words, constraints, word_df)
word_axes = get_per_word_axes(constraints, word_df)
cross_axes = cross_slot_axes(constraints)
# Capture pre-cartesian domain sizes for stats parity (otherwise
# delegated stats would only see post-truncation candidates).
domain_sizes = _peek_domain_sizes(verb, band, filtered_spec, sel_df, include_adverb)
candidates = solve_shape(
shape,
verb=verb, domain_words=filtered_spec, sel_df=sel_df,
band=band, word_axes=word_axes, cross_axes=cross_axes,
word_df=word_df, weights=weights, top_k=top_k,
)
stats = _build_solve_stats(
verb=verb, spec_id=spec_id, band=band,
candidates=candidates, trace=trace,
word_axes=word_axes, cross_axes=cross_axes,
domain_sizes=domain_sizes,
)
return candidates, stats
_peek_domain_sizes(...) does a cheap PMI lookup per slot (already cached via _ADVMOD_PMI_CACHE for adv) and intersects with filtered_spec. Returns {nsubj: int, dobj: int, advmod: int}. Lets stats stay accurate at minimal cost.
_build_solve_stats(...) synthesizes the legacy stats dict shape:
def _build_solve_stats(*, verb, spec_id, band, candidates, trace, word_axes, cross_axes, domain_sizes) -> dict:
return {
"verb": verb,
"spec_id": spec_id,
"band": band,
"nsubj_domain_size": domain_sizes.get("nsubj", 0),
"dobj_domain_size": domain_sizes.get("dobj", 0),
"adv_domain_size": domain_sizes.get("advmod", 0),
"candidate_count": len(candidates),
"unique_pairs": len({(c.get("nsubj"), c.get("dobj")) for c in candidates}),
"domain_trace": trace,
"active_axes": list(word_axes.keys()) + list(cross_axes.keys()),
}
Edge cases & invariants¶
| Case | Behavior |
|---|---|
| Empty slot fillers | _slot_fillers returns [] for some slot → solve_shape returns []. Vectorized path: cross-join with empty frame → empty cart → empty result. |
| Locked slot whose locked value isn't in the PMI table | 1-row frame uses scores.get(w, 0.0) for that slot's PMI column — 0.0 contribution, candidate still appears. Same outcome as Python path. |
| ccomp recursive solve | _resolve_embedded_clause calls solve_shape recursively. Inner call routes vectorized vs fallback by its own constraint set (no contrastive in inner solve typically → vectorized). |
| Contrastive present + ccomp shape | Outer falls back to Python; inner is independent solve_shape — also routes by its own constraints. Both paths work. |
nsubj == dobj filter |
Vectorized: pl.col("nsubj") != pl.col("dobj") filter after cross-join, only when both slots present. Python fallback: existing inner if. Same outcome. |
| Adverb sentinel + real advmod PMI | _slot_fillers returns verb_pmi OR band_fallback, never both. adv_sentinel column only present when fallback fired. Invariant holds. |
weights=None (default 1.0 per axis) |
Collapses to constant 1.0. Same as Python path. |
top_k * over_fetch larger than deduped pool |
head(top_k * over_fetch) is bounded by frame height. Same outcome. |
Invariants the vectorized path must preserve (verified by tests):
- Top-K candidates by
total_scoreare bit-identical to the Python path under same inputs (modulo float ordering when scores tie — handled by stable sort + content-pair dedup in both paths). score_componentsdict on each returned candidate has the same keys as the Python path (one entry per active axis).- Locked slots produce candidates where the locked filler's word is at the locked position, exactly as Python does.
- Contrastive present → vectorized path is NOT taken; cross_axes scorer fires per-candidate as today.
Testing¶
New test_vectorized_enumeration.py next to other spike tests. Three buckets:
Equivalence (vectorized == Python)¶
Parameterized over the PHON-95 acceptance probe matrix plus a locked-slot probe and a PP shape:
@pytest.mark.parametrize("verb,spec_id,arg_structure", [
("cut", "spec1", "nsubj,V,dobj"),
("cut", "spec1", "nsubj,V,dobj,advmod"),
("chase", "spec1", "nsubj,V,dobj,advmod"),
("melt", "spec6", "nsubj,V,dobj,advmod"),
("eat", "spec1", "nsubj,V,dobj,advmod"),
("fill", "spec1", "nsubj,V,dobj,advmod"),
])
def test_vectorized_matches_python(store, sel_df, verb, spec_id, arg_structure):
"""Bit-identical top-K output between vectorized and python paths."""
spec_words = paradigm_3_csp.spec_lexicon(store, spec_id)
shape = SkeletonShape(arg_structure, parse_arg_structure(arg_structure), 0)
vec_out = solve_shape(shape, verb=verb, domain_words=spec_words, ...)
with _force_python_path():
py_out = solve_shape(shape, verb=verb, domain_words=spec_words, ...)
assert [c["sentence"] for c in vec_out] == [c["sentence"] for c in py_out]
assert [c["total_score"] for c in vec_out] == [c["total_score"] for c in py_out]
for v, p in zip(vec_out, py_out):
assert v["score_components"] == p["score_components"]
_force_python_path() is a context manager that flips a module-level flag forcing the fallback path; exists only for tests.
Routing¶
def test_contrastive_takes_python_fallback():
cross = cross_slot_axes([ContrastiveConstraint(pair_type="minpair", phoneme1="k", phoneme2="g")])
assert _should_use_vectorized(cross_axes=cross) is False
def test_no_contrastive_takes_vectorized():
cross = cross_slot_axes([IncludeConstraint(phonemes=("k",))])
assert _should_use_vectorized(cross_axes=cross) is True
Stats parity (solve() delegation)¶
def test_solve_delegation_stats_match(store, sel_df):
spec_words = paradigm_3_csp.spec_lexicon(store, "spec1")
top, stats = solve("cut", "spec1", spec_words, sel_df, word_df=store.df)
expected_keys = {
"verb", "spec_id", "band",
"nsubj_domain_size", "dobj_domain_size", "adv_domain_size",
"candidate_count", "unique_pairs", "domain_trace", "active_axes",
}
assert set(stats.keys()) == expected_keys
assert stats["verb"] == "cut"
assert stats["candidate_count"] == len(top)
assert stats["unique_pairs"] >= 1
Perf bench¶
Extend bench_domain_cache.py (or add a sibling bench_enumeration.py) with two conditions:
- melt × spec6 probe, vectorized path
- melt × spec6 probe, forced Python path
Wall-clock per condition + speedup ratio. Numbers recorded back into this spec under "Empirical baseline" after first run.
Open questions¶
None.
References¶
- Spike code:
packages/generation/research/2026-05-07-sentence-generation-paradigms/ - PHON-103 caches sit upstream — vectorized path consumes
get_filtered_spec/get_per_word_axesoutputs directly. - Productionization: PHON-109 (blocked on PHON-103–108).
- Existing wall-clock baseline:
notebook.mdline 20 — "9,120 candidates × 30 adverbs ran in 2.19s".
Empirical baseline (recorded 2026-05-08)¶
From bench_enumeration.py:
| Probe | Vec (s) | Py (s) | Speedup |
|---|---|---|---|
| melt × spec6 | 0.007 | 0.007 | 0.91x |
| cut × spec1 | 0.010 | 0.034 | 3.55x |
| chase × spec1 | 0.007 | 0.006 | 0.89x |
| eat × spec1 | 0.009 | 0.032 | 3.53x |
| fill × spec1 | 0.009 | 0.021 | 2.45x |
| Total | 0.042 | 0.100 | 2.40x |
Note on the speedup vs the design's 10× projection. The 10×+ projection in the Goal section was calibrated against the pre-PHON-103 baseline of 2.19s on melt × spec6 (notebook.md line 20). PHON-103's domain caches (get_filtered_spec, get_spec_lexicon, get_per_word_axes) landed UPSTREAM of this bench and absorbed nearly all of that 2.19s — most of the original cost was Polars filter + scan + I/O on the 125K-row WordStore, not the Python-loop overhead PHON-104 was designed to eliminate. After PHON-103, melt × spec6 runs in ~7ms on either path; the Cartesian Python-iteration cost PHON-104 targets is now small relative to the constant overheads of a solve_shape call. The 2.4× total speedup is real (driven by cut × spec1 and eat × spec1 at ~3.5×) but represents the post-cache enumeration work, not the pre-cache full-pipeline work. The consolidation refactor (solve() migrating to delegate to solve_shape, eliminating ~70 lines of duplicate enumeration logic) remains a meaningful structural win independent of the wall-clock numbers.