Skip to content

PHON-104 — CSP perf: vectorize enumeration via Polars

Date: 2026-05-08 Branch: feature/csp-iteration Scope: spike-internal perf optimization (productionization to packages/generators/csp/ is PHON-109's scope)

Frame

The CSP spike's hot path enumerates a Cartesian over per-slot fillers in solve_shape (skeleton_csp.py) and a structurally similar nested loop in solve() (paradigm_3_csp.py). The largest acceptance probe (melt × spec6) produces 9,120 (nsubj × dobj) pairs × 30 adverbs = 273,600 Cartesian points and runs in 2.19s. The dominant cost is Python iteration overhead — generator-based recursion in enumerate_assignments, dict construction per assignment, function-call dispatch.

The original ticket title — "per-slot top-N pruning" — is unsafe: dropping low-PMI candidates before the reranker sees them risks losing candidates that the teacher-distilled quality reranker (Step 15, Spearman 0.633 on held-out (verb, band) groups) would otherwise rank into top-K. The reranker uses MiniLM features on full sentences and cannot be evaluated per-filler.

This ticket reframes PHON-104 as vectorize the enumeration in Polars. No candidates are dropped; the same full Cartesian is produced and scored, just orders of magnitude faster.

Goal

For the largest acceptance probe (melt × spec6), expect 10×+ wall-clock speedup. Bit-identical top-K output between vectorized and Python paths under the same inputs.

Non-goals

  • Per-slot top-N pruning (rejected as unsafe — reranker can re-rank low-PMI candidates into top-K).
  • Vectorizing cross-slot scorers (ContrastiveConstraint minpair check). Falls back to Python; PHON-106 reworks contrastive scoring anyway.
  • Productionization move to packages/generators/csp/ (PHON-109 scope).
  • Numpy/JAX backend variants. Polars is the codebase's data layer.

Architecture

Two-part work, both committed under PHON-104:

Part A — solve() delegation

paradigm_3_csp.solve() is rewritten as a thin wrapper around solve_shape(). Constructs a SkeletonShape from the implicit nsubj,V,dobj[,advmod] arg structure, calls solve_shape, repackages the return value into the legacy (top, stats) tuple. Public signature unchanged.

paradigm_3_csp.py
└── solve()  → constructs SkeletonShape("nsubj,V,dobj" + optional advmod)
              → calls solve_shape() under the hood
              → wraps return into (top, stats) via _build_solve_stats()
              → preserves existing public signature

Part B — solve_shape vectorization

skeleton_csp.py
└── solve_shape (existing API surface, rewritten internals)
    ├── _build_slot_filler_tables  (per-slot fillers + scores → list of polars frames)
    ├── _enumerate_vectorized      (cartesian via cross-join + score columns)
    └── _enumerate_python_fallback (extracted from current enumerate_assignments;
                                    used when ContrastiveConstraint present)

Routing decision: at the top of solve_shape, inspect cross_axes — if any scorer is registered (i.e., ContrastiveConstraint is in the request), take the Python fallback. Else, take the vectorized path. ccomp shapes recursively call solve_shape; the inner solve routes by its own constraint set independently.

solve_shape's public contract is unchanged — same args, same return shape (list[dict] with score_components, total_score, sentence, etc.). Only the internals swap.

Vectorized enumeration

def _enumerate_vectorized(
    shape: SkeletonShape,
    slot_fillers: list[tuple[str, list[str], dict[str, float]]],
    word_axes: dict[str, dict[str, float]],
    weights: dict[str, float] | None,
    locked_slots: dict[str, str],
) -> pl.DataFrame:
    # Build per-slot frames first
    slot_frames = {}
    for slot, fillers, scores in slot_fillers:
        if slot in locked_slots:
            w = locked_slots[slot]
            slot_frames[slot] = pl.DataFrame({
                slot: [w],
                f"pmi_{slot}": [scores.get(w, 0.0)],
            })
        else:
            slot_frames[slot] = pl.DataFrame({
                slot: fillers,
                f"pmi_{slot}": [scores.get(f, 0.0) for f in fillers],
            })

    # Cartesian via successive cross joins
    cart = slot_frames[shape.slots[0]]
    for s in shape.slots[1:]:
        cart = cart.join(slot_frames[s], how="cross")

    # nsubj != dobj invariant (only when both present)
    if "nsubj" in shape.slots and "dobj" in shape.slots:
        cart = cart.filter(pl.col("nsubj") != pl.col("dobj"))

    # Per-word soft axes — sum contributions across content slots
    for axis_name, lookup in word_axes.items():
        contributions = [
            pl.col(content_slot).replace_strict(lookup, default=0.0).cast(pl.Float64)
            for content_slot in shape.content_slots
        ]
        cart = cart.with_columns(pl.sum_horizontal(contributions).alias(axis_name))

    # Adverb sentinel (constant; only when shape has advmod with no real PMI)
    if "advmod" in shape.slots:
        cart = cart.with_columns(pl.lit(0.001).alias("adv_sentinel"))

    # Total score = weighted sum of all score columns
    score_cols = [
        c for c in cart.columns
        if c.startswith("pmi_") or c in word_axes or c == "adv_sentinel"
    ]
    weighted = [
        pl.col(c) * (weights.get(c, 1.0) if weights else 1.0)
        for c in score_cols
    ]
    cart = cart.with_columns(pl.sum_horizontal(weighted).alias("total_score"))
    return cart

Dedup + truncate (Polars expressions)

content_keys = list(shape.content_slots)
deduped = (
    cart
    .sort("total_score", descending=True)
    .unique(subset=content_keys, keep="first")
    .head(top_k * over_fetch)
)

Final step iterates the small deduped frame in Python to build the list[dict] return shape (sentence realization via realize(), ccomp resolution via _resolve_embedded_clause, etc.). Python loop runs over top_k * over_fetch rows (8–32 typically), not the full Cartesian.

solve() delegation

def solve(
    verb: str,
    spec_id: str,
    spec_words: frozenset[str],
    sel_df: pl.DataFrame,
    *,
    constraints: list[Constraint] | None = None,
    word_df: pl.DataFrame | None = None,
    band: str = BAND,
    top_k: int = TOP_K,
    include_adverb: bool = True,
    weights: dict[str, float] | None = None,
) -> tuple[list[dict], dict]:
    """Solve nsubj-V-dobj (+ optional advmod) shape via solve_shape."""
    constraints = list(constraints or [])
    arg = "nsubj,V,dobj,advmod" if include_adverb else "nsubj,V,dobj"
    shape = SkeletonShape(arg, parse_arg_structure(arg), band_freq=0)

    filtered_spec, trace = _resolve_domain_words(spec_words, constraints, word_df)
    word_axes = get_per_word_axes(constraints, word_df)
    cross_axes = cross_slot_axes(constraints)

    # Capture pre-cartesian domain sizes for stats parity (otherwise
    # delegated stats would only see post-truncation candidates).
    domain_sizes = _peek_domain_sizes(verb, band, filtered_spec, sel_df, include_adverb)

    candidates = solve_shape(
        shape,
        verb=verb, domain_words=filtered_spec, sel_df=sel_df,
        band=band, word_axes=word_axes, cross_axes=cross_axes,
        word_df=word_df, weights=weights, top_k=top_k,
    )

    stats = _build_solve_stats(
        verb=verb, spec_id=spec_id, band=band,
        candidates=candidates, trace=trace,
        word_axes=word_axes, cross_axes=cross_axes,
        domain_sizes=domain_sizes,
    )
    return candidates, stats

_peek_domain_sizes(...) does a cheap PMI lookup per slot (already cached via _ADVMOD_PMI_CACHE for adv) and intersects with filtered_spec. Returns {nsubj: int, dobj: int, advmod: int}. Lets stats stay accurate at minimal cost.

_build_solve_stats(...) synthesizes the legacy stats dict shape:

def _build_solve_stats(*, verb, spec_id, band, candidates, trace, word_axes, cross_axes, domain_sizes) -> dict:
    return {
        "verb": verb,
        "spec_id": spec_id,
        "band": band,
        "nsubj_domain_size": domain_sizes.get("nsubj", 0),
        "dobj_domain_size":  domain_sizes.get("dobj", 0),
        "adv_domain_size":   domain_sizes.get("advmod", 0),
        "candidate_count":   len(candidates),
        "unique_pairs":      len({(c.get("nsubj"), c.get("dobj")) for c in candidates}),
        "domain_trace":      trace,
        "active_axes":       list(word_axes.keys()) + list(cross_axes.keys()),
    }

Edge cases & invariants

Case Behavior
Empty slot fillers _slot_fillers returns [] for some slot → solve_shape returns []. Vectorized path: cross-join with empty frame → empty cart → empty result.
Locked slot whose locked value isn't in the PMI table 1-row frame uses scores.get(w, 0.0) for that slot's PMI column — 0.0 contribution, candidate still appears. Same outcome as Python path.
ccomp recursive solve _resolve_embedded_clause calls solve_shape recursively. Inner call routes vectorized vs fallback by its own constraint set (no contrastive in inner solve typically → vectorized).
Contrastive present + ccomp shape Outer falls back to Python; inner is independent solve_shape — also routes by its own constraints. Both paths work.
nsubj == dobj filter Vectorized: pl.col("nsubj") != pl.col("dobj") filter after cross-join, only when both slots present. Python fallback: existing inner if. Same outcome.
Adverb sentinel + real advmod PMI _slot_fillers returns verb_pmi OR band_fallback, never both. adv_sentinel column only present when fallback fired. Invariant holds.
weights=None (default 1.0 per axis) Collapses to constant 1.0. Same as Python path.
top_k * over_fetch larger than deduped pool head(top_k * over_fetch) is bounded by frame height. Same outcome.

Invariants the vectorized path must preserve (verified by tests):

  1. Top-K candidates by total_score are bit-identical to the Python path under same inputs (modulo float ordering when scores tie — handled by stable sort + content-pair dedup in both paths).
  2. score_components dict on each returned candidate has the same keys as the Python path (one entry per active axis).
  3. Locked slots produce candidates where the locked filler's word is at the locked position, exactly as Python does.
  4. Contrastive present → vectorized path is NOT taken; cross_axes scorer fires per-candidate as today.

Testing

New test_vectorized_enumeration.py next to other spike tests. Three buckets:

Equivalence (vectorized == Python)

Parameterized over the PHON-95 acceptance probe matrix plus a locked-slot probe and a PP shape:

@pytest.mark.parametrize("verb,spec_id,arg_structure", [
    ("cut",   "spec1", "nsubj,V,dobj"),
    ("cut",   "spec1", "nsubj,V,dobj,advmod"),
    ("chase", "spec1", "nsubj,V,dobj,advmod"),
    ("melt",  "spec6", "nsubj,V,dobj,advmod"),
    ("eat",   "spec1", "nsubj,V,dobj,advmod"),
    ("fill",  "spec1", "nsubj,V,dobj,advmod"),
])
def test_vectorized_matches_python(store, sel_df, verb, spec_id, arg_structure):
    """Bit-identical top-K output between vectorized and python paths."""
    spec_words = paradigm_3_csp.spec_lexicon(store, spec_id)
    shape = SkeletonShape(arg_structure, parse_arg_structure(arg_structure), 0)

    vec_out = solve_shape(shape, verb=verb, domain_words=spec_words, ...)
    with _force_python_path():
        py_out = solve_shape(shape, verb=verb, domain_words=spec_words, ...)

    assert [c["sentence"] for c in vec_out] == [c["sentence"] for c in py_out]
    assert [c["total_score"] for c in vec_out] == [c["total_score"] for c in py_out]
    for v, p in zip(vec_out, py_out):
        assert v["score_components"] == p["score_components"]

_force_python_path() is a context manager that flips a module-level flag forcing the fallback path; exists only for tests.

Routing

def test_contrastive_takes_python_fallback():
    cross = cross_slot_axes([ContrastiveConstraint(pair_type="minpair", phoneme1="k", phoneme2="g")])
    assert _should_use_vectorized(cross_axes=cross) is False

def test_no_contrastive_takes_vectorized():
    cross = cross_slot_axes([IncludeConstraint(phonemes=("k",))])
    assert _should_use_vectorized(cross_axes=cross) is True

Stats parity (solve() delegation)

def test_solve_delegation_stats_match(store, sel_df):
    spec_words = paradigm_3_csp.spec_lexicon(store, "spec1")
    top, stats = solve("cut", "spec1", spec_words, sel_df, word_df=store.df)
    expected_keys = {
        "verb", "spec_id", "band",
        "nsubj_domain_size", "dobj_domain_size", "adv_domain_size",
        "candidate_count", "unique_pairs", "domain_trace", "active_axes",
    }
    assert set(stats.keys()) == expected_keys
    assert stats["verb"] == "cut"
    assert stats["candidate_count"] == len(top)
    assert stats["unique_pairs"] >= 1

Perf bench

Extend bench_domain_cache.py (or add a sibling bench_enumeration.py) with two conditions: - melt × spec6 probe, vectorized path - melt × spec6 probe, forced Python path

Wall-clock per condition + speedup ratio. Numbers recorded back into this spec under "Empirical baseline" after first run.

Open questions

None.

References

  • Spike code: packages/generation/research/2026-05-07-sentence-generation-paradigms/
  • PHON-103 caches sit upstream — vectorized path consumes get_filtered_spec / get_per_word_axes outputs directly.
  • Productionization: PHON-109 (blocked on PHON-103–108).
  • Existing wall-clock baseline: notebook.md line 20 — "9,120 candidates × 30 adverbs ran in 2.19s".

Empirical baseline (recorded 2026-05-08)

From bench_enumeration.py:

Probe Vec (s) Py (s) Speedup
melt × spec6 0.007 0.007 0.91x
cut × spec1 0.010 0.034 3.55x
chase × spec1 0.007 0.006 0.89x
eat × spec1 0.009 0.032 3.53x
fill × spec1 0.009 0.021 2.45x
Total 0.042 0.100 2.40x

Note on the speedup vs the design's 10× projection. The 10×+ projection in the Goal section was calibrated against the pre-PHON-103 baseline of 2.19s on melt × spec6 (notebook.md line 20). PHON-103's domain caches (get_filtered_spec, get_spec_lexicon, get_per_word_axes) landed UPSTREAM of this bench and absorbed nearly all of that 2.19s — most of the original cost was Polars filter + scan + I/O on the 125K-row WordStore, not the Python-loop overhead PHON-104 was designed to eliminate. After PHON-103, melt × spec6 runs in ~7ms on either path; the Cartesian Python-iteration cost PHON-104 targets is now small relative to the constant overheads of a solve_shape call. The 2.4× total speedup is real (driven by cut × spec1 and eat × spec1 at ~3.5×) but represents the post-cache enumeration work, not the pre-cache full-pipeline work. The consolidation refactor (solve() migrating to delegate to solve_shape, eliminating ~70 lines of duplicate enumeration logic) remains a meaningful structural win independent of the wall-clock numbers.