Skip to content

Unified Constrained Generation Implementation Plan

For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (- [ ]) syntax for tracking.

Goal: Replace the token-level governor architecture with a word-list-based BAN/BOOST pipeline, adding bigram lookahead and targeted rollout escalation.

Architecture: Everything is a word list. Constraints resolve via PhonoLex Workers API to word lists, then apply as BAN (bad_words_ids) or BOOST (Reranker coverage modulation). Four escalation layers: BAN/BOOST → GUARD → bigram dead-end filter → targeted rollout. Best-of-N wraps everything.

Tech Stack: Python (FastAPI, PyTorch, transformers), TypeScript (Hono/D1, React/MUI), SSE streaming

Spec: docs/superpowers/specs/2026-04-15-unified-constrained-generation-design.md


File Map

Workers API (new endpoint)

  • Modify: packages/web/workers/src/routes/words.ts — add POST /word-list
  • Test: packages/web/workers/tests/word-list.test.ts

Generation Server (major rewrite)

  • Modify: packages/generation/server/schemas.py — collapse 11 → 5 constraint types + ResolvedConstraint
  • Rewrite: packages/generation/server/governor.pyresolve_constraints() + prepare_generation()
  • Rewrite: packages/generation/server/model.py — four-layer generate_with_checking()
  • Modify: packages/generation/server/routes/generate.py — SSE streaming for generate-single
  • Modify: packages/generation/server/word_norms.py — lazy loading, remove startup precompute
  • Modify: packages/generation/server/main.py — remove GovernorCache init, remove lookup loading
  • Delete: packages/generation/server/routes/sessions.py — session route uses old governor
  • Delete: packages/generation/server/sessions.py — session store
  • Delete: packages/generation/server/profiles.py — profile store
  • Delete: packages/generation/server/routes/profiles.py — profile routes
  • Rewrite: packages/generation/server/tests/test_governor.py — test resolve + prepare
  • Modify: packages/generation/server/tests/test_schemas.py — test new 5-type schema
  • Modify: packages/generation/server/tests/conftest.py — remove lookup fixture dependency

Governor Package (strip to A-team)

  • Rewrite: packages/governors/src/phonolex_governors/__init__.py — export only A-team
  • Rewrite: packages/governors/src/phonolex_governors/generation/reranker.py — set membership + coverage
  • Create: packages/governors/src/phonolex_governors/generation/lookahead.py — bigram filter + rollout
  • Keep: packages/governors/src/phonolex_governors/checking/checker.py (unchanged)
  • Keep: packages/governors/src/phonolex_governors/checking/g2p.py (unchanged)
  • Keep: packages/governors/src/phonolex_governors/checking/phonology.py (unchanged)
  • Delete: packages/governors/src/phonolex_governors/core.py
  • Delete: packages/governors/src/phonolex_governors/gates.py
  • Delete: packages/governors/src/phonolex_governors/boosts.py (module)
  • Delete: packages/governors/src/phonolex_governors/boosts/ (directory)
  • Delete: packages/governors/src/phonolex_governors/cdd.py
  • Delete: packages/governors/src/phonolex_governors/constraints.py
  • Delete: packages/governors/src/phonolex_governors/include.py
  • Delete: packages/governors/src/phonolex_governors/lookups.py
  • Delete: packages/governors/src/phonolex_governors/thematic.py
  • Delete: packages/governors/src/phonolex_governors/constraint_compiler.py
  • Delete: packages/governors/src/phonolex_governors/constraint_types.py
  • Delete: packages/governors/src/phonolex_governors/generation/backtrack.py
  • Delete: packages/governors/src/phonolex_governors/generation/loop.py
  • Delete: packages/governors/src/phonolex_governors/generation/sampling.py
  • Create: packages/governors/tests/test_reranker_v2.py
  • Create: packages/governors/tests/test_lookahead.py

Bigram Matrix Builder

  • Create: packages/generation/scripts/build_bigram_matrix.py
  • Artifact: packages/generation/data/bigram_matrix.npz (gitignored)

Frontend

  • Modify: packages/web/frontend/src/types/governance.ts — 5 constraint types
  • Modify: packages/web/frontend/src/lib/constraintCompiler.ts — handle all 5 types
  • Modify: packages/web/frontend/src/lib/constraintCompiler.test.ts — tests for bound_boost + contrastive
  • Modify: packages/web/frontend/src/lib/generationApi.ts — SSE streaming
  • Modify: packages/web/frontend/src/components/tools/GovernedGenerationTool/index.tsx — status line
  • Modify: packages/web/frontend/src/components/tools/GovernedGenerationTool/OutputCard.tsx — generalized coverage

Task 1: Workers API Word List Endpoint

Files: - Modify: packages/web/workers/src/routes/words.ts - Create: packages/web/workers/tests/word-list.test.ts

  • [ ] Step 1: Write the failing test

Create packages/web/workers/tests/word-list.test.ts:

import { describe, it, expect } from 'vitest';
import app from '../src/index';

// Uses the same test D1 fixture as other worker tests

describe('POST /api/words/word-list', () => {
  it('returns words containing a phoneme', async () => {
    const res = await app.request('/api/words/word-list', {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({ include_phonemes: ['k'] }),
    });
    expect(res.status).toBe(200);
    const data = await res.json();
    expect(data.words).toBeInstanceOf(Array);
    expect(data.total).toBeGreaterThan(0);
    expect(data.words.length).toBe(data.total);
    // Every returned word should be a string
    expect(data.words.every((w: unknown) => typeof w === 'string')).toBe(true);
  });

  it('returns words matching norm bounds', async () => {
    const res = await app.request('/api/words/word-list', {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({ filters: { max_aoa_kuperman: 5 } }),
    });
    expect(res.status).toBe(200);
    const data = await res.json();
    expect(data.words).toBeInstanceOf(Array);
    expect(data.total).toBeGreaterThan(0);
  });

  it('returns empty list for impossible filter', async () => {
    const res = await app.request('/api/words/word-list', {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({ filters: { min_aoa_kuperman: 999 } }),
    });
    expect(res.status).toBe(200);
    const data = await res.json();
    expect(data.words).toEqual([]);
    expect(data.total).toBe(0);
  });

  it('returns 400 for empty body', async () => {
    const res = await app.request('/api/words/word-list', {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({}),
    });
    expect(res.status).toBe(400);
  });
});
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/web/workers && npm test -- --grep "word-list" Expected: FAIL — route not found (404)

  • [ ] Step 3: Implement the endpoint

Add to packages/web/workers/src/routes/words.ts after the existing /batch route:

interface WordListBody {
  include_phonemes?: string[];
  exclude_phonemes?: string[];
  filters?: Record<string, number>;
}

words.post('/word-list', async (c) => {
  const body = await c.req.json<WordListBody>();

  if (!body.include_phonemes?.length && !body.exclude_phonemes?.length
      && (!body.filters || Object.keys(body.filters).length === 0)) {
    return c.json({ error: 'At least one filter required' }, 400);
  }

  const wordsWhere: string[] = ['w.has_phonology = 1'];
  const propsWhere: string[] = [];
  const params: unknown[] = [];

  // Phoneme inclusion — words containing ANY of these phonemes
  if (body.include_phonemes?.length) {
    const phonemeClauses = body.include_phonemes.map((ph) => {
      const normalized = normalizePhoneme(ph);
      params.push(`%|${normalized}|%`);
      return 'w.phonemes_str LIKE ?';
    });
    wordsWhere.push(`(${phonemeClauses.join(' OR ')})`);
  }

  // Phoneme exclusion — words NOT containing these phonemes
  if (body.exclude_phonemes?.length) {
    for (const ph of body.exclude_phonemes) {
      const normalized = normalizePhoneme(ph);
      wordsWhere.push('w.phonemes_str NOT LIKE ?');
      params.push(`%|${normalized}|%`);
    }
  }

  // Norm filters
  if (body.filters && Object.keys(body.filters).length > 0) {
    const mergedFilters: Record<string, number | null> = {};
    for (const [key, value] of Object.entries(body.filters)) {
      if ((key.startsWith('min_') || key.startsWith('max_')) && typeof value === 'number') {
        mergedFilters[key] = value;
      }
    }
    if (Object.keys(mergedFilters).length > 0) {
      const partitioned = partitionFilterColumns(mergedFilters);
      wordsWhere.push(...partitioned.wordsClauses.map((cl) => prefixWordsColumns(cl)));
      params.push(...partitioned.wordsParams);
      propsWhere.push(...partitioned.propsClauses.map((cl) => prefixPropsColumns(cl)));
      params.push(...partitioned.propsParams);
    }
  }

  const wordsWhereSQL = wordsWhere.join(' AND ');
  const needsPropsJoin = propsWhere.length > 0;
  const fromClause = needsPropsJoin
    ? 'FROM words w INNER JOIN word_properties wp ON w.word = wp.word'
    : 'FROM words w';
  const fullWhere = needsPropsJoin
    ? ` WHERE ${wordsWhereSQL} AND ${propsWhere.join(' AND ')}`
    : ` WHERE ${wordsWhereSQL}`;

  const sql = `SELECT w.word ${fromClause}${fullWhere}`;
  const { results } = await c.env.DB.prepare(sql).bind(...params).all<{ word: string }>();

  const wordList = results.map((r) => r.word);
  return c.json({ words: wordList, total: wordList.length });
});

Note: prefixWordsColumns and prefixPropsColumns are helper functions already used in the /search route. Check that they are importable or defined in scope — they may be local to the file. If they're inline helpers, extract them or reuse the existing ones.

  • [ ] Step 4: Run test to verify it passes

Run: cd packages/web/workers && npm test -- --grep "word-list" Expected: PASS

  • [ ] Step 5: Commit
git add packages/web/workers/src/routes/words.ts packages/web/workers/tests/word-list.test.ts
git commit -m "feat(api): add POST /api/words/word-list endpoint for generation server"

Task 2: Simplify Schema Types

Files: - Modify: packages/generation/server/schemas.py - Modify: packages/generation/server/tests/test_schemas.py

  • [ ] Step 1: Write the failing test

Replace content of packages/generation/server/tests/test_schemas.py with tests for the new 5-type schema:

import pytest
from pydantic import ValidationError
from server.schemas import (
    ExcludeConstraint, IncludeConstraint, BoundConstraint,
    BoundBoostConstraint, ContrastiveConstraint,
    ResolvedConstraint,
    GenerateSingleRequest,
)


def test_exclude_constraint():
    c = ExcludeConstraint(type="exclude", phonemes=["ɹ", "ɝ"])
    assert c.phonemes == ["ɹ", "ɝ"]


def test_include_constraint():
    c = IncludeConstraint(type="include", phonemes=["k"], target_rate=0.2)
    assert c.target_rate == 0.2


def test_include_constraint_rejects_invalid_rate():
    with pytest.raises(ValidationError):
        IncludeConstraint(type="include", phonemes=["k"], target_rate=1.5)


def test_bound_constraint_min():
    c = BoundConstraint(type="bound", norm="concreteness", min=3.0)
    assert c.min == 3.0
    assert c.max is None


def test_bound_constraint_max():
    c = BoundConstraint(type="bound", norm="aoa_kuperman", max=5.0)
    assert c.max == 5.0


def test_bound_boost_constraint():
    c = BoundBoostConstraint(
        type="bound_boost", norm="concreteness", min=3.0, coverage_target=0.2,
    )
    assert c.coverage_target == 0.2


def test_contrastive_constraint_minpair():
    c = ContrastiveConstraint(
        type="contrastive", pair_type="minpair",
        phoneme1="s", phoneme2="z", position="initial",
    )
    assert c.pair_type == "minpair"
    assert c.position == "initial"


def test_contrastive_constraint_maxopp():
    c = ContrastiveConstraint(
        type="contrastive", pair_type="maxopp",
        phoneme1="m", phoneme2="s", position="any",
    )
    assert c.pair_type == "maxopp"


def test_contrastive_rejects_invalid_position():
    with pytest.raises(ValidationError):
        ContrastiveConstraint(
            type="contrastive", pair_type="minpair",
            phoneme1="s", phoneme2="z", position="nowhere",
        )


def test_resolved_constraint_ban():
    r = ResolvedConstraint(
        mode="ban", words=["car", "card"], strategy="direct",
        label="exclude /ɹ/",
    )
    assert r.mode == "ban"
    assert r.strategy == "direct"
    assert r.coverage_target is None


def test_resolved_constraint_boost():
    r = ResolvedConstraint(
        mode="boost", words=["cat", "kit"], strategy="direct",
        coverage_target=0.2, label="include /k/ 20%",
    )
    assert r.coverage_target == 0.2


def test_generate_single_request_with_new_types():
    req = GenerateSingleRequest(
        prompt="Write a story",
        constraints=[
            {"type": "exclude", "phonemes": ["ɹ"]},
            {"type": "include", "phonemes": ["k"], "target_rate": 0.2},
            {"type": "bound", "norm": "aoa_kuperman", "max": 5.0},
            {"type": "bound_boost", "norm": "concreteness", "min": 3.0, "coverage_target": 0.15},
            {"type": "contrastive", "pair_type": "minpair", "phoneme1": "s", "phoneme2": "z", "position": "initial"},
        ],
    )
    assert len(req.constraints) == 5
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/generation && uv run python -m pytest server/tests/test_schemas.py -v Expected: FAIL — BoundBoostConstraint, ContrastiveConstraint, ResolvedConstraint not defined

  • [ ] Step 3: Rewrite schemas.py

Replace the constraint types in packages/generation/server/schemas.py. Keep Phono, RichToken, GenerateSingleRequest, SingleGenerationResponse, WordViolation, WordComplianceDetail, IncludeCoverage, ServerStatus. Remove the old 11 constraint types and session/profile types. Add:

class ExcludeConstraint(BaseModel):
    type: Literal["exclude"]
    phonemes: list[str]


class IncludeConstraint(BaseModel):
    type: Literal["include"]
    phonemes: list[str]
    target_rate: float = 0.2

    @field_validator("target_rate")
    @classmethod
    def validate_target_rate(cls, v):
        if not (0.0 <= v <= 1.0):
            raise ValueError("target_rate must be between 0.0 and 1.0")
        return v


class BoundConstraint(BaseModel):
    type: Literal["bound"]
    norm: str
    min: float | None = None
    max: float | None = None


class BoundBoostConstraint(BaseModel):
    type: Literal["bound_boost"]
    norm: str
    min: float | None = None
    max: float | None = None
    coverage_target: float = 0.2

    @field_validator("coverage_target")
    @classmethod
    def validate_coverage_target(cls, v):
        if not (0.0 <= v <= 1.0):
            raise ValueError("coverage_target must be between 0.0 and 1.0")
        return v


class ContrastiveConstraint(BaseModel):
    type: Literal["contrastive"]
    pair_type: Literal["minpair", "maxopp"]
    phoneme1: str
    phoneme2: str
    position: Literal["initial", "medial", "final", "any"] = "any"


Constraint = (
    ExcludeConstraint | IncludeConstraint | BoundConstraint
    | BoundBoostConstraint | ContrastiveConstraint
)


class ResolvedConstraint(BaseModel):
    mode: Literal["ban", "boost"]
    words: list[str]
    strategy: Literal["direct", "complement"] = "direct"
    coverage_target: float | None = None
    label: str

Remove: ExcludeClustersConstraint, VocabBoostConstraint, ComplexityConstraint, VocabOnlyConstraint, MSHConstraint, MinPairBoostConstraint, MaxOppositionBoostConstraint, ThematicConstraint, ConstraintProfile, Turn, TurnUser, TurnAssistant, Session, GenerateRequest, GenerateResponse, UserAnalysis, AssistantResponse, SessionCreateRequest, SessionCreateResponse, ProfileCreateRequest, ProfileUpdateRequest, TokenCensus.

Update GenerateSingleRequest to use the new Constraint union. Add a BoostCoverage model for generalized coverage reporting:

class BoostCoverage(BaseModel):
    label: str
    target_rate: float
    actual_rate: float
    hit_words: list[str]
    total_words: int

Update SingleGenerationResponse — replace include_coverage: list[IncludeCoverage] with boost_coverage: list[BoostCoverage]. Keep IncludeCoverage temporarily as an alias for backwards compatibility if needed, or remove it.

  • [ ] Step 4: Run test to verify it passes

Run: cd packages/generation && uv run python -m pytest server/tests/test_schemas.py -v Expected: PASS

  • [ ] Step 5: Commit
git add packages/generation/server/schemas.py packages/generation/server/tests/test_schemas.py
git commit -m "refactor(schemas): collapse 11 constraint types to 5 + ResolvedConstraint"

Task 3: Strip Governor Package to A-Team

Files: - Delete: 14 files (see file map above) - Rewrite: packages/governors/src/phonolex_governors/__init__.py - Keep: checking/ directory unchanged - Keep: generation/reranker.py (rewritten in Task 5)

  • [ ] Step 1: Delete dead modules
cd packages/governors/src/phonolex_governors
rm -f core.py gates.py cdd.py constraints.py include.py lookups.py thematic.py
rm -f constraint_compiler.py constraint_types.py
rm -f boosts.py  # the flat file
rm -rf boosts/   # the directory
rm -f generation/backtrack.py generation/loop.py generation/sampling.py
  • [ ] Step 2: Rewrite __init__.py
"""PhonoLex Governors — word-level constraint checking and generation steering."""

from phonolex_governors.checking.checker import (
    CheckerConfig,
    CheckResult,
    PhonemeExcludeCheck,
    VocabOnlyCheck,
    check_word,
)
from phonolex_governors.checking.g2p import G2PCache, word_to_phonemes
from phonolex_governors.generation.reranker import Reranker

__all__ = [
    "CheckerConfig",
    "CheckResult",
    "PhonemeExcludeCheck",
    "VocabOnlyCheck",
    "check_word",
    "G2PCache",
    "word_to_phonemes",
    "Reranker",
]
  • [ ] Step 3: Update generation/__init__.py
"""Generation steering — reranker and lookahead."""
  • [ ] Step 4: Verify the package imports cleanly

Run: cd packages/governors && uv run python -c "from phonolex_governors import check_word, Reranker, G2PCache; print('OK')" Expected: OK

  • [ ] Step 5: Commit
cd /Users/jneumann/Repos/PhonoLex
git add -u packages/governors/
git commit -m "refactor(governors): strip to A-team — checker, g2p, reranker"

Task 4: Frontend — Complete the Constraint Compiler

Files: - Modify: packages/web/frontend/src/types/governance.ts - Modify: packages/web/frontend/src/lib/constraintCompiler.ts - Modify: packages/web/frontend/src/lib/constraintCompiler.test.ts

  • [ ] Step 1: Write the failing tests

Add to packages/web/frontend/src/lib/constraintCompiler.test.ts:

it('compiles bound_boost entry', () => {
  const entries: StoreEntry[] = [
    { type: 'bound_boost', norm: 'concreteness', direction: 'min', value: 3, coverageTarget: 15 },
  ];
  const result = compileConstraints(entries);
  expect(result).toEqual([
    { type: 'bound_boost', norm: 'concreteness', min: 3, coverage_target: 0.15 },
  ]);
});

it('compiles contrastive minpair entry', () => {
  const entries: StoreEntry[] = [
    { type: 'contrastive', pairType: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
  ];
  const result = compileConstraints(entries);
  expect(result).toEqual([
    { type: 'contrastive', pair_type: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
  ]);
});

it('compiles contrastive maxopp entry', () => {
  const entries: StoreEntry[] = [
    { type: 'contrastive', pairType: 'maxopp', phoneme1: 'm', phoneme2: 's', position: 'any' },
  ];
  const result = compileConstraints(entries);
  expect(result).toEqual([
    { type: 'contrastive', pair_type: 'maxopp', phoneme1: 'm', phoneme2: 's', position: 'any' },
  ]);
});

it('compiles all 5 types together', () => {
  const entries: StoreEntry[] = [
    { type: 'exclude', phoneme: 'ɹ' },
    { type: 'include', phoneme: 'k', strength: 2.0, targetRate: 20 },
    { type: 'bound', norm: 'aoa_kuperman', direction: 'max', value: 5 },
    { type: 'bound_boost', norm: 'concreteness', direction: 'min', value: 3, coverageTarget: 15 },
    { type: 'contrastive', pairType: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
  ];
  const result = compileConstraints(entries);
  expect(result).toHaveLength(5);
  expect(result.map((c) => c.type)).toEqual([
    'exclude', 'include', 'bound', 'bound_boost', 'contrastive',
  ]);
});
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/web/frontend && npx vitest run src/lib/constraintCompiler.test.ts Expected: FAIL — bound_boost and contrastive entries silently dropped

  • [ ] Step 3: Update governance.ts types

Simplify the Constraint union in packages/web/frontend/src/types/governance.ts. Replace the 11 API constraint types with 5:

export interface ExcludeConstraint {
  type: "exclude";
  phonemes: string[];
}

export interface IncludeConstraint {
  type: "include";
  phonemes: string[];
  target_rate?: number;
}

export interface BoundConstraint {
  type: "bound";
  norm: string;
  min?: number;
  max?: number;
}

export interface BoundBoostConstraint {
  type: "bound_boost";
  norm: string;
  min?: number;
  max?: number;
  coverage_target: number;
}

export interface ContrastiveConstraint {
  type: "contrastive";
  pair_type: "minpair" | "maxopp";
  phoneme1: string;
  phoneme2: string;
  position: string;
}

export type Constraint =
  | ExcludeConstraint
  | IncludeConstraint
  | BoundConstraint
  | BoundBoostConstraint
  | ContrastiveConstraint;

Remove the old types: ExcludeClustersConstraint, VocabBoostConstraint, ComplexityConstraint, VocabOnlyConstraint, MSHConstraint, MinPairBoostConstraint, MaxOppositionBoostConstraint, ThematicConstraint.

Add BoostCoverage to replace IncludeCoverage:

export interface BoostCoverage {
  label: string;
  target_rate: number;
  actual_rate: number;
  hit_words: string[];
  total_words: number;
}

Update SingleGenerationResponse to use boost_coverage: BoostCoverage[] instead of include_coverage: IncludeCoverage[].

  • [ ] Step 4: Update constraintCompiler.ts

Add bound_boost and contrastive handling to compileConstraints():

// --- Bound boost: one constraint per entry ---
const boundBoosts = entries.filter(
  (e): e is Extract<StoreEntry, { type: 'bound_boost' }> => e.type === 'bound_boost',
);
for (const bb of boundBoosts) {
  result.push({
    type: 'bound_boost',
    norm: bb.norm,
    ...(bb.direction === 'min' ? { min: bb.value } : { max: bb.value }),
    coverage_target: bb.coverageTarget / 100,
  });
}

// --- Contrastive: one constraint per entry ---
const contrastives = entries.filter(
  (e): e is Extract<StoreEntry, { type: 'contrastive' }> => e.type === 'contrastive',
);
for (const ct of contrastives) {
  result.push({
    type: 'contrastive',
    pair_type: ct.pairType,
    phoneme1: ct.phoneme1,
    phoneme2: ct.phoneme2,
    position: ct.position,
  });
}
  • [ ] Step 5: Run tests to verify they pass

Run: cd packages/web/frontend && npx vitest run src/lib/constraintCompiler.test.ts Expected: PASS (all tests including new ones)

  • [ ] Step 6: Commit
git add packages/web/frontend/src/types/governance.ts packages/web/frontend/src/lib/constraintCompiler.ts packages/web/frontend/src/lib/constraintCompiler.test.ts
git commit -m "feat(frontend): complete constraint compiler — all 5 types wired"

Task 5: Rewrite the Reranker

Files: - Rewrite: packages/governors/src/phonolex_governors/generation/reranker.py - Create: packages/governors/tests/test_reranker_v2.py

  • [ ] Step 1: Write the failing tests

Create packages/governors/tests/test_reranker_v2.py:

import pytest
import torch
from unittest.mock import MagicMock
from phonolex_governors.generation.reranker import Reranker, BoostList


@pytest.fixture
def mock_tokenizer():
    tok = MagicMock()
    tok.all_special_ids = [0, 1, 2]
    # Map token IDs to decoded text
    decode_map = {
        10: "▁cat",
        11: "▁dog",
        12: "▁car",
        13: "▁kit",
        14: "▁the",
        15: "▁run",
    }
    tok.decode = MagicMock(side_effect=lambda ids, **kw: decode_map.get(ids[0], "▁unk"))
    return tok


def test_complement_ban_penalizes_non_members(mock_tokenizer):
    """Words not in the allow set should be penalized."""
    allow_set = {"cat", "dog", "the"}
    reranker = Reranker(
        tokenizer=mock_tokenizer,
        allow_sets=[allow_set],
        boost_lists=[],
        top_k=10,
        penalize=10.0,
    )
    logits = torch.zeros(1, 20)
    logits[0, 10] = 5.0  # cat — allowed
    logits[0, 12] = 5.0  # car — not allowed
    result = reranker(logits)
    assert result[0, 10].item() == 5.0  # cat untouched
    assert result[0, 12].item() < 0.0   # car penalized


def test_boost_list_increases_target_logits(mock_tokenizer):
    """Words in a boost list should get a positive logit adjustment."""
    boost = BoostList(words={"cat", "kit"}, target_rate=0.3, label="include /k/")
    reranker = Reranker(
        tokenizer=mock_tokenizer,
        allow_sets=[],
        boost_lists=[boost],
        top_k=10,
    )
    logits = torch.zeros(1, 20)
    logits[0, 10] = 5.0  # cat — in boost set
    logits[0, 11] = 5.0  # dog — not in boost set
    result = reranker(logits)
    assert result[0, 10].item() > 5.0  # cat boosted
    assert result[0, 11].item() == 5.0  # dog untouched


def test_boost_self_regulates(mock_tokenizer):
    """When coverage is at target, boost should ease off."""
    boost = BoostList(words={"cat"}, target_rate=0.5, label="test")
    reranker = Reranker(
        tokenizer=mock_tokenizer,
        allow_sets=[],
        boost_lists=[boost],
        top_k=10,
    )
    # Simulate: 10 words generated, 5 are "cat" → 50% coverage = at target
    boost.update_coverage(hit=5, total=10)
    logits = torch.zeros(1, 20)
    logits[0, 10] = 5.0  # cat
    result = reranker(logits)
    # At target — boost should be zero or near-zero
    assert abs(result[0, 10].item() - 5.0) < 1.0


def test_multiple_boost_lists_compose(mock_tokenizer):
    """Multiple boost lists each apply independently."""
    b1 = BoostList(words={"cat"}, target_rate=0.2, label="boost1")
    b2 = BoostList(words={"cat", "dog"}, target_rate=0.3, label="boost2")
    reranker = Reranker(
        tokenizer=mock_tokenizer,
        allow_sets=[],
        boost_lists=[b1, b2],
        top_k=10,
    )
    logits = torch.zeros(1, 20)
    logits[0, 10] = 5.0  # cat — in both
    logits[0, 11] = 5.0  # dog — in b2 only
    result = reranker(logits)
    # cat gets boosted by both lists
    assert result[0, 10].item() > result[0, 11].item()
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/governors && uv run python -m pytest tests/test_reranker_v2.py -v Expected: FAIL — BoostList not defined, Reranker constructor changed

  • [ ] Step 3: Rewrite reranker.py

Replace packages/governors/src/phonolex_governors/generation/reranker.py:

"""Word-list reranker — set membership checking + coverage modulation.

Two responsibilities:
1. Complement ban: penalize candidates not in any allow set
2. Boost coverage: modulate logits to converge on target rates for boost word lists

No G2P. All checks are set membership — O(1) per candidate.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import torch
from transformers import LogitsProcessor


WORD_BOUNDARY_CHAR = "\u2581"


@dataclass
class BoostList:
    """A word list with a coverage target."""
    words: set[str]
    target_rate: float
    label: str
    max_boost: float = 5.0
    _hits: int = 0
    _total: int = 0

    def update_coverage(self, hit: int, total: int) -> None:
        self._hits = hit
        self._total = total

    @property
    def coverage(self) -> float:
        return self._hits / self._total if self._total > 0 else 0.0

    @property
    def gap(self) -> float:
        return max(0.0, self.target_rate - self.coverage)

    @property
    def boost_strength(self) -> float:
        return self.max_boost * self.gap


@dataclass
class RerankerStats:
    calls: int = 0
    tokens_checked: int = 0
    tokens_penalized: int = 0
    tokens_boosted: int = 0


class Reranker(LogitsProcessor):
    """Word-list reranker as a HuggingFace LogitsProcessor.

    Args:
        tokenizer: HuggingFace tokenizer for decoding token IDs to text.
        allow_sets: List of allowed word sets for complement banning.
            If non-empty, words not in ANY allow set are penalized.
        boost_lists: List of BoostList objects for coverage modulation.
        top_k: Number of top candidates to evaluate per step.
        penalize: Logit penalty for non-allowed words.
    """

    def __init__(
        self,
        tokenizer,
        allow_sets: list[set[str]],
        boost_lists: list[BoostList],
        top_k: int = 200,
        penalize: float = 15.0,
    ):
        self.tokenizer = tokenizer
        self.allow_sets = allow_sets
        self.boost_lists = boost_lists
        self.top_k = top_k
        self.penalize = penalize
        self.stats = RerankerStats()
        # Precompute union of all allow sets for fast membership check
        self._allow_union: set[str] = set()
        for s in allow_sets:
            self._allow_union |= s
        self._has_allow = len(self._allow_union) > 0

    def _get_partial_word(self, input_ids: torch.LongTensor | None) -> str:
        """Decode the trailing partial word from generated tokens so far."""
        if input_ids is None:
            return ""
        ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist()
        if not ids:
            return ""
        special = set(self.tokenizer.all_special_ids)
        partial_ids = []
        start = max(0, len(ids) - 10)
        for i in range(len(ids) - 1, start - 1, -1):
            tid = ids[i]
            if tid in special:
                break
            text = self.tokenizer.decode([tid], skip_special_tokens=False)
            partial_ids.insert(0, tid)
            if text.startswith(WORD_BOUNDARY_CHAR) or text.startswith(" "):
                break
        if not partial_ids:
            return ""
        return self.tokenizer.decode(partial_ids, skip_special_tokens=True).strip()

    def _update_boost_coverage(self, input_ids: torch.LongTensor | None) -> None:
        """Recompute per-boost-list coverage from generated tokens so far."""
        import re
        if input_ids is None or not self.boost_lists:
            return
        text = self.tokenizer.decode(
            input_ids[0] if input_ids.dim() > 1 else input_ids,
            skip_special_tokens=True,
        )
        words = re.findall(r'[a-zA-Z]+', text)
        total = len(words)
        if total < 3:
            return
        word_set = {w.lower() for w in words}
        for bl in self.boost_lists:
            hits = len(word_set & bl.words)
            bl.update_coverage(hit=hits, total=total)

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """Rerank logits via set membership checks."""
        self.stats.calls += 1
        self._update_boost_coverage(input_ids)

        partial = self._get_partial_word(input_ids)

        k = min(self.top_k, scores.shape[-1])
        top_vals, top_ids = torch.topk(scores[0], k)

        for i in range(k):
            tid = top_ids[i].item()
            text = self.tokenizer.decode([tid], skip_special_tokens=False)

            is_word_start = text.startswith(WORD_BOUNDARY_CHAR) or text.startswith(" ")
            fragment = text.lstrip(WORD_BOUNDARY_CHAR).lstrip(" ").strip()

            if not fragment or not fragment.isalpha():
                continue

            word = fragment if is_word_start else (partial + fragment if partial else fragment)
            if not word or len(word) < 2:
                continue

            word_lower = word.lower()
            self.stats.tokens_checked += 1

            # Complement ban: penalize words not in allow set
            if self._has_allow and word_lower not in self._allow_union:
                scores[0, tid] -= self.penalize
                self.stats.tokens_penalized += 1

            # Boost: add logits for words in boost lists (proportional to gap)
            for bl in self.boost_lists:
                if word_lower in bl.words and bl.gap > 0:
                    scores[0, tid] += bl.boost_strength
                    self.stats.tokens_boosted += 1

        return scores
  • [ ] Step 4: Run test to verify it passes

Run: cd packages/governors && uv run python -m pytest tests/test_reranker_v2.py -v Expected: PASS

  • [ ] Step 5: Commit
git add packages/governors/src/phonolex_governors/generation/reranker.py packages/governors/tests/test_reranker_v2.py
git commit -m "feat(reranker): rewrite to set membership + coverage modulation — no G2P"

Task 6: Rewrite governor.py — Constraint Resolution

Files: - Rewrite: packages/generation/server/governor.py - Rewrite: packages/generation/server/tests/test_governor.py

  • [ ] Step 1: Write the failing tests

Replace packages/generation/server/tests/test_governor.py:

import pytest
from unittest.mock import patch, AsyncMock
from server.governor import resolve_constraints, prepare_generation
from server.schemas import (
    ExcludeConstraint, IncludeConstraint, BoundConstraint,
    BoundBoostConstraint, ContrastiveConstraint, ResolvedConstraint,
)


@pytest.fixture
def mock_api_response():
    """Mock the Workers API word-list response."""
    return {"words": ["cat", "dog", "bat", "hat"], "total": 4}


@pytest.mark.asyncio
async def test_resolve_exclude(mock_api_response):
    constraints = [ExcludeConstraint(type="exclude", phonemes=["ɹ"])]
    with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
               return_value=mock_api_response):
        resolved = await resolve_constraints(constraints)
    assert len(resolved) == 1
    assert resolved[0].mode == "ban"
    assert resolved[0].words == ["cat", "dog", "bat", "hat"]
    assert "ɹ" in resolved[0].label


@pytest.mark.asyncio
async def test_resolve_include(mock_api_response):
    constraints = [IncludeConstraint(type="include", phonemes=["k"], target_rate=0.2)]
    with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
               return_value=mock_api_response):
        resolved = await resolve_constraints(constraints)
    assert len(resolved) == 1
    assert resolved[0].mode == "boost"
    assert resolved[0].coverage_target == 0.2


@pytest.mark.asyncio
async def test_resolve_bound(mock_api_response):
    constraints = [BoundConstraint(type="bound", norm="aoa_kuperman", max=5.0)]
    with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
               return_value=mock_api_response):
        resolved = await resolve_constraints(constraints)
    assert len(resolved) == 1
    assert resolved[0].mode == "ban"
    # Bounds use complement strategy
    assert resolved[0].strategy == "complement"


def test_prepare_generation_ban_direct():
    tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
        "google/t5gemma-9b-2b-ul2-it"
    )
    resolved = [
        ResolvedConstraint(mode="ban", words=["car", "card"], strategy="direct", label="test"),
    ]
    bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
    assert len(bad_words_ids) > 0  # "car" and "card" tokenized
    assert len(allow_sets) == 0
    assert len(boost_lists) == 0


def test_prepare_generation_ban_complement():
    tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
        "google/t5gemma-9b-2b-ul2-it"
    )
    resolved = [
        ResolvedConstraint(mode="ban", words=["cat", "dog"], strategy="complement", label="test"),
    ]
    bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
    assert len(bad_words_ids) == 0
    assert len(allow_sets) == 1
    assert "cat" in allow_sets[0]


def test_prepare_generation_boost():
    tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
        "google/t5gemma-9b-2b-ul2-it"
    )
    resolved = [
        ResolvedConstraint(mode="boost", words=["cat", "kit"], strategy="direct",
                          coverage_target=0.2, label="include /k/"),
    ]
    bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
    assert len(bad_words_ids) == 0
    assert len(allow_sets) == 0
    assert len(boost_lists) == 1
    assert boost_lists[0].target_rate == 0.2
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/generation && uv run python -m pytest server/tests/test_governor.py -v Expected: FAIL — resolve_constraints, prepare_generation, _call_word_list_api not defined

  • [ ] Step 3: Rewrite governor.py

Replace packages/generation/server/governor.py:

"""Constraint resolution and generation preparation.

Resolves semantic constraints to word lists via PhonoLex Workers API,
then prepares BAN/BOOST inputs for the generation pipeline.
"""
from __future__ import annotations

import logging
import os

import httpx

from server.schemas import (
    Constraint, ExcludeConstraint, IncludeConstraint,
    BoundConstraint, BoundBoostConstraint, ContrastiveConstraint,
    ResolvedConstraint,
)
from phonolex_governors.generation.reranker import BoostList

log = logging.getLogger("phonolex.governor")

PHONOLEX_API_URL = os.environ.get("PHONOLEX_API_URL", "http://localhost:8788")


async def _call_word_list_api(payload: dict) -> dict:
    """Call the PhonoLex Workers API word-list endpoint."""
    async with httpx.AsyncClient(timeout=30.0) as client:
        resp = await client.post(f"{PHONOLEX_API_URL}/api/words/word-list", json=payload)
        resp.raise_for_status()
        return resp.json()


async def _call_contrastive_api(endpoint: str, payload: dict) -> list[str]:
    """Call a PhonoLex contrastive endpoint and extract word strings."""
    async with httpx.AsyncClient(timeout=30.0) as client:
        resp = await client.post(f"{PHONOLEX_API_URL}/api/contrastive/{endpoint}", json=payload)
        resp.raise_for_status()
        data = resp.json()
    words = set()
    for pair in data.get("pairs", data.get("items", [])):
        if "word1" in pair:
            words.add(pair["word1"]["word"] if isinstance(pair["word1"], dict) else pair["word1"])
        if "word2" in pair:
            words.add(pair["word2"]["word"] if isinstance(pair["word2"], dict) else pair["word2"])
    return list(words)


async def resolve_constraints(constraints: list[Constraint]) -> list[ResolvedConstraint]:
    """Resolve semantic constraints to word lists via PhonoLex Workers API."""
    resolved: list[ResolvedConstraint] = []

    for c in constraints:
        if isinstance(c, ExcludeConstraint):
            data = await _call_word_list_api({"include_phonemes": c.phonemes})
            resolved.append(ResolvedConstraint(
                mode="ban", words=data["words"], strategy="direct",
                label=f"exclude /{', '.join(c.phonemes)}/",
            ))

        elif isinstance(c, IncludeConstraint):
            data = await _call_word_list_api({"include_phonemes": c.phonemes})
            resolved.append(ResolvedConstraint(
                mode="boost", words=data["words"],
                coverage_target=c.target_rate,
                label=f"include /{', '.join(c.phonemes)}/ {c.target_rate:.0%}",
            ))

        elif isinstance(c, BoundConstraint):
            # Get the allowed words (passing the bound) — complement ban
            filters = {}
            if c.min is not None:
                filters[f"min_{c.norm}"] = c.min
            if c.max is not None:
                filters[f"max_{c.norm}"] = c.max
            data = await _call_word_list_api({"filters": filters})
            label_parts = []
            if c.min is not None:
                label_parts.append(f"{c.norm}{c.min}")
            if c.max is not None:
                label_parts.append(f"{c.norm}{c.max}")
            resolved.append(ResolvedConstraint(
                mode="ban", words=data["words"], strategy="complement",
                label=", ".join(label_parts),
            ))

        elif isinstance(c, BoundBoostConstraint):
            filters = {}
            if c.min is not None:
                filters[f"min_{c.norm}"] = c.min
            if c.max is not None:
                filters[f"max_{c.norm}"] = c.max
            data = await _call_word_list_api({"filters": filters})
            label_parts = []
            if c.min is not None:
                label_parts.append(f"{c.norm}{c.min}")
            if c.max is not None:
                label_parts.append(f"{c.norm}{c.max}")
            resolved.append(ResolvedConstraint(
                mode="boost", words=data["words"],
                coverage_target=c.coverage_target,
                label=f"boost {', '.join(label_parts)} {c.coverage_target:.0%}",
            ))

        elif isinstance(c, ContrastiveConstraint):
            if c.pair_type == "minpair":
                words = await _call_contrastive_api("minimal-pairs", {
                    "phoneme1": c.phoneme1, "phoneme2": c.phoneme2,
                    "position": c.position,
                })
            else:
                words = await _call_contrastive_api("maximal-opposition/word-lists", {
                    "phoneme1": c.phoneme1, "phoneme2": c.phoneme2,
                    "position": c.position,
                })
            resolved.append(ResolvedConstraint(
                mode="boost", words=words,
                coverage_target=0.15,
                label=f"{c.pair_type} /{c.phoneme1}/↔/{c.phoneme2}/ {c.position}",
            ))

    return resolved


def prepare_generation(
    resolved: list[ResolvedConstraint],
    tokenizer,
) -> tuple[list[list[int]], list[set[str]], list[BoostList]]:
    """Convert resolved constraints to generation inputs.

    Returns:
        bad_words_ids: Token ID sequences for HuggingFace bad_words_ids
        allow_sets: Word sets for Reranker complement banning
        boost_lists: BoostList objects for Reranker coverage modulation
    """
    bad_words_ids: list[list[int]] = []
    allow_sets: list[set[str]] = []
    boost_lists: list[BoostList] = []

    for rc in resolved:
        if rc.mode == "ban":
            if rc.strategy == "direct":
                # Encode each banned word to token IDs
                for word in rc.words:
                    for variant in [word, " " + word]:
                        tids = tokenizer.encode(variant, add_special_tokens=False)
                        if tids and tids not in bad_words_ids:
                            bad_words_ids.append(tids)
            elif rc.strategy == "complement":
                allow_sets.append(set(rc.words))

        elif rc.mode == "boost":
            boost_lists.append(BoostList(
                words=set(rc.words),
                target_rate=rc.coverage_target or 0.2,
                label=rc.label,
            ))

    return bad_words_ids, allow_sets, boost_lists
  • [ ] Step 4: Run test to verify it passes

Run: cd packages/generation && uv run python -m pytest server/tests/test_governor.py -v Expected: PASS (async tests pass, tokenizer tests may skip if model not available)

  • [ ] Step 5: Commit
git add packages/generation/server/governor.py packages/generation/server/tests/test_governor.py
git commit -m "feat(governor): rewrite to word-list resolution via Workers API"

Task 7: Rewrite generate-single Route with SSE

Files: - Modify: packages/generation/server/routes/generate.py - Modify: packages/generation/server/model.py

  • [ ] Step 1: Rewrite the generate-single route

Replace the generate_single function in packages/generation/server/routes/generate.py with an SSE streaming version. The old session-based /generate route and all session/profile imports are removed.

import json
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from server import model
from server.schemas import (
    GenerateSingleRequest, SingleGenerationResponse,
    ExcludeConstraint, IncludeConstraint, BoundConstraint,
    BoundBoostConstraint, ContrastiveConstraint, BoostCoverage,
    WordViolation, WordComplianceDetail,
)
from server.governor import resolve_constraints, prepare_generation

log = logging.getLogger("phonolex.generation")
router = APIRouter(prefix="/api")


def init(ps=None, ss=None, gc=None):
    """Legacy init — no-op. Kept for compatibility with main.py."""
    pass


async def _generate_sse(req: GenerateSingleRequest):
    """SSE generator: yields status events, then the final response."""
    import re
    from phonolex_governors.checking.checker import check_word
    from phonolex_governors.checking.g2p import G2PCache
    from phonolex_governors.generation.reranker import Reranker

    def emit(msg: str) -> str:
        return f"data: {json.dumps({'status': msg})}\n\n"

    try:
        # --- Resolve constraints to word lists ---
        yield emit("Resolving constraints...")
        resolved = await resolve_constraints(req.constraints)
        yield emit(f"Fetched word lists ({len(resolved)} constraints)")

        # --- Prepare BAN/BOOST ---
        tokenizer = model.get_tokenizer()
        bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)

        # --- Build Reranker ---
        reranker = Reranker(
            tokenizer=tokenizer,
            allow_sets=allow_sets,
            boost_lists=boost_lists,
        ) if (allow_sets or boost_lists) else None

        # --- Build CheckerConfig for GUARD ---
        from phonolex_governors.checking.checker import CheckerConfig, PhonemeExcludeCheck, VocabOnlyCheck
        checks = []
        bound_norms = []
        for c in req.constraints:
            if isinstance(c, ExcludeConstraint):
                checks.append(PhonemeExcludeCheck(excluded=set(c.phonemes)))
            elif isinstance(c, BoundConstraint):
                # For GUARD: build VocabOnlyCheck from complement allow set
                for rc in resolved:
                    if rc.strategy == "complement" and rc.label and c.norm in rc.label:
                        checks.append(VocabOnlyCheck(allowed_words=set(rc.words)))
                        break
                bound_norms.append(c.norm)

        from server.word_norms import get_word_norms, get_vocab_memberships
        norms_data = get_word_norms() or {}
        vocab_data = get_vocab_memberships() or {}
        stop_words = {w for w, m in vocab_data.items() if any('stop' in s for s in m)}
        checker_config = CheckerConfig(
            checks=checks, norm_lookup=norms_data,
            vocab_lookup=vocab_data, stop_words=stop_words,
        ) if checks else None

        # --- Generate with four-layer pipeline ---
        g2p_cache = G2PCache()
        gen_kwargs = dict(
            max_new_tokens=128,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.3,
        )
        if bad_words_ids:
            gen_kwargs["bad_words_ids"] = bad_words_ids

        import time
        t0 = time.time()
        best_text = ""
        best_ids = []
        total_retries = 0
        max_retries = 4
        n_drafts = 2

        for draft_idx in range(n_drafts):
            draft_bad = list(bad_words_ids)  # copy per draft
            for attempt in range(max_retries + 1):
                yield emit(f"Generating draft {draft_idx + 1} (attempt {attempt + 1})...")

                current_kwargs = {**gen_kwargs}
                if draft_bad:
                    current_kwargs["bad_words_ids"] = draft_bad

                processors = []
                if reranker is not None:
                    from transformers import LogitsProcessorList
                    processors.append(reranker)

                gen_ids, text, _ = model.generate_single(
                    req.prompt,
                    logits_processor=LogitsProcessorList(processors) if processors else None,
                )
                # Override with current kwargs for bad_words_ids
                # (generate_single needs to accept bad_words_ids — see model.py update)

                yield emit("Checking compliance...")
                words = re.findall(r'[a-zA-Z]+', text)
                violations = []
                if checker_config:
                    for w in words:
                        result = check_word(w, checker_config, g2p_cache)
                        if not result.passed:
                            violations.append(w)

                if not violations:
                    best_text = text
                    best_ids = gen_ids
                    yield emit(f"Draft {draft_idx + 1}: compliant")
                    break
                else:
                    yield emit(f"Draft {draft_idx + 1}: {len(violations)} violations, retrying...")
                    for v_word in violations:
                        for variant in [v_word, " " + v_word]:
                            tids = tokenizer.encode(variant, add_special_tokens=False)
                            if tids and tids not in draft_bad:
                                draft_bad.append(tids)
                    total_retries += 1

        gen_time_ms = (time.time() - t0) * 1000

        # --- Compute compliance details ---
        word_violations = []
        word_compliance = []
        all_words = re.findall(r'[a-zA-Z]+', best_text)
        if checker_config:
            for w in all_words:
                result = check_word(w, checker_config, g2p_cache)
                clean = w.strip().lower()
                word_norms_entry = norms_data.get(clean, {})
                relevant_values = {n: word_norms_entry.get(n) for n in bound_norms}
                if not result.passed:
                    details = [v.details for v in result.violations]
                    word_violations.append(WordViolation(word=w, details=details))
                    word_compliance.append(WordComplianceDetail(
                        word=w, passed=False, values=relevant_values, violations=details,
                    ))
                else:
                    word_compliance.append(WordComplianceDetail(
                        word=w, passed=True, values=relevant_values,
                    ))

        # --- Compute boost coverage ---
        boost_coverage = []
        if boost_lists and all_words:
            for bl in boost_lists:
                hit_words = [w for w in all_words if w.lower() in bl.words]
                rate = len(hit_words) / len(all_words) if all_words else 0.0
                boost_coverage.append(BoostCoverage(
                    label=bl.label,
                    target_rate=round(bl.target_rate * 100, 1),
                    actual_rate=round(rate * 100, 1),
                    hit_words=hit_words,
                    total_words=len(all_words),
                ))

        # --- Enrich tokens ---
        gen_tokens = model.enrich_tokens(best_ids, {}, [], tokenizer)

        response = SingleGenerationResponse(
            tokens=gen_tokens, text=best_text,
            gen_time_ms=gen_time_ms,
            compliant=len(word_violations) == 0,
            violation_count=len(word_violations),
            violation_words=[wv.word for wv in word_violations],
            word_violations=word_violations,
            word_compliance=word_compliance,
            boost_coverage=boost_coverage,
        )

        yield f"data: {json.dumps({'result': response.model_dump()})}\n\n"

    except Exception as e:
        log.exception("Generation failed")
        yield f"data: {json.dumps({'error': str(e)})}\n\n"


@router.post("/generate-single")
async def generate_single(req: GenerateSingleRequest):
    if not model.is_ready():
        raise HTTPException(503, "Model not ready")
    return StreamingResponse(
        _generate_sse(req),
        media_type="text/event-stream",
    )

Note: This is a structural sketch. The actual implementation will need to pass bad_words_ids through to model.generate_single() and handle the generate_with_checking refactor. The key pattern is: SSE stream wrapping the four-layer pipeline, emitting status events at each stage.

  • [ ] Step 2: Remove session/profile routes from main.py

Update packages/generation/server/main.py to remove session/profile initialization and routes. Remove imports of ProfileStore, SessionStore, GovernorCache. Remove lookup loading (no longer needed). Keep model loading and word_norms loading.

  • [ ] Step 3: Delete dead server files
rm packages/generation/server/sessions.py
rm packages/generation/server/profiles.py
rm packages/generation/server/routes/sessions.py
rm packages/generation/server/routes/profiles.py
rm packages/generation/server/tests/test_sessions.py
rm packages/generation/server/tests/test_profiles.py
  • [ ] Step 4: Verify server starts

Run: cd packages/generation && PHONOLEX_API_URL=http://localhost:8788 uv run uvicorn server.main:app --host 0.0.0.0 --port 3003 Expected: Server starts without import errors. May fail on model loading (expected if no GPU) but should not crash on import.

  • [ ] Step 5: Commit
git add -u packages/generation/server/
git commit -m "feat(generation): SSE streaming, word-list pipeline, remove sessions/profiles"

Task 8: Frontend SSE Client

Files: - Modify: packages/web/frontend/src/lib/generationApi.ts - Modify: packages/web/frontend/src/components/tools/GovernedGenerationTool/index.tsx

  • [ ] Step 1: Rewrite generationApi.ts for SSE
import { useState, useEffect, useRef } from 'react';
import type { Constraint, SingleGenerationResponse, ServerStatus } from '../types/governance';
import { freshRequestId, getRequestId, logError } from './logger';

const GENERATION_API_URL = import.meta.env.VITE_GENERATION_API_URL || 'http://localhost:8000';

export interface GenerationCallbacks {
  onStatus: (message: string) => void;
  onResult: (response: SingleGenerationResponse) => void;
  onError: (error: string) => void;
}

export async function generateContent(
  prompt: string,
  constraints: Constraint[],
  callbacks: GenerationCallbacks,
): Promise<void> {
  const rid = freshRequestId();
  const res = await fetch(`${GENERATION_API_URL}/api/generate-single`, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
      'X-Request-ID': rid,
    },
    body: JSON.stringify({ prompt, constraints }),
  });

  if (!res.ok) {
    const detail = await res.text().catch(() => res.statusText);
    logError('Generation API request failed', {
      method: 'POST', url: '/api/generate-single', status: res.status, detail,
    });
    callbacks.onError(`Generation failed (${res.status}): ${detail}`);
    return;
  }

  const reader = res.body?.getReader();
  if (!reader) {
    callbacks.onError('No response body');
    return;
  }

  const decoder = new TextDecoder();
  let buffer = '';

  while (true) {
    const { done, value } = await reader.read();
    if (done) break;

    buffer += decoder.decode(value, { stream: true });
    const lines = buffer.split('\n');
    buffer = lines.pop() || '';

    for (const line of lines) {
      if (!line.startsWith('data: ')) continue;
      const payload = line.slice(6).trim();
      if (!payload) continue;

      try {
        const event = JSON.parse(payload);
        if (event.status) {
          callbacks.onStatus(event.status);
        } else if (event.result) {
          callbacks.onResult(event.result);
        } else if (event.error) {
          callbacks.onError(event.error);
        }
      } catch {
        // Skip malformed events
      }
    }
  }
}

export async function fetchServerStatus(): Promise<ServerStatus> {
  const res = await fetch(`${GENERATION_API_URL}/api/server/status`, {
    headers: { 'X-Request-ID': getRequestId() },
  });
  if (!res.ok) throw new Error(`Server status check failed (${res.status})`);
  return res.json();
}

export function useServerStatus(pollIntervalMs = 5000): ServerStatus | null {
  const [status, setStatus] = useState<ServerStatus | null>(null);
  const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);

  useEffect(() => {
    let cancelled = false;
    async function poll() {
      try {
        const s = await fetchServerStatus();
        if (!cancelled) setStatus(s);
      } catch {
        if (!cancelled) setStatus(null);
      }
    }
    poll();
    intervalRef.current = setInterval(poll, pollIntervalMs);
    return () => {
      cancelled = true;
      if (intervalRef.current) clearInterval(intervalRef.current);
    };
  }, [pollIntervalMs]);

  return status;
}
  • [ ] Step 2: Update GovernedGenerationTool index.tsx

Add a statusMessage state and pass onStatus callback to generateContent. Display the status line below the Generate button:

const [statusMessage, setStatusMessage] = useState<string | null>(null);

// In the generate handler:
setStatusMessage('Starting...');
await generateContent(prompt, compiled, {
  onStatus: (msg) => setStatusMessage(msg),
  onResult: (response) => {
    setStatusMessage(null);
    // ... create GenerationResult from response and add to feed
  },
  onError: (err) => {
    setStatusMessage(null);
    // ... show error
  },
});

Render the status line:

{statusMessage && (
  <Typography variant="body2" color="text.secondary" sx={{ mt: 1, fontStyle: 'italic' }}>
    {statusMessage}
  </Typography>
)}
  • [ ] Step 3: Update OutputCard.tsx for generalized coverage

Replace includeCoverage rendering with boostCoverage — same visual pattern but using BoostCoverage.label instead of hardcoded phoneme display.

  • [ ] Step 4: Run frontend build

Run: cd packages/web/frontend && npm run build Expected: No TypeScript errors

  • [ ] Step 5: Commit
git add packages/web/frontend/src/lib/generationApi.ts packages/web/frontend/src/components/tools/GovernedGenerationTool/
git commit -m "feat(frontend): SSE streaming, status line, generalized coverage display"

Task 9: Build Bigram Transition Matrix

Files: - Create: packages/generation/scripts/build_bigram_matrix.py

  • [ ] Step 1: Write the builder script
"""Build a BPE bigram transition matrix from a text corpus.

Tokenizes a corpus with the T5Gemma tokenizer and counts token bigrams.
Outputs a sparse matrix: for each token, the top-k most likely next tokens
and their conditional probabilities.

Usage:
    python build_bigram_matrix.py --corpus data/corpus.txt --output data/bigram_matrix.npz
    python build_bigram_matrix.py --use-wiki --output data/bigram_matrix.npz
"""

import argparse
import logging
import numpy as np
from collections import Counter
from pathlib import Path
from transformers import AutoTokenizer

log = logging.getLogger(__name__)

MODEL_NAME = "google/t5gemma-9b-2b-ul2-it"
TOP_K = 100  # transitions per token


def build_from_corpus(corpus_path: str, tokenizer, top_k: int = TOP_K) -> dict:
    """Build bigram counts from a text file."""
    log.info("Tokenizing corpus: %s", corpus_path)
    bigram_counts: dict[int, Counter] = {}

    with open(corpus_path) as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            ids = tokenizer.encode(line, add_special_tokens=False)
            for a, b in zip(ids, ids[1:]):
                if a not in bigram_counts:
                    bigram_counts[a] = Counter()
                bigram_counts[a][b] += 1
            if (i + 1) % 100_000 == 0:
                log.info("Processed %d lines", i + 1)

    log.info("Computing conditional probabilities for %d tokens", len(bigram_counts))
    matrix = {}
    for token_id, counts in bigram_counts.items():
        total = sum(counts.values())
        top = counts.most_common(top_k)
        matrix[token_id] = {
            next_id: count / total for next_id, count in top
        }

    return matrix


def save_matrix(matrix: dict, output_path: str) -> None:
    """Save as compressed numpy arrays (sparse representation)."""
    token_ids = []
    next_ids = []
    probs = []

    for tid, transitions in matrix.items():
        for next_tid, prob in transitions.items():
            token_ids.append(tid)
            next_ids.append(next_tid)
            probs.append(prob)

    np.savez_compressed(
        output_path,
        token_ids=np.array(token_ids, dtype=np.int32),
        next_ids=np.array(next_ids, dtype=np.int32),
        probs=np.array(probs, dtype=np.float32),
    )
    log.info("Saved matrix to %s (%d entries)", output_path, len(token_ids))


def main():
    logging.basicConfig(level=logging.INFO)
    parser = argparse.ArgumentParser(description="Build BPE bigram matrix")
    parser.add_argument("--corpus", type=str, help="Path to text corpus file")
    parser.add_argument("--output", type=str, default="data/bigram_matrix.npz")
    parser.add_argument("--top-k", type=int, default=TOP_K)
    args = parser.parse_args()

    if not args.corpus:
        parser.error("--corpus is required")

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    matrix = build_from_corpus(args.corpus, tokenizer, top_k=args.top_k)
    save_matrix(matrix, args.output)


if __name__ == "__main__":
    main()
  • [ ] Step 2: Add data directory to gitignore

Check that packages/generation/data/ is in .gitignore. If not, add packages/generation/data/*.npz.

  • [ ] Step 3: Commit
git add packages/generation/scripts/build_bigram_matrix.py
git commit -m "feat(generation): bigram transition matrix builder script"

Task 10: Bigram Dead-End Filter (Lookahead Layer 3)

Files: - Create: packages/governors/src/phonolex_governors/generation/lookahead.py - Create: packages/governors/tests/test_lookahead.py

  • [ ] Step 1: Write the failing tests

Create packages/governors/tests/test_lookahead.py:

import pytest
import numpy as np
import torch
from phonolex_governors.generation.lookahead import BigramDeadEndFilter, load_bigram_matrix


@pytest.fixture
def tiny_matrix(tmp_path):
    """Create a tiny bigram matrix for testing."""
    # Token 10 → Token 11 (p=0.8), Token 12 (p=0.2)
    # Token 13 → Token 14 (p=0.9), Token 15 (p=0.1)
    token_ids = np.array([10, 10, 13, 13], dtype=np.int32)
    next_ids = np.array([11, 12, 14, 15], dtype=np.int32)
    probs = np.array([0.8, 0.2, 0.9, 0.1], dtype=np.float32)
    path = tmp_path / "test_bigram.npz"
    np.savez_compressed(path, token_ids=token_ids, next_ids=next_ids, probs=probs)
    return path


@pytest.fixture
def mock_tokenizer():
    from unittest.mock import MagicMock
    tok = MagicMock()
    tok.decode = MagicMock(side_effect=lambda ids, **kw: {
        (10,): "▁dre", (11,): "am", (12,): "ss",
        (13,): "▁ca", (14,): "t", (15,): "r",
        (10, 11): "dream", (10, 12): "dress",
        (13, 14): "cat", (13, 15): "car",
    }.get(tuple(ids), "unk"))
    return tok


def test_load_bigram_matrix(tiny_matrix):
    matrix = load_bigram_matrix(str(tiny_matrix))
    assert 10 in matrix
    assert 11 in matrix[10]
    assert abs(matrix[10][11] - 0.8) < 0.01


def test_dead_end_score(tiny_matrix, mock_tokenizer):
    matrix = load_bigram_matrix(str(tiny_matrix))
    ban_words = {"dream"}  # Token 10→11 = "dream" is banned
    filt = BigramDeadEndFilter(matrix, mock_tokenizer, ban_words, vocab_size=20)
    # Token 10 should have dead_end_score = P(11|10) = 0.8 (since dream is banned)
    scores = filt.dead_end_scores
    assert scores[10] > 0.5


def test_filter_penalizes_dead_ends(tiny_matrix, mock_tokenizer):
    matrix = load_bigram_matrix(str(tiny_matrix))
    ban_words = {"dream"}
    filt = BigramDeadEndFilter(matrix, mock_tokenizer, ban_words,
                                vocab_size=20, penalty=10.0)
    logits = torch.zeros(1, 20)
    logits[0, 10] = 5.0
    logits[0, 13] = 5.0
    input_ids = torch.tensor([[0]])
    result = filt(input_ids, logits)
    # Token 10 should be penalized (high dead-end), token 13 should not
    assert result[0, 10].item() < result[0, 13].item()
  • [ ] Step 2: Run test to verify it fails

Run: cd packages/governors && uv run python -m pytest tests/test_lookahead.py -v Expected: FAIL — BigramDeadEndFilter, load_bigram_matrix not defined

  • [ ] Step 3: Implement lookahead.py

Create packages/governors/src/phonolex_governors/generation/lookahead.py:

"""Lookahead mechanisms for constrained generation escalation.

Layer 3: BigramDeadEndFilter — precomputed P(B|A) scored against ban lists.
Layer 4: TargetedRolloutProcessor — short greedy rollouts for context-aware checking.

Based on:
- FUDGE (Yang & Klein, 2021) — prefix-level constraint prediction
- NeuroLogic A*esque Decoding (Lu et al., 2022) — lookahead heuristics
"""

from __future__ import annotations

import logging
from typing import Optional

import numpy as np
import torch
from transformers import LogitsProcessor

log = logging.getLogger(__name__)


def load_bigram_matrix(path: str) -> dict[int, dict[int, float]]:
    """Load a sparse bigram matrix from .npz file."""
    data = np.load(path)
    token_ids = data["token_ids"]
    next_ids = data["next_ids"]
    probs = data["probs"]

    matrix: dict[int, dict[int, float]] = {}
    for i in range(len(token_ids)):
        tid = int(token_ids[i])
        nid = int(next_ids[i])
        p = float(probs[i])
        if tid not in matrix:
            matrix[tid] = {}
        matrix[tid][nid] = p

    return matrix


class BigramDeadEndFilter(LogitsProcessor):
    """Penalize tokens whose probable continuations form banned words.

    Computes dead_end_score[A] = Σ P(B|A) for all B where decode(A+B) ∈ ban_set.
    Applied as a penalty vector: logits[A] -= penalty * dead_end_score[A].

    Based on FUDGE (Yang & Klein, 2021).
    """

    def __init__(
        self,
        bigram_matrix: dict[int, dict[int, float]],
        tokenizer,
        ban_words: set[str],
        vocab_size: int,
        penalty: float = 10.0,
    ):
        self.penalty = penalty
        self.dead_end_scores = self._compute_scores(
            bigram_matrix, tokenizer, ban_words, vocab_size,
        )

    def _compute_scores(
        self,
        matrix: dict[int, dict[int, float]],
        tokenizer,
        ban_words: set[str],
        vocab_size: int,
    ) -> torch.Tensor:
        """Precompute per-token dead-end scores."""
        scores = torch.zeros(vocab_size)
        ban_lower = {w.lower() for w in ban_words}

        for tid, transitions in matrix.items():
            if tid >= vocab_size:
                continue
            dead_prob = 0.0
            for next_tid, prob in transitions.items():
                if next_tid >= vocab_size:
                    continue
                # Decode the two-token sequence to see what word it forms
                word = tokenizer.decode([tid, next_tid], skip_special_tokens=True).strip().lower()
                # Also check just the continuation as a word
                if word in ban_lower:
                    dead_prob += prob
            scores[tid] = dead_prob

        nonzero = (scores > 0).sum().item()
        log.info("Dead-end filter: %d/%d tokens have nonzero scores", nonzero, vocab_size)
        return scores

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
    ) -> torch.FloatTensor:
        device_scores = self.dead_end_scores.to(scores.device)
        v = min(scores.shape[-1], device_scores.shape[0])
        scores[0, :v] -= self.penalty * device_scores[:v]
        return scores


class TargetedRolloutProcessor(LogitsProcessor):
    """Context-aware lookahead via short greedy rollouts.

    For each top-k candidate, runs 2-3 greedy forward passes and checks
    if the resulting words hit the ban set. Penalizes candidates whose
    rollouts produce violations.

    Based on NeuroLogic A*esque Decoding (Lu et al., 2022).
    """

    def __init__(
        self,
        model,
        tokenizer,
        ban_words: set[str],
        allow_words: Optional[set[str]] = None,
        rollout_depth: int = 3,
        top_k: int = 50,
        penalty: float = 50.0,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.ban_words = {w.lower() for w in ban_words}
        self.allow_words = {w.lower() for w in allow_words} if allow_words else None
        self.rollout_depth = rollout_depth
        self.top_k = top_k
        self.penalty = penalty

    def _is_violation(self, word: str) -> bool:
        w = word.lower().strip()
        if not w or not w.isalpha() or len(w) < 2:
            return False
        if w in self.ban_words:
            return True
        if self.allow_words is not None and w not in self.allow_words:
            return True
        return False

    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
    ) -> torch.FloatTensor:
        import re

        k = min(self.top_k, scores.shape[-1])
        _, top_ids = torch.topk(scores[0], k)

        for i in range(k):
            tid = top_ids[i].item()
            # Build candidate sequence
            candidate = torch.cat([input_ids[0], torch.tensor([tid], device=input_ids.device)])

            # Greedy rollout
            rollout_ids = [tid]
            current = candidate.unsqueeze(0)
            with torch.no_grad():
                for _ in range(self.rollout_depth):
                    outputs = self.model(current, use_cache=False)
                    next_logits = outputs.logits[0, -1, :]
                    next_id = next_logits.argmax().item()
                    rollout_ids.append(next_id)
                    current = torch.cat([
                        current,
                        torch.tensor([[next_id]], device=current.device),
                    ], dim=1)

            # Decode rollout and check words
            rollout_text = self.tokenizer.decode(rollout_ids, skip_special_tokens=True)
            rollout_words = re.findall(r'[a-zA-Z]+', rollout_text)
            if any(self._is_violation(w) for w in rollout_words):
                scores[0, tid] -= self.penalty

        return scores
  • [ ] Step 4: Run test to verify it passes

Run: cd packages/governors && uv run python -m pytest tests/test_lookahead.py -v Expected: PASS

  • [ ] Step 5: Commit
git add packages/governors/src/phonolex_governors/generation/lookahead.py packages/governors/tests/test_lookahead.py
git commit -m "feat(lookahead): bigram dead-end filter + targeted rollout processors"

Task 11: Integration — Wire Lookahead into Generation Pipeline

Files: - Modify: packages/generation/server/routes/generate.py - Modify: packages/generation/server/model.py

  • [ ] Step 1: Add lookahead escalation to the SSE pipeline

In the generation loop in routes/generate.py, after the retry budget is exhausted, add escalation:

# After max_retries reached for a draft:
if attempt == max_retries and violations:
    # --- Layer 3: Bigram dead-end filter ---
    yield emit("Activating bigram dead-end filter...")
    from phonolex_governors.generation.lookahead import BigramDeadEndFilter, load_bigram_matrix
    bigram_path = Path(__file__).resolve().parents[2] / "data" / "bigram_matrix.npz"
    if bigram_path.exists():
        matrix = load_bigram_matrix(str(bigram_path))
        all_ban_words = set()
        for rc in resolved:
            if rc.mode == "ban" and rc.strategy == "direct":
                all_ban_words.update(rc.words)
        dead_end_filter = BigramDeadEndFilter(
            matrix, tokenizer, all_ban_words, vocab_size=model.get_vocab_size(),
        )
        # Add to processors and generate one more draft
        yield emit("Generating with lookahead...")
        # ... generate with dead_end_filter added to LogitsProcessorList

    # --- Layer 4: Targeted rollout (if still failing) ---
    if still_failing and model_ref is not None:
        yield emit("Activating targeted rollout...")
        from phonolex_governors.generation.lookahead import TargetedRolloutProcessor
        rollout = TargetedRolloutProcessor(
            model=model_ref, tokenizer=tokenizer,
            ban_words=all_ban_words,
        )
        # ... generate with rollout added to LogitsProcessorList

The exact wiring depends on how model.generate_single() is refactored to accept additional LogitsProcessors and bad_words_ids. The key pattern: escalation layers are lazy — only instantiated when earlier layers fail.

  • [ ] Step 2: Update model.py to accept bad_words_ids

Modify generate_single() to accept an optional bad_words_ids parameter:

def generate_single(
    prompt: str,
    logits_processor: LogitsProcessorList | None = None,
    bad_words_ids: list[list[int]] | None = None,
) -> tuple[list[int], str, float]:

Add bad_words_ids to gen_kwargs if provided.

  • [ ] Step 3: Test the full pipeline manually

Start all three servers (Workers API, frontend, generation server). Set a phoneme exclusion constraint in the UI, generate, and verify: 1. Status messages appear below the Generate button 2. Compliance highlighting works on the output 3. Word list resolution happens (check generation server logs)

  • [ ] Step 4: Commit
git add packages/generation/server/routes/generate.py packages/generation/server/model.py
git commit -m "feat(generation): wire lookahead escalation into SSE pipeline"

Task 12: Update main.py and Clean Up Dead Code

Files: - Modify: packages/generation/server/main.py - Delete: dead test files - Modify: packages/generation/server/tests/conftest.py

  • [ ] Step 1: Simplify main.py

Remove GovernorCache initialization, lookup loading, and session/profile store setup. Keep model loading and word_norms loading. Update route initialization.

  • [ ] Step 2: Update conftest.py

Remove the lookup and small_lookup fixtures (no longer needed — no token-level governor tests). Add any new fixtures needed for the word-list-based tests.

  • [ ] Step 3: Delete dead test files
rm packages/generation/server/tests/test_enrichment.py  # if it tests old compliance
rm packages/generation/server/tests/test_compliance.py   # if it tests old checker path

Review each before deleting — keep any tests that test GUARD-relevant behavior.

  • [ ] Step 4: Run full test suite

Run: cd packages/generation && uv run python -m pytest server/tests/ -v Expected: All remaining tests pass

  • [ ] Step 5: Commit
git add -u packages/generation/
git commit -m "chore: clean up dead code — remove sessions, profiles, lookup fixtures"

Task 13: Frontend Build Verification

Files: - Modify: various frontend files as needed for type errors

  • [ ] Step 1: Run TypeScript type check

Run: cd packages/web/frontend && npx tsc --noEmit Expected: No errors. If there are errors from the governance.ts type changes (removed old types still referenced elsewhere), fix them.

  • [ ] Step 2: Run frontend tests

Run: cd packages/web/frontend && npx vitest run Expected: All tests pass

  • [ ] Step 3: Run Workers API tests

Run: cd packages/web/workers && npm test Expected: All tests pass including new word-list tests

  • [ ] Step 4: Commit any fixes
git add -u packages/web/
git commit -m "fix: resolve type errors from governance.ts simplification"