Unified Constrained Generation Implementation Plan¶
For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (
- [ ]) syntax for tracking.
Goal: Replace the token-level governor architecture with a word-list-based BAN/BOOST pipeline, adding bigram lookahead and targeted rollout escalation.
Architecture: Everything is a word list. Constraints resolve via PhonoLex Workers API to word lists, then apply as BAN (bad_words_ids) or BOOST (Reranker coverage modulation). Four escalation layers: BAN/BOOST → GUARD → bigram dead-end filter → targeted rollout. Best-of-N wraps everything.
Tech Stack: Python (FastAPI, PyTorch, transformers), TypeScript (Hono/D1, React/MUI), SSE streaming
Spec: docs/superpowers/specs/2026-04-15-unified-constrained-generation-design.md
File Map¶
Workers API (new endpoint)¶
- Modify:
packages/web/workers/src/routes/words.ts— addPOST /word-list - Test:
packages/web/workers/tests/word-list.test.ts
Generation Server (major rewrite)¶
- Modify:
packages/generation/server/schemas.py— collapse 11 → 5 constraint types + ResolvedConstraint - Rewrite:
packages/generation/server/governor.py—resolve_constraints()+prepare_generation() - Rewrite:
packages/generation/server/model.py— four-layergenerate_with_checking() - Modify:
packages/generation/server/routes/generate.py— SSE streaming forgenerate-single - Modify:
packages/generation/server/word_norms.py— lazy loading, remove startup precompute - Modify:
packages/generation/server/main.py— remove GovernorCache init, remove lookup loading - Delete:
packages/generation/server/routes/sessions.py— session route uses old governor - Delete:
packages/generation/server/sessions.py— session store - Delete:
packages/generation/server/profiles.py— profile store - Delete:
packages/generation/server/routes/profiles.py— profile routes - Rewrite:
packages/generation/server/tests/test_governor.py— test resolve + prepare - Modify:
packages/generation/server/tests/test_schemas.py— test new 5-type schema - Modify:
packages/generation/server/tests/conftest.py— remove lookup fixture dependency
Governor Package (strip to A-team)¶
- Rewrite:
packages/governors/src/phonolex_governors/__init__.py— export only A-team - Rewrite:
packages/governors/src/phonolex_governors/generation/reranker.py— set membership + coverage - Create:
packages/governors/src/phonolex_governors/generation/lookahead.py— bigram filter + rollout - Keep:
packages/governors/src/phonolex_governors/checking/checker.py(unchanged) - Keep:
packages/governors/src/phonolex_governors/checking/g2p.py(unchanged) - Keep:
packages/governors/src/phonolex_governors/checking/phonology.py(unchanged) - Delete:
packages/governors/src/phonolex_governors/core.py - Delete:
packages/governors/src/phonolex_governors/gates.py - Delete:
packages/governors/src/phonolex_governors/boosts.py(module) - Delete:
packages/governors/src/phonolex_governors/boosts/(directory) - Delete:
packages/governors/src/phonolex_governors/cdd.py - Delete:
packages/governors/src/phonolex_governors/constraints.py - Delete:
packages/governors/src/phonolex_governors/include.py - Delete:
packages/governors/src/phonolex_governors/lookups.py - Delete:
packages/governors/src/phonolex_governors/thematic.py - Delete:
packages/governors/src/phonolex_governors/constraint_compiler.py - Delete:
packages/governors/src/phonolex_governors/constraint_types.py - Delete:
packages/governors/src/phonolex_governors/generation/backtrack.py - Delete:
packages/governors/src/phonolex_governors/generation/loop.py - Delete:
packages/governors/src/phonolex_governors/generation/sampling.py - Create:
packages/governors/tests/test_reranker_v2.py - Create:
packages/governors/tests/test_lookahead.py
Bigram Matrix Builder¶
- Create:
packages/generation/scripts/build_bigram_matrix.py - Artifact:
packages/generation/data/bigram_matrix.npz(gitignored)
Frontend¶
- Modify:
packages/web/frontend/src/types/governance.ts— 5 constraint types - Modify:
packages/web/frontend/src/lib/constraintCompiler.ts— handle all 5 types - Modify:
packages/web/frontend/src/lib/constraintCompiler.test.ts— tests for bound_boost + contrastive - Modify:
packages/web/frontend/src/lib/generationApi.ts— SSE streaming - Modify:
packages/web/frontend/src/components/tools/GovernedGenerationTool/index.tsx— status line - Modify:
packages/web/frontend/src/components/tools/GovernedGenerationTool/OutputCard.tsx— generalized coverage
Task 1: Workers API Word List Endpoint¶
Files:
- Modify: packages/web/workers/src/routes/words.ts
- Create: packages/web/workers/tests/word-list.test.ts
- [ ] Step 1: Write the failing test
Create packages/web/workers/tests/word-list.test.ts:
import { describe, it, expect } from 'vitest';
import app from '../src/index';
// Uses the same test D1 fixture as other worker tests
describe('POST /api/words/word-list', () => {
it('returns words containing a phoneme', async () => {
const res = await app.request('/api/words/word-list', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ include_phonemes: ['k'] }),
});
expect(res.status).toBe(200);
const data = await res.json();
expect(data.words).toBeInstanceOf(Array);
expect(data.total).toBeGreaterThan(0);
expect(data.words.length).toBe(data.total);
// Every returned word should be a string
expect(data.words.every((w: unknown) => typeof w === 'string')).toBe(true);
});
it('returns words matching norm bounds', async () => {
const res = await app.request('/api/words/word-list', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filters: { max_aoa_kuperman: 5 } }),
});
expect(res.status).toBe(200);
const data = await res.json();
expect(data.words).toBeInstanceOf(Array);
expect(data.total).toBeGreaterThan(0);
});
it('returns empty list for impossible filter', async () => {
const res = await app.request('/api/words/word-list', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ filters: { min_aoa_kuperman: 999 } }),
});
expect(res.status).toBe(200);
const data = await res.json();
expect(data.words).toEqual([]);
expect(data.total).toBe(0);
});
it('returns 400 for empty body', async () => {
const res = await app.request('/api/words/word-list', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({}),
});
expect(res.status).toBe(400);
});
});
- [ ] Step 2: Run test to verify it fails
Run: cd packages/web/workers && npm test -- --grep "word-list"
Expected: FAIL — route not found (404)
- [ ] Step 3: Implement the endpoint
Add to packages/web/workers/src/routes/words.ts after the existing /batch route:
interface WordListBody {
include_phonemes?: string[];
exclude_phonemes?: string[];
filters?: Record<string, number>;
}
words.post('/word-list', async (c) => {
const body = await c.req.json<WordListBody>();
if (!body.include_phonemes?.length && !body.exclude_phonemes?.length
&& (!body.filters || Object.keys(body.filters).length === 0)) {
return c.json({ error: 'At least one filter required' }, 400);
}
const wordsWhere: string[] = ['w.has_phonology = 1'];
const propsWhere: string[] = [];
const params: unknown[] = [];
// Phoneme inclusion — words containing ANY of these phonemes
if (body.include_phonemes?.length) {
const phonemeClauses = body.include_phonemes.map((ph) => {
const normalized = normalizePhoneme(ph);
params.push(`%|${normalized}|%`);
return 'w.phonemes_str LIKE ?';
});
wordsWhere.push(`(${phonemeClauses.join(' OR ')})`);
}
// Phoneme exclusion — words NOT containing these phonemes
if (body.exclude_phonemes?.length) {
for (const ph of body.exclude_phonemes) {
const normalized = normalizePhoneme(ph);
wordsWhere.push('w.phonemes_str NOT LIKE ?');
params.push(`%|${normalized}|%`);
}
}
// Norm filters
if (body.filters && Object.keys(body.filters).length > 0) {
const mergedFilters: Record<string, number | null> = {};
for (const [key, value] of Object.entries(body.filters)) {
if ((key.startsWith('min_') || key.startsWith('max_')) && typeof value === 'number') {
mergedFilters[key] = value;
}
}
if (Object.keys(mergedFilters).length > 0) {
const partitioned = partitionFilterColumns(mergedFilters);
wordsWhere.push(...partitioned.wordsClauses.map((cl) => prefixWordsColumns(cl)));
params.push(...partitioned.wordsParams);
propsWhere.push(...partitioned.propsClauses.map((cl) => prefixPropsColumns(cl)));
params.push(...partitioned.propsParams);
}
}
const wordsWhereSQL = wordsWhere.join(' AND ');
const needsPropsJoin = propsWhere.length > 0;
const fromClause = needsPropsJoin
? 'FROM words w INNER JOIN word_properties wp ON w.word = wp.word'
: 'FROM words w';
const fullWhere = needsPropsJoin
? ` WHERE ${wordsWhereSQL} AND ${propsWhere.join(' AND ')}`
: ` WHERE ${wordsWhereSQL}`;
const sql = `SELECT w.word ${fromClause}${fullWhere}`;
const { results } = await c.env.DB.prepare(sql).bind(...params).all<{ word: string }>();
const wordList = results.map((r) => r.word);
return c.json({ words: wordList, total: wordList.length });
});
Note: prefixWordsColumns and prefixPropsColumns are helper functions already used in the /search route. Check that they are importable or defined in scope — they may be local to the file. If they're inline helpers, extract them or reuse the existing ones.
- [ ] Step 4: Run test to verify it passes
Run: cd packages/web/workers && npm test -- --grep "word-list"
Expected: PASS
- [ ] Step 5: Commit
git add packages/web/workers/src/routes/words.ts packages/web/workers/tests/word-list.test.ts
git commit -m "feat(api): add POST /api/words/word-list endpoint for generation server"
Task 2: Simplify Schema Types¶
Files:
- Modify: packages/generation/server/schemas.py
- Modify: packages/generation/server/tests/test_schemas.py
- [ ] Step 1: Write the failing test
Replace content of packages/generation/server/tests/test_schemas.py with tests for the new 5-type schema:
import pytest
from pydantic import ValidationError
from server.schemas import (
ExcludeConstraint, IncludeConstraint, BoundConstraint,
BoundBoostConstraint, ContrastiveConstraint,
ResolvedConstraint,
GenerateSingleRequest,
)
def test_exclude_constraint():
c = ExcludeConstraint(type="exclude", phonemes=["ɹ", "ɝ"])
assert c.phonemes == ["ɹ", "ɝ"]
def test_include_constraint():
c = IncludeConstraint(type="include", phonemes=["k"], target_rate=0.2)
assert c.target_rate == 0.2
def test_include_constraint_rejects_invalid_rate():
with pytest.raises(ValidationError):
IncludeConstraint(type="include", phonemes=["k"], target_rate=1.5)
def test_bound_constraint_min():
c = BoundConstraint(type="bound", norm="concreteness", min=3.0)
assert c.min == 3.0
assert c.max is None
def test_bound_constraint_max():
c = BoundConstraint(type="bound", norm="aoa_kuperman", max=5.0)
assert c.max == 5.0
def test_bound_boost_constraint():
c = BoundBoostConstraint(
type="bound_boost", norm="concreteness", min=3.0, coverage_target=0.2,
)
assert c.coverage_target == 0.2
def test_contrastive_constraint_minpair():
c = ContrastiveConstraint(
type="contrastive", pair_type="minpair",
phoneme1="s", phoneme2="z", position="initial",
)
assert c.pair_type == "minpair"
assert c.position == "initial"
def test_contrastive_constraint_maxopp():
c = ContrastiveConstraint(
type="contrastive", pair_type="maxopp",
phoneme1="m", phoneme2="s", position="any",
)
assert c.pair_type == "maxopp"
def test_contrastive_rejects_invalid_position():
with pytest.raises(ValidationError):
ContrastiveConstraint(
type="contrastive", pair_type="minpair",
phoneme1="s", phoneme2="z", position="nowhere",
)
def test_resolved_constraint_ban():
r = ResolvedConstraint(
mode="ban", words=["car", "card"], strategy="direct",
label="exclude /ɹ/",
)
assert r.mode == "ban"
assert r.strategy == "direct"
assert r.coverage_target is None
def test_resolved_constraint_boost():
r = ResolvedConstraint(
mode="boost", words=["cat", "kit"], strategy="direct",
coverage_target=0.2, label="include /k/ 20%",
)
assert r.coverage_target == 0.2
def test_generate_single_request_with_new_types():
req = GenerateSingleRequest(
prompt="Write a story",
constraints=[
{"type": "exclude", "phonemes": ["ɹ"]},
{"type": "include", "phonemes": ["k"], "target_rate": 0.2},
{"type": "bound", "norm": "aoa_kuperman", "max": 5.0},
{"type": "bound_boost", "norm": "concreteness", "min": 3.0, "coverage_target": 0.15},
{"type": "contrastive", "pair_type": "minpair", "phoneme1": "s", "phoneme2": "z", "position": "initial"},
],
)
assert len(req.constraints) == 5
- [ ] Step 2: Run test to verify it fails
Run: cd packages/generation && uv run python -m pytest server/tests/test_schemas.py -v
Expected: FAIL — BoundBoostConstraint, ContrastiveConstraint, ResolvedConstraint not defined
- [ ] Step 3: Rewrite schemas.py
Replace the constraint types in packages/generation/server/schemas.py. Keep Phono, RichToken, GenerateSingleRequest, SingleGenerationResponse, WordViolation, WordComplianceDetail, IncludeCoverage, ServerStatus. Remove the old 11 constraint types and session/profile types. Add:
class ExcludeConstraint(BaseModel):
type: Literal["exclude"]
phonemes: list[str]
class IncludeConstraint(BaseModel):
type: Literal["include"]
phonemes: list[str]
target_rate: float = 0.2
@field_validator("target_rate")
@classmethod
def validate_target_rate(cls, v):
if not (0.0 <= v <= 1.0):
raise ValueError("target_rate must be between 0.0 and 1.0")
return v
class BoundConstraint(BaseModel):
type: Literal["bound"]
norm: str
min: float | None = None
max: float | None = None
class BoundBoostConstraint(BaseModel):
type: Literal["bound_boost"]
norm: str
min: float | None = None
max: float | None = None
coverage_target: float = 0.2
@field_validator("coverage_target")
@classmethod
def validate_coverage_target(cls, v):
if not (0.0 <= v <= 1.0):
raise ValueError("coverage_target must be between 0.0 and 1.0")
return v
class ContrastiveConstraint(BaseModel):
type: Literal["contrastive"]
pair_type: Literal["minpair", "maxopp"]
phoneme1: str
phoneme2: str
position: Literal["initial", "medial", "final", "any"] = "any"
Constraint = (
ExcludeConstraint | IncludeConstraint | BoundConstraint
| BoundBoostConstraint | ContrastiveConstraint
)
class ResolvedConstraint(BaseModel):
mode: Literal["ban", "boost"]
words: list[str]
strategy: Literal["direct", "complement"] = "direct"
coverage_target: float | None = None
label: str
Remove: ExcludeClustersConstraint, VocabBoostConstraint, ComplexityConstraint, VocabOnlyConstraint, MSHConstraint, MinPairBoostConstraint, MaxOppositionBoostConstraint, ThematicConstraint, ConstraintProfile, Turn, TurnUser, TurnAssistant, Session, GenerateRequest, GenerateResponse, UserAnalysis, AssistantResponse, SessionCreateRequest, SessionCreateResponse, ProfileCreateRequest, ProfileUpdateRequest, TokenCensus.
Update GenerateSingleRequest to use the new Constraint union. Add a BoostCoverage model for generalized coverage reporting:
class BoostCoverage(BaseModel):
label: str
target_rate: float
actual_rate: float
hit_words: list[str]
total_words: int
Update SingleGenerationResponse — replace include_coverage: list[IncludeCoverage] with boost_coverage: list[BoostCoverage]. Keep IncludeCoverage temporarily as an alias for backwards compatibility if needed, or remove it.
- [ ] Step 4: Run test to verify it passes
Run: cd packages/generation && uv run python -m pytest server/tests/test_schemas.py -v
Expected: PASS
- [ ] Step 5: Commit
git add packages/generation/server/schemas.py packages/generation/server/tests/test_schemas.py
git commit -m "refactor(schemas): collapse 11 constraint types to 5 + ResolvedConstraint"
Task 3: Strip Governor Package to A-Team¶
Files:
- Delete: 14 files (see file map above)
- Rewrite: packages/governors/src/phonolex_governors/__init__.py
- Keep: checking/ directory unchanged
- Keep: generation/reranker.py (rewritten in Task 5)
- [ ] Step 1: Delete dead modules
cd packages/governors/src/phonolex_governors
rm -f core.py gates.py cdd.py constraints.py include.py lookups.py thematic.py
rm -f constraint_compiler.py constraint_types.py
rm -f boosts.py # the flat file
rm -rf boosts/ # the directory
rm -f generation/backtrack.py generation/loop.py generation/sampling.py
- [ ] Step 2: Rewrite
__init__.py
"""PhonoLex Governors — word-level constraint checking and generation steering."""
from phonolex_governors.checking.checker import (
CheckerConfig,
CheckResult,
PhonemeExcludeCheck,
VocabOnlyCheck,
check_word,
)
from phonolex_governors.checking.g2p import G2PCache, word_to_phonemes
from phonolex_governors.generation.reranker import Reranker
__all__ = [
"CheckerConfig",
"CheckResult",
"PhonemeExcludeCheck",
"VocabOnlyCheck",
"check_word",
"G2PCache",
"word_to_phonemes",
"Reranker",
]
- [ ] Step 3: Update
generation/__init__.py
"""Generation steering — reranker and lookahead."""
- [ ] Step 4: Verify the package imports cleanly
Run: cd packages/governors && uv run python -c "from phonolex_governors import check_word, Reranker, G2PCache; print('OK')"
Expected: OK
- [ ] Step 5: Commit
cd /Users/jneumann/Repos/PhonoLex
git add -u packages/governors/
git commit -m "refactor(governors): strip to A-team — checker, g2p, reranker"
Task 4: Frontend — Complete the Constraint Compiler¶
Files:
- Modify: packages/web/frontend/src/types/governance.ts
- Modify: packages/web/frontend/src/lib/constraintCompiler.ts
- Modify: packages/web/frontend/src/lib/constraintCompiler.test.ts
- [ ] Step 1: Write the failing tests
Add to packages/web/frontend/src/lib/constraintCompiler.test.ts:
it('compiles bound_boost entry', () => {
const entries: StoreEntry[] = [
{ type: 'bound_boost', norm: 'concreteness', direction: 'min', value: 3, coverageTarget: 15 },
];
const result = compileConstraints(entries);
expect(result).toEqual([
{ type: 'bound_boost', norm: 'concreteness', min: 3, coverage_target: 0.15 },
]);
});
it('compiles contrastive minpair entry', () => {
const entries: StoreEntry[] = [
{ type: 'contrastive', pairType: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
];
const result = compileConstraints(entries);
expect(result).toEqual([
{ type: 'contrastive', pair_type: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
]);
});
it('compiles contrastive maxopp entry', () => {
const entries: StoreEntry[] = [
{ type: 'contrastive', pairType: 'maxopp', phoneme1: 'm', phoneme2: 's', position: 'any' },
];
const result = compileConstraints(entries);
expect(result).toEqual([
{ type: 'contrastive', pair_type: 'maxopp', phoneme1: 'm', phoneme2: 's', position: 'any' },
]);
});
it('compiles all 5 types together', () => {
const entries: StoreEntry[] = [
{ type: 'exclude', phoneme: 'ɹ' },
{ type: 'include', phoneme: 'k', strength: 2.0, targetRate: 20 },
{ type: 'bound', norm: 'aoa_kuperman', direction: 'max', value: 5 },
{ type: 'bound_boost', norm: 'concreteness', direction: 'min', value: 3, coverageTarget: 15 },
{ type: 'contrastive', pairType: 'minpair', phoneme1: 's', phoneme2: 'z', position: 'initial' },
];
const result = compileConstraints(entries);
expect(result).toHaveLength(5);
expect(result.map((c) => c.type)).toEqual([
'exclude', 'include', 'bound', 'bound_boost', 'contrastive',
]);
});
- [ ] Step 2: Run test to verify it fails
Run: cd packages/web/frontend && npx vitest run src/lib/constraintCompiler.test.ts
Expected: FAIL — bound_boost and contrastive entries silently dropped
- [ ] Step 3: Update governance.ts types
Simplify the Constraint union in packages/web/frontend/src/types/governance.ts. Replace the 11 API constraint types with 5:
export interface ExcludeConstraint {
type: "exclude";
phonemes: string[];
}
export interface IncludeConstraint {
type: "include";
phonemes: string[];
target_rate?: number;
}
export interface BoundConstraint {
type: "bound";
norm: string;
min?: number;
max?: number;
}
export interface BoundBoostConstraint {
type: "bound_boost";
norm: string;
min?: number;
max?: number;
coverage_target: number;
}
export interface ContrastiveConstraint {
type: "contrastive";
pair_type: "minpair" | "maxopp";
phoneme1: string;
phoneme2: string;
position: string;
}
export type Constraint =
| ExcludeConstraint
| IncludeConstraint
| BoundConstraint
| BoundBoostConstraint
| ContrastiveConstraint;
Remove the old types: ExcludeClustersConstraint, VocabBoostConstraint, ComplexityConstraint, VocabOnlyConstraint, MSHConstraint, MinPairBoostConstraint, MaxOppositionBoostConstraint, ThematicConstraint.
Add BoostCoverage to replace IncludeCoverage:
export interface BoostCoverage {
label: string;
target_rate: number;
actual_rate: number;
hit_words: string[];
total_words: number;
}
Update SingleGenerationResponse to use boost_coverage: BoostCoverage[] instead of include_coverage: IncludeCoverage[].
- [ ] Step 4: Update constraintCompiler.ts
Add bound_boost and contrastive handling to compileConstraints():
// --- Bound boost: one constraint per entry ---
const boundBoosts = entries.filter(
(e): e is Extract<StoreEntry, { type: 'bound_boost' }> => e.type === 'bound_boost',
);
for (const bb of boundBoosts) {
result.push({
type: 'bound_boost',
norm: bb.norm,
...(bb.direction === 'min' ? { min: bb.value } : { max: bb.value }),
coverage_target: bb.coverageTarget / 100,
});
}
// --- Contrastive: one constraint per entry ---
const contrastives = entries.filter(
(e): e is Extract<StoreEntry, { type: 'contrastive' }> => e.type === 'contrastive',
);
for (const ct of contrastives) {
result.push({
type: 'contrastive',
pair_type: ct.pairType,
phoneme1: ct.phoneme1,
phoneme2: ct.phoneme2,
position: ct.position,
});
}
- [ ] Step 5: Run tests to verify they pass
Run: cd packages/web/frontend && npx vitest run src/lib/constraintCompiler.test.ts
Expected: PASS (all tests including new ones)
- [ ] Step 6: Commit
git add packages/web/frontend/src/types/governance.ts packages/web/frontend/src/lib/constraintCompiler.ts packages/web/frontend/src/lib/constraintCompiler.test.ts
git commit -m "feat(frontend): complete constraint compiler — all 5 types wired"
Task 5: Rewrite the Reranker¶
Files:
- Rewrite: packages/governors/src/phonolex_governors/generation/reranker.py
- Create: packages/governors/tests/test_reranker_v2.py
- [ ] Step 1: Write the failing tests
Create packages/governors/tests/test_reranker_v2.py:
import pytest
import torch
from unittest.mock import MagicMock
from phonolex_governors.generation.reranker import Reranker, BoostList
@pytest.fixture
def mock_tokenizer():
tok = MagicMock()
tok.all_special_ids = [0, 1, 2]
# Map token IDs to decoded text
decode_map = {
10: "▁cat",
11: "▁dog",
12: "▁car",
13: "▁kit",
14: "▁the",
15: "▁run",
}
tok.decode = MagicMock(side_effect=lambda ids, **kw: decode_map.get(ids[0], "▁unk"))
return tok
def test_complement_ban_penalizes_non_members(mock_tokenizer):
"""Words not in the allow set should be penalized."""
allow_set = {"cat", "dog", "the"}
reranker = Reranker(
tokenizer=mock_tokenizer,
allow_sets=[allow_set],
boost_lists=[],
top_k=10,
penalize=10.0,
)
logits = torch.zeros(1, 20)
logits[0, 10] = 5.0 # cat — allowed
logits[0, 12] = 5.0 # car — not allowed
result = reranker(logits)
assert result[0, 10].item() == 5.0 # cat untouched
assert result[0, 12].item() < 0.0 # car penalized
def test_boost_list_increases_target_logits(mock_tokenizer):
"""Words in a boost list should get a positive logit adjustment."""
boost = BoostList(words={"cat", "kit"}, target_rate=0.3, label="include /k/")
reranker = Reranker(
tokenizer=mock_tokenizer,
allow_sets=[],
boost_lists=[boost],
top_k=10,
)
logits = torch.zeros(1, 20)
logits[0, 10] = 5.0 # cat — in boost set
logits[0, 11] = 5.0 # dog — not in boost set
result = reranker(logits)
assert result[0, 10].item() > 5.0 # cat boosted
assert result[0, 11].item() == 5.0 # dog untouched
def test_boost_self_regulates(mock_tokenizer):
"""When coverage is at target, boost should ease off."""
boost = BoostList(words={"cat"}, target_rate=0.5, label="test")
reranker = Reranker(
tokenizer=mock_tokenizer,
allow_sets=[],
boost_lists=[boost],
top_k=10,
)
# Simulate: 10 words generated, 5 are "cat" → 50% coverage = at target
boost.update_coverage(hit=5, total=10)
logits = torch.zeros(1, 20)
logits[0, 10] = 5.0 # cat
result = reranker(logits)
# At target — boost should be zero or near-zero
assert abs(result[0, 10].item() - 5.0) < 1.0
def test_multiple_boost_lists_compose(mock_tokenizer):
"""Multiple boost lists each apply independently."""
b1 = BoostList(words={"cat"}, target_rate=0.2, label="boost1")
b2 = BoostList(words={"cat", "dog"}, target_rate=0.3, label="boost2")
reranker = Reranker(
tokenizer=mock_tokenizer,
allow_sets=[],
boost_lists=[b1, b2],
top_k=10,
)
logits = torch.zeros(1, 20)
logits[0, 10] = 5.0 # cat — in both
logits[0, 11] = 5.0 # dog — in b2 only
result = reranker(logits)
# cat gets boosted by both lists
assert result[0, 10].item() > result[0, 11].item()
- [ ] Step 2: Run test to verify it fails
Run: cd packages/governors && uv run python -m pytest tests/test_reranker_v2.py -v
Expected: FAIL — BoostList not defined, Reranker constructor changed
- [ ] Step 3: Rewrite reranker.py
Replace packages/governors/src/phonolex_governors/generation/reranker.py:
"""Word-list reranker — set membership checking + coverage modulation.
Two responsibilities:
1. Complement ban: penalize candidates not in any allow set
2. Boost coverage: modulate logits to converge on target rates for boost word lists
No G2P. All checks are set membership — O(1) per candidate.
"""
from __future__ import annotations
from dataclasses import dataclass, field
import torch
from transformers import LogitsProcessor
WORD_BOUNDARY_CHAR = "\u2581"
@dataclass
class BoostList:
"""A word list with a coverage target."""
words: set[str]
target_rate: float
label: str
max_boost: float = 5.0
_hits: int = 0
_total: int = 0
def update_coverage(self, hit: int, total: int) -> None:
self._hits = hit
self._total = total
@property
def coverage(self) -> float:
return self._hits / self._total if self._total > 0 else 0.0
@property
def gap(self) -> float:
return max(0.0, self.target_rate - self.coverage)
@property
def boost_strength(self) -> float:
return self.max_boost * self.gap
@dataclass
class RerankerStats:
calls: int = 0
tokens_checked: int = 0
tokens_penalized: int = 0
tokens_boosted: int = 0
class Reranker(LogitsProcessor):
"""Word-list reranker as a HuggingFace LogitsProcessor.
Args:
tokenizer: HuggingFace tokenizer for decoding token IDs to text.
allow_sets: List of allowed word sets for complement banning.
If non-empty, words not in ANY allow set are penalized.
boost_lists: List of BoostList objects for coverage modulation.
top_k: Number of top candidates to evaluate per step.
penalize: Logit penalty for non-allowed words.
"""
def __init__(
self,
tokenizer,
allow_sets: list[set[str]],
boost_lists: list[BoostList],
top_k: int = 200,
penalize: float = 15.0,
):
self.tokenizer = tokenizer
self.allow_sets = allow_sets
self.boost_lists = boost_lists
self.top_k = top_k
self.penalize = penalize
self.stats = RerankerStats()
# Precompute union of all allow sets for fast membership check
self._allow_union: set[str] = set()
for s in allow_sets:
self._allow_union |= s
self._has_allow = len(self._allow_union) > 0
def _get_partial_word(self, input_ids: torch.LongTensor | None) -> str:
"""Decode the trailing partial word from generated tokens so far."""
if input_ids is None:
return ""
ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist()
if not ids:
return ""
special = set(self.tokenizer.all_special_ids)
partial_ids = []
start = max(0, len(ids) - 10)
for i in range(len(ids) - 1, start - 1, -1):
tid = ids[i]
if tid in special:
break
text = self.tokenizer.decode([tid], skip_special_tokens=False)
partial_ids.insert(0, tid)
if text.startswith(WORD_BOUNDARY_CHAR) or text.startswith(" "):
break
if not partial_ids:
return ""
return self.tokenizer.decode(partial_ids, skip_special_tokens=True).strip()
def _update_boost_coverage(self, input_ids: torch.LongTensor | None) -> None:
"""Recompute per-boost-list coverage from generated tokens so far."""
import re
if input_ids is None or not self.boost_lists:
return
text = self.tokenizer.decode(
input_ids[0] if input_ids.dim() > 1 else input_ids,
skip_special_tokens=True,
)
words = re.findall(r'[a-zA-Z]+', text)
total = len(words)
if total < 3:
return
word_set = {w.lower() for w in words}
for bl in self.boost_lists:
hits = len(word_set & bl.words)
bl.update_coverage(hit=hits, total=total)
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:
"""Rerank logits via set membership checks."""
self.stats.calls += 1
self._update_boost_coverage(input_ids)
partial = self._get_partial_word(input_ids)
k = min(self.top_k, scores.shape[-1])
top_vals, top_ids = torch.topk(scores[0], k)
for i in range(k):
tid = top_ids[i].item()
text = self.tokenizer.decode([tid], skip_special_tokens=False)
is_word_start = text.startswith(WORD_BOUNDARY_CHAR) or text.startswith(" ")
fragment = text.lstrip(WORD_BOUNDARY_CHAR).lstrip(" ").strip()
if not fragment or not fragment.isalpha():
continue
word = fragment if is_word_start else (partial + fragment if partial else fragment)
if not word or len(word) < 2:
continue
word_lower = word.lower()
self.stats.tokens_checked += 1
# Complement ban: penalize words not in allow set
if self._has_allow and word_lower not in self._allow_union:
scores[0, tid] -= self.penalize
self.stats.tokens_penalized += 1
# Boost: add logits for words in boost lists (proportional to gap)
for bl in self.boost_lists:
if word_lower in bl.words and bl.gap > 0:
scores[0, tid] += bl.boost_strength
self.stats.tokens_boosted += 1
return scores
- [ ] Step 4: Run test to verify it passes
Run: cd packages/governors && uv run python -m pytest tests/test_reranker_v2.py -v
Expected: PASS
- [ ] Step 5: Commit
git add packages/governors/src/phonolex_governors/generation/reranker.py packages/governors/tests/test_reranker_v2.py
git commit -m "feat(reranker): rewrite to set membership + coverage modulation — no G2P"
Task 6: Rewrite governor.py — Constraint Resolution¶
Files:
- Rewrite: packages/generation/server/governor.py
- Rewrite: packages/generation/server/tests/test_governor.py
- [ ] Step 1: Write the failing tests
Replace packages/generation/server/tests/test_governor.py:
import pytest
from unittest.mock import patch, AsyncMock
from server.governor import resolve_constraints, prepare_generation
from server.schemas import (
ExcludeConstraint, IncludeConstraint, BoundConstraint,
BoundBoostConstraint, ContrastiveConstraint, ResolvedConstraint,
)
@pytest.fixture
def mock_api_response():
"""Mock the Workers API word-list response."""
return {"words": ["cat", "dog", "bat", "hat"], "total": 4}
@pytest.mark.asyncio
async def test_resolve_exclude(mock_api_response):
constraints = [ExcludeConstraint(type="exclude", phonemes=["ɹ"])]
with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
return_value=mock_api_response):
resolved = await resolve_constraints(constraints)
assert len(resolved) == 1
assert resolved[0].mode == "ban"
assert resolved[0].words == ["cat", "dog", "bat", "hat"]
assert "ɹ" in resolved[0].label
@pytest.mark.asyncio
async def test_resolve_include(mock_api_response):
constraints = [IncludeConstraint(type="include", phonemes=["k"], target_rate=0.2)]
with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
return_value=mock_api_response):
resolved = await resolve_constraints(constraints)
assert len(resolved) == 1
assert resolved[0].mode == "boost"
assert resolved[0].coverage_target == 0.2
@pytest.mark.asyncio
async def test_resolve_bound(mock_api_response):
constraints = [BoundConstraint(type="bound", norm="aoa_kuperman", max=5.0)]
with patch("server.governor._call_word_list_api", new_callable=AsyncMock,
return_value=mock_api_response):
resolved = await resolve_constraints(constraints)
assert len(resolved) == 1
assert resolved[0].mode == "ban"
# Bounds use complement strategy
assert resolved[0].strategy == "complement"
def test_prepare_generation_ban_direct():
tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
"google/t5gemma-9b-2b-ul2-it"
)
resolved = [
ResolvedConstraint(mode="ban", words=["car", "card"], strategy="direct", label="test"),
]
bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
assert len(bad_words_ids) > 0 # "car" and "card" tokenized
assert len(allow_sets) == 0
assert len(boost_lists) == 0
def test_prepare_generation_ban_complement():
tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
"google/t5gemma-9b-2b-ul2-it"
)
resolved = [
ResolvedConstraint(mode="ban", words=["cat", "dog"], strategy="complement", label="test"),
]
bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
assert len(bad_words_ids) == 0
assert len(allow_sets) == 1
assert "cat" in allow_sets[0]
def test_prepare_generation_boost():
tokenizer = pytest.importorskip("transformers").AutoTokenizer.from_pretrained(
"google/t5gemma-9b-2b-ul2-it"
)
resolved = [
ResolvedConstraint(mode="boost", words=["cat", "kit"], strategy="direct",
coverage_target=0.2, label="include /k/"),
]
bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
assert len(bad_words_ids) == 0
assert len(allow_sets) == 0
assert len(boost_lists) == 1
assert boost_lists[0].target_rate == 0.2
- [ ] Step 2: Run test to verify it fails
Run: cd packages/generation && uv run python -m pytest server/tests/test_governor.py -v
Expected: FAIL — resolve_constraints, prepare_generation, _call_word_list_api not defined
- [ ] Step 3: Rewrite governor.py
Replace packages/generation/server/governor.py:
"""Constraint resolution and generation preparation.
Resolves semantic constraints to word lists via PhonoLex Workers API,
then prepares BAN/BOOST inputs for the generation pipeline.
"""
from __future__ import annotations
import logging
import os
import httpx
from server.schemas import (
Constraint, ExcludeConstraint, IncludeConstraint,
BoundConstraint, BoundBoostConstraint, ContrastiveConstraint,
ResolvedConstraint,
)
from phonolex_governors.generation.reranker import BoostList
log = logging.getLogger("phonolex.governor")
PHONOLEX_API_URL = os.environ.get("PHONOLEX_API_URL", "http://localhost:8788")
async def _call_word_list_api(payload: dict) -> dict:
"""Call the PhonoLex Workers API word-list endpoint."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{PHONOLEX_API_URL}/api/words/word-list", json=payload)
resp.raise_for_status()
return resp.json()
async def _call_contrastive_api(endpoint: str, payload: dict) -> list[str]:
"""Call a PhonoLex contrastive endpoint and extract word strings."""
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(f"{PHONOLEX_API_URL}/api/contrastive/{endpoint}", json=payload)
resp.raise_for_status()
data = resp.json()
words = set()
for pair in data.get("pairs", data.get("items", [])):
if "word1" in pair:
words.add(pair["word1"]["word"] if isinstance(pair["word1"], dict) else pair["word1"])
if "word2" in pair:
words.add(pair["word2"]["word"] if isinstance(pair["word2"], dict) else pair["word2"])
return list(words)
async def resolve_constraints(constraints: list[Constraint]) -> list[ResolvedConstraint]:
"""Resolve semantic constraints to word lists via PhonoLex Workers API."""
resolved: list[ResolvedConstraint] = []
for c in constraints:
if isinstance(c, ExcludeConstraint):
data = await _call_word_list_api({"include_phonemes": c.phonemes})
resolved.append(ResolvedConstraint(
mode="ban", words=data["words"], strategy="direct",
label=f"exclude /{', '.join(c.phonemes)}/",
))
elif isinstance(c, IncludeConstraint):
data = await _call_word_list_api({"include_phonemes": c.phonemes})
resolved.append(ResolvedConstraint(
mode="boost", words=data["words"],
coverage_target=c.target_rate,
label=f"include /{', '.join(c.phonemes)}/ {c.target_rate:.0%}",
))
elif isinstance(c, BoundConstraint):
# Get the allowed words (passing the bound) — complement ban
filters = {}
if c.min is not None:
filters[f"min_{c.norm}"] = c.min
if c.max is not None:
filters[f"max_{c.norm}"] = c.max
data = await _call_word_list_api({"filters": filters})
label_parts = []
if c.min is not None:
label_parts.append(f"{c.norm} ≥ {c.min}")
if c.max is not None:
label_parts.append(f"{c.norm} ≤ {c.max}")
resolved.append(ResolvedConstraint(
mode="ban", words=data["words"], strategy="complement",
label=", ".join(label_parts),
))
elif isinstance(c, BoundBoostConstraint):
filters = {}
if c.min is not None:
filters[f"min_{c.norm}"] = c.min
if c.max is not None:
filters[f"max_{c.norm}"] = c.max
data = await _call_word_list_api({"filters": filters})
label_parts = []
if c.min is not None:
label_parts.append(f"{c.norm} ≥ {c.min}")
if c.max is not None:
label_parts.append(f"{c.norm} ≤ {c.max}")
resolved.append(ResolvedConstraint(
mode="boost", words=data["words"],
coverage_target=c.coverage_target,
label=f"boost {', '.join(label_parts)} {c.coverage_target:.0%}",
))
elif isinstance(c, ContrastiveConstraint):
if c.pair_type == "minpair":
words = await _call_contrastive_api("minimal-pairs", {
"phoneme1": c.phoneme1, "phoneme2": c.phoneme2,
"position": c.position,
})
else:
words = await _call_contrastive_api("maximal-opposition/word-lists", {
"phoneme1": c.phoneme1, "phoneme2": c.phoneme2,
"position": c.position,
})
resolved.append(ResolvedConstraint(
mode="boost", words=words,
coverage_target=0.15,
label=f"{c.pair_type} /{c.phoneme1}/↔/{c.phoneme2}/ {c.position}",
))
return resolved
def prepare_generation(
resolved: list[ResolvedConstraint],
tokenizer,
) -> tuple[list[list[int]], list[set[str]], list[BoostList]]:
"""Convert resolved constraints to generation inputs.
Returns:
bad_words_ids: Token ID sequences for HuggingFace bad_words_ids
allow_sets: Word sets for Reranker complement banning
boost_lists: BoostList objects for Reranker coverage modulation
"""
bad_words_ids: list[list[int]] = []
allow_sets: list[set[str]] = []
boost_lists: list[BoostList] = []
for rc in resolved:
if rc.mode == "ban":
if rc.strategy == "direct":
# Encode each banned word to token IDs
for word in rc.words:
for variant in [word, " " + word]:
tids = tokenizer.encode(variant, add_special_tokens=False)
if tids and tids not in bad_words_ids:
bad_words_ids.append(tids)
elif rc.strategy == "complement":
allow_sets.append(set(rc.words))
elif rc.mode == "boost":
boost_lists.append(BoostList(
words=set(rc.words),
target_rate=rc.coverage_target or 0.2,
label=rc.label,
))
return bad_words_ids, allow_sets, boost_lists
- [ ] Step 4: Run test to verify it passes
Run: cd packages/generation && uv run python -m pytest server/tests/test_governor.py -v
Expected: PASS (async tests pass, tokenizer tests may skip if model not available)
- [ ] Step 5: Commit
git add packages/generation/server/governor.py packages/generation/server/tests/test_governor.py
git commit -m "feat(governor): rewrite to word-list resolution via Workers API"
Task 7: Rewrite generate-single Route with SSE¶
Files:
- Modify: packages/generation/server/routes/generate.py
- Modify: packages/generation/server/model.py
- [ ] Step 1: Rewrite the generate-single route
Replace the generate_single function in packages/generation/server/routes/generate.py with an SSE streaming version. The old session-based /generate route and all session/profile imports are removed.
import json
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from server import model
from server.schemas import (
GenerateSingleRequest, SingleGenerationResponse,
ExcludeConstraint, IncludeConstraint, BoundConstraint,
BoundBoostConstraint, ContrastiveConstraint, BoostCoverage,
WordViolation, WordComplianceDetail,
)
from server.governor import resolve_constraints, prepare_generation
log = logging.getLogger("phonolex.generation")
router = APIRouter(prefix="/api")
def init(ps=None, ss=None, gc=None):
"""Legacy init — no-op. Kept for compatibility with main.py."""
pass
async def _generate_sse(req: GenerateSingleRequest):
"""SSE generator: yields status events, then the final response."""
import re
from phonolex_governors.checking.checker import check_word
from phonolex_governors.checking.g2p import G2PCache
from phonolex_governors.generation.reranker import Reranker
def emit(msg: str) -> str:
return f"data: {json.dumps({'status': msg})}\n\n"
try:
# --- Resolve constraints to word lists ---
yield emit("Resolving constraints...")
resolved = await resolve_constraints(req.constraints)
yield emit(f"Fetched word lists ({len(resolved)} constraints)")
# --- Prepare BAN/BOOST ---
tokenizer = model.get_tokenizer()
bad_words_ids, allow_sets, boost_lists = prepare_generation(resolved, tokenizer)
# --- Build Reranker ---
reranker = Reranker(
tokenizer=tokenizer,
allow_sets=allow_sets,
boost_lists=boost_lists,
) if (allow_sets or boost_lists) else None
# --- Build CheckerConfig for GUARD ---
from phonolex_governors.checking.checker import CheckerConfig, PhonemeExcludeCheck, VocabOnlyCheck
checks = []
bound_norms = []
for c in req.constraints:
if isinstance(c, ExcludeConstraint):
checks.append(PhonemeExcludeCheck(excluded=set(c.phonemes)))
elif isinstance(c, BoundConstraint):
# For GUARD: build VocabOnlyCheck from complement allow set
for rc in resolved:
if rc.strategy == "complement" and rc.label and c.norm in rc.label:
checks.append(VocabOnlyCheck(allowed_words=set(rc.words)))
break
bound_norms.append(c.norm)
from server.word_norms import get_word_norms, get_vocab_memberships
norms_data = get_word_norms() or {}
vocab_data = get_vocab_memberships() or {}
stop_words = {w for w, m in vocab_data.items() if any('stop' in s for s in m)}
checker_config = CheckerConfig(
checks=checks, norm_lookup=norms_data,
vocab_lookup=vocab_data, stop_words=stop_words,
) if checks else None
# --- Generate with four-layer pipeline ---
g2p_cache = G2PCache()
gen_kwargs = dict(
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.9,
top_k=50,
repetition_penalty=1.3,
)
if bad_words_ids:
gen_kwargs["bad_words_ids"] = bad_words_ids
import time
t0 = time.time()
best_text = ""
best_ids = []
total_retries = 0
max_retries = 4
n_drafts = 2
for draft_idx in range(n_drafts):
draft_bad = list(bad_words_ids) # copy per draft
for attempt in range(max_retries + 1):
yield emit(f"Generating draft {draft_idx + 1} (attempt {attempt + 1})...")
current_kwargs = {**gen_kwargs}
if draft_bad:
current_kwargs["bad_words_ids"] = draft_bad
processors = []
if reranker is not None:
from transformers import LogitsProcessorList
processors.append(reranker)
gen_ids, text, _ = model.generate_single(
req.prompt,
logits_processor=LogitsProcessorList(processors) if processors else None,
)
# Override with current kwargs for bad_words_ids
# (generate_single needs to accept bad_words_ids — see model.py update)
yield emit("Checking compliance...")
words = re.findall(r'[a-zA-Z]+', text)
violations = []
if checker_config:
for w in words:
result = check_word(w, checker_config, g2p_cache)
if not result.passed:
violations.append(w)
if not violations:
best_text = text
best_ids = gen_ids
yield emit(f"Draft {draft_idx + 1}: compliant")
break
else:
yield emit(f"Draft {draft_idx + 1}: {len(violations)} violations, retrying...")
for v_word in violations:
for variant in [v_word, " " + v_word]:
tids = tokenizer.encode(variant, add_special_tokens=False)
if tids and tids not in draft_bad:
draft_bad.append(tids)
total_retries += 1
gen_time_ms = (time.time() - t0) * 1000
# --- Compute compliance details ---
word_violations = []
word_compliance = []
all_words = re.findall(r'[a-zA-Z]+', best_text)
if checker_config:
for w in all_words:
result = check_word(w, checker_config, g2p_cache)
clean = w.strip().lower()
word_norms_entry = norms_data.get(clean, {})
relevant_values = {n: word_norms_entry.get(n) for n in bound_norms}
if not result.passed:
details = [v.details for v in result.violations]
word_violations.append(WordViolation(word=w, details=details))
word_compliance.append(WordComplianceDetail(
word=w, passed=False, values=relevant_values, violations=details,
))
else:
word_compliance.append(WordComplianceDetail(
word=w, passed=True, values=relevant_values,
))
# --- Compute boost coverage ---
boost_coverage = []
if boost_lists and all_words:
for bl in boost_lists:
hit_words = [w for w in all_words if w.lower() in bl.words]
rate = len(hit_words) / len(all_words) if all_words else 0.0
boost_coverage.append(BoostCoverage(
label=bl.label,
target_rate=round(bl.target_rate * 100, 1),
actual_rate=round(rate * 100, 1),
hit_words=hit_words,
total_words=len(all_words),
))
# --- Enrich tokens ---
gen_tokens = model.enrich_tokens(best_ids, {}, [], tokenizer)
response = SingleGenerationResponse(
tokens=gen_tokens, text=best_text,
gen_time_ms=gen_time_ms,
compliant=len(word_violations) == 0,
violation_count=len(word_violations),
violation_words=[wv.word for wv in word_violations],
word_violations=word_violations,
word_compliance=word_compliance,
boost_coverage=boost_coverage,
)
yield f"data: {json.dumps({'result': response.model_dump()})}\n\n"
except Exception as e:
log.exception("Generation failed")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
@router.post("/generate-single")
async def generate_single(req: GenerateSingleRequest):
if not model.is_ready():
raise HTTPException(503, "Model not ready")
return StreamingResponse(
_generate_sse(req),
media_type="text/event-stream",
)
Note: This is a structural sketch. The actual implementation will need to pass bad_words_ids through to model.generate_single() and handle the generate_with_checking refactor. The key pattern is: SSE stream wrapping the four-layer pipeline, emitting status events at each stage.
- [ ] Step 2: Remove session/profile routes from main.py
Update packages/generation/server/main.py to remove session/profile initialization and routes. Remove imports of ProfileStore, SessionStore, GovernorCache. Remove lookup loading (no longer needed). Keep model loading and word_norms loading.
- [ ] Step 3: Delete dead server files
rm packages/generation/server/sessions.py
rm packages/generation/server/profiles.py
rm packages/generation/server/routes/sessions.py
rm packages/generation/server/routes/profiles.py
rm packages/generation/server/tests/test_sessions.py
rm packages/generation/server/tests/test_profiles.py
- [ ] Step 4: Verify server starts
Run: cd packages/generation && PHONOLEX_API_URL=http://localhost:8788 uv run uvicorn server.main:app --host 0.0.0.0 --port 3003
Expected: Server starts without import errors. May fail on model loading (expected if no GPU) but should not crash on import.
- [ ] Step 5: Commit
git add -u packages/generation/server/
git commit -m "feat(generation): SSE streaming, word-list pipeline, remove sessions/profiles"
Task 8: Frontend SSE Client¶
Files:
- Modify: packages/web/frontend/src/lib/generationApi.ts
- Modify: packages/web/frontend/src/components/tools/GovernedGenerationTool/index.tsx
- [ ] Step 1: Rewrite generationApi.ts for SSE
import { useState, useEffect, useRef } from 'react';
import type { Constraint, SingleGenerationResponse, ServerStatus } from '../types/governance';
import { freshRequestId, getRequestId, logError } from './logger';
const GENERATION_API_URL = import.meta.env.VITE_GENERATION_API_URL || 'http://localhost:8000';
export interface GenerationCallbacks {
onStatus: (message: string) => void;
onResult: (response: SingleGenerationResponse) => void;
onError: (error: string) => void;
}
export async function generateContent(
prompt: string,
constraints: Constraint[],
callbacks: GenerationCallbacks,
): Promise<void> {
const rid = freshRequestId();
const res = await fetch(`${GENERATION_API_URL}/api/generate-single`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-Request-ID': rid,
},
body: JSON.stringify({ prompt, constraints }),
});
if (!res.ok) {
const detail = await res.text().catch(() => res.statusText);
logError('Generation API request failed', {
method: 'POST', url: '/api/generate-single', status: res.status, detail,
});
callbacks.onError(`Generation failed (${res.status}): ${detail}`);
return;
}
const reader = res.body?.getReader();
if (!reader) {
callbacks.onError('No response body');
return;
}
const decoder = new TextDecoder();
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
if (!line.startsWith('data: ')) continue;
const payload = line.slice(6).trim();
if (!payload) continue;
try {
const event = JSON.parse(payload);
if (event.status) {
callbacks.onStatus(event.status);
} else if (event.result) {
callbacks.onResult(event.result);
} else if (event.error) {
callbacks.onError(event.error);
}
} catch {
// Skip malformed events
}
}
}
}
export async function fetchServerStatus(): Promise<ServerStatus> {
const res = await fetch(`${GENERATION_API_URL}/api/server/status`, {
headers: { 'X-Request-ID': getRequestId() },
});
if (!res.ok) throw new Error(`Server status check failed (${res.status})`);
return res.json();
}
export function useServerStatus(pollIntervalMs = 5000): ServerStatus | null {
const [status, setStatus] = useState<ServerStatus | null>(null);
const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
useEffect(() => {
let cancelled = false;
async function poll() {
try {
const s = await fetchServerStatus();
if (!cancelled) setStatus(s);
} catch {
if (!cancelled) setStatus(null);
}
}
poll();
intervalRef.current = setInterval(poll, pollIntervalMs);
return () => {
cancelled = true;
if (intervalRef.current) clearInterval(intervalRef.current);
};
}, [pollIntervalMs]);
return status;
}
- [ ] Step 2: Update GovernedGenerationTool index.tsx
Add a statusMessage state and pass onStatus callback to generateContent. Display the status line below the Generate button:
const [statusMessage, setStatusMessage] = useState<string | null>(null);
// In the generate handler:
setStatusMessage('Starting...');
await generateContent(prompt, compiled, {
onStatus: (msg) => setStatusMessage(msg),
onResult: (response) => {
setStatusMessage(null);
// ... create GenerationResult from response and add to feed
},
onError: (err) => {
setStatusMessage(null);
// ... show error
},
});
Render the status line:
{statusMessage && (
<Typography variant="body2" color="text.secondary" sx={{ mt: 1, fontStyle: 'italic' }}>
{statusMessage}
</Typography>
)}
- [ ] Step 3: Update OutputCard.tsx for generalized coverage
Replace includeCoverage rendering with boostCoverage — same visual pattern but using BoostCoverage.label instead of hardcoded phoneme display.
- [ ] Step 4: Run frontend build
Run: cd packages/web/frontend && npm run build
Expected: No TypeScript errors
- [ ] Step 5: Commit
git add packages/web/frontend/src/lib/generationApi.ts packages/web/frontend/src/components/tools/GovernedGenerationTool/
git commit -m "feat(frontend): SSE streaming, status line, generalized coverage display"
Task 9: Build Bigram Transition Matrix¶
Files:
- Create: packages/generation/scripts/build_bigram_matrix.py
- [ ] Step 1: Write the builder script
"""Build a BPE bigram transition matrix from a text corpus.
Tokenizes a corpus with the T5Gemma tokenizer and counts token bigrams.
Outputs a sparse matrix: for each token, the top-k most likely next tokens
and their conditional probabilities.
Usage:
python build_bigram_matrix.py --corpus data/corpus.txt --output data/bigram_matrix.npz
python build_bigram_matrix.py --use-wiki --output data/bigram_matrix.npz
"""
import argparse
import logging
import numpy as np
from collections import Counter
from pathlib import Path
from transformers import AutoTokenizer
log = logging.getLogger(__name__)
MODEL_NAME = "google/t5gemma-9b-2b-ul2-it"
TOP_K = 100 # transitions per token
def build_from_corpus(corpus_path: str, tokenizer, top_k: int = TOP_K) -> dict:
"""Build bigram counts from a text file."""
log.info("Tokenizing corpus: %s", corpus_path)
bigram_counts: dict[int, Counter] = {}
with open(corpus_path) as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
ids = tokenizer.encode(line, add_special_tokens=False)
for a, b in zip(ids, ids[1:]):
if a not in bigram_counts:
bigram_counts[a] = Counter()
bigram_counts[a][b] += 1
if (i + 1) % 100_000 == 0:
log.info("Processed %d lines", i + 1)
log.info("Computing conditional probabilities for %d tokens", len(bigram_counts))
matrix = {}
for token_id, counts in bigram_counts.items():
total = sum(counts.values())
top = counts.most_common(top_k)
matrix[token_id] = {
next_id: count / total for next_id, count in top
}
return matrix
def save_matrix(matrix: dict, output_path: str) -> None:
"""Save as compressed numpy arrays (sparse representation)."""
token_ids = []
next_ids = []
probs = []
for tid, transitions in matrix.items():
for next_tid, prob in transitions.items():
token_ids.append(tid)
next_ids.append(next_tid)
probs.append(prob)
np.savez_compressed(
output_path,
token_ids=np.array(token_ids, dtype=np.int32),
next_ids=np.array(next_ids, dtype=np.int32),
probs=np.array(probs, dtype=np.float32),
)
log.info("Saved matrix to %s (%d entries)", output_path, len(token_ids))
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Build BPE bigram matrix")
parser.add_argument("--corpus", type=str, help="Path to text corpus file")
parser.add_argument("--output", type=str, default="data/bigram_matrix.npz")
parser.add_argument("--top-k", type=int, default=TOP_K)
args = parser.parse_args()
if not args.corpus:
parser.error("--corpus is required")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
matrix = build_from_corpus(args.corpus, tokenizer, top_k=args.top_k)
save_matrix(matrix, args.output)
if __name__ == "__main__":
main()
- [ ] Step 2: Add data directory to gitignore
Check that packages/generation/data/ is in .gitignore. If not, add packages/generation/data/*.npz.
- [ ] Step 3: Commit
git add packages/generation/scripts/build_bigram_matrix.py
git commit -m "feat(generation): bigram transition matrix builder script"
Task 10: Bigram Dead-End Filter (Lookahead Layer 3)¶
Files:
- Create: packages/governors/src/phonolex_governors/generation/lookahead.py
- Create: packages/governors/tests/test_lookahead.py
- [ ] Step 1: Write the failing tests
Create packages/governors/tests/test_lookahead.py:
import pytest
import numpy as np
import torch
from phonolex_governors.generation.lookahead import BigramDeadEndFilter, load_bigram_matrix
@pytest.fixture
def tiny_matrix(tmp_path):
"""Create a tiny bigram matrix for testing."""
# Token 10 → Token 11 (p=0.8), Token 12 (p=0.2)
# Token 13 → Token 14 (p=0.9), Token 15 (p=0.1)
token_ids = np.array([10, 10, 13, 13], dtype=np.int32)
next_ids = np.array([11, 12, 14, 15], dtype=np.int32)
probs = np.array([0.8, 0.2, 0.9, 0.1], dtype=np.float32)
path = tmp_path / "test_bigram.npz"
np.savez_compressed(path, token_ids=token_ids, next_ids=next_ids, probs=probs)
return path
@pytest.fixture
def mock_tokenizer():
from unittest.mock import MagicMock
tok = MagicMock()
tok.decode = MagicMock(side_effect=lambda ids, **kw: {
(10,): "▁dre", (11,): "am", (12,): "ss",
(13,): "▁ca", (14,): "t", (15,): "r",
(10, 11): "dream", (10, 12): "dress",
(13, 14): "cat", (13, 15): "car",
}.get(tuple(ids), "unk"))
return tok
def test_load_bigram_matrix(tiny_matrix):
matrix = load_bigram_matrix(str(tiny_matrix))
assert 10 in matrix
assert 11 in matrix[10]
assert abs(matrix[10][11] - 0.8) < 0.01
def test_dead_end_score(tiny_matrix, mock_tokenizer):
matrix = load_bigram_matrix(str(tiny_matrix))
ban_words = {"dream"} # Token 10→11 = "dream" is banned
filt = BigramDeadEndFilter(matrix, mock_tokenizer, ban_words, vocab_size=20)
# Token 10 should have dead_end_score = P(11|10) = 0.8 (since dream is banned)
scores = filt.dead_end_scores
assert scores[10] > 0.5
def test_filter_penalizes_dead_ends(tiny_matrix, mock_tokenizer):
matrix = load_bigram_matrix(str(tiny_matrix))
ban_words = {"dream"}
filt = BigramDeadEndFilter(matrix, mock_tokenizer, ban_words,
vocab_size=20, penalty=10.0)
logits = torch.zeros(1, 20)
logits[0, 10] = 5.0
logits[0, 13] = 5.0
input_ids = torch.tensor([[0]])
result = filt(input_ids, logits)
# Token 10 should be penalized (high dead-end), token 13 should not
assert result[0, 10].item() < result[0, 13].item()
- [ ] Step 2: Run test to verify it fails
Run: cd packages/governors && uv run python -m pytest tests/test_lookahead.py -v
Expected: FAIL — BigramDeadEndFilter, load_bigram_matrix not defined
- [ ] Step 3: Implement lookahead.py
Create packages/governors/src/phonolex_governors/generation/lookahead.py:
"""Lookahead mechanisms for constrained generation escalation.
Layer 3: BigramDeadEndFilter — precomputed P(B|A) scored against ban lists.
Layer 4: TargetedRolloutProcessor — short greedy rollouts for context-aware checking.
Based on:
- FUDGE (Yang & Klein, 2021) — prefix-level constraint prediction
- NeuroLogic A*esque Decoding (Lu et al., 2022) — lookahead heuristics
"""
from __future__ import annotations
import logging
from typing import Optional
import numpy as np
import torch
from transformers import LogitsProcessor
log = logging.getLogger(__name__)
def load_bigram_matrix(path: str) -> dict[int, dict[int, float]]:
"""Load a sparse bigram matrix from .npz file."""
data = np.load(path)
token_ids = data["token_ids"]
next_ids = data["next_ids"]
probs = data["probs"]
matrix: dict[int, dict[int, float]] = {}
for i in range(len(token_ids)):
tid = int(token_ids[i])
nid = int(next_ids[i])
p = float(probs[i])
if tid not in matrix:
matrix[tid] = {}
matrix[tid][nid] = p
return matrix
class BigramDeadEndFilter(LogitsProcessor):
"""Penalize tokens whose probable continuations form banned words.
Computes dead_end_score[A] = Σ P(B|A) for all B where decode(A+B) ∈ ban_set.
Applied as a penalty vector: logits[A] -= penalty * dead_end_score[A].
Based on FUDGE (Yang & Klein, 2021).
"""
def __init__(
self,
bigram_matrix: dict[int, dict[int, float]],
tokenizer,
ban_words: set[str],
vocab_size: int,
penalty: float = 10.0,
):
self.penalty = penalty
self.dead_end_scores = self._compute_scores(
bigram_matrix, tokenizer, ban_words, vocab_size,
)
def _compute_scores(
self,
matrix: dict[int, dict[int, float]],
tokenizer,
ban_words: set[str],
vocab_size: int,
) -> torch.Tensor:
"""Precompute per-token dead-end scores."""
scores = torch.zeros(vocab_size)
ban_lower = {w.lower() for w in ban_words}
for tid, transitions in matrix.items():
if tid >= vocab_size:
continue
dead_prob = 0.0
for next_tid, prob in transitions.items():
if next_tid >= vocab_size:
continue
# Decode the two-token sequence to see what word it forms
word = tokenizer.decode([tid, next_tid], skip_special_tokens=True).strip().lower()
# Also check just the continuation as a word
if word in ban_lower:
dead_prob += prob
scores[tid] = dead_prob
nonzero = (scores > 0).sum().item()
log.info("Dead-end filter: %d/%d tokens have nonzero scores", nonzero, vocab_size)
return scores
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:
device_scores = self.dead_end_scores.to(scores.device)
v = min(scores.shape[-1], device_scores.shape[0])
scores[0, :v] -= self.penalty * device_scores[:v]
return scores
class TargetedRolloutProcessor(LogitsProcessor):
"""Context-aware lookahead via short greedy rollouts.
For each top-k candidate, runs 2-3 greedy forward passes and checks
if the resulting words hit the ban set. Penalizes candidates whose
rollouts produce violations.
Based on NeuroLogic A*esque Decoding (Lu et al., 2022).
"""
def __init__(
self,
model,
tokenizer,
ban_words: set[str],
allow_words: Optional[set[str]] = None,
rollout_depth: int = 3,
top_k: int = 50,
penalty: float = 50.0,
):
self.model = model
self.tokenizer = tokenizer
self.ban_words = {w.lower() for w in ban_words}
self.allow_words = {w.lower() for w in allow_words} if allow_words else None
self.rollout_depth = rollout_depth
self.top_k = top_k
self.penalty = penalty
def _is_violation(self, word: str) -> bool:
w = word.lower().strip()
if not w or not w.isalpha() or len(w) < 2:
return False
if w in self.ban_words:
return True
if self.allow_words is not None and w not in self.allow_words:
return True
return False
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:
import re
k = min(self.top_k, scores.shape[-1])
_, top_ids = torch.topk(scores[0], k)
for i in range(k):
tid = top_ids[i].item()
# Build candidate sequence
candidate = torch.cat([input_ids[0], torch.tensor([tid], device=input_ids.device)])
# Greedy rollout
rollout_ids = [tid]
current = candidate.unsqueeze(0)
with torch.no_grad():
for _ in range(self.rollout_depth):
outputs = self.model(current, use_cache=False)
next_logits = outputs.logits[0, -1, :]
next_id = next_logits.argmax().item()
rollout_ids.append(next_id)
current = torch.cat([
current,
torch.tensor([[next_id]], device=current.device),
], dim=1)
# Decode rollout and check words
rollout_text = self.tokenizer.decode(rollout_ids, skip_special_tokens=True)
rollout_words = re.findall(r'[a-zA-Z]+', rollout_text)
if any(self._is_violation(w) for w in rollout_words):
scores[0, tid] -= self.penalty
return scores
- [ ] Step 4: Run test to verify it passes
Run: cd packages/governors && uv run python -m pytest tests/test_lookahead.py -v
Expected: PASS
- [ ] Step 5: Commit
git add packages/governors/src/phonolex_governors/generation/lookahead.py packages/governors/tests/test_lookahead.py
git commit -m "feat(lookahead): bigram dead-end filter + targeted rollout processors"
Task 11: Integration — Wire Lookahead into Generation Pipeline¶
Files:
- Modify: packages/generation/server/routes/generate.py
- Modify: packages/generation/server/model.py
- [ ] Step 1: Add lookahead escalation to the SSE pipeline
In the generation loop in routes/generate.py, after the retry budget is exhausted, add escalation:
# After max_retries reached for a draft:
if attempt == max_retries and violations:
# --- Layer 3: Bigram dead-end filter ---
yield emit("Activating bigram dead-end filter...")
from phonolex_governors.generation.lookahead import BigramDeadEndFilter, load_bigram_matrix
bigram_path = Path(__file__).resolve().parents[2] / "data" / "bigram_matrix.npz"
if bigram_path.exists():
matrix = load_bigram_matrix(str(bigram_path))
all_ban_words = set()
for rc in resolved:
if rc.mode == "ban" and rc.strategy == "direct":
all_ban_words.update(rc.words)
dead_end_filter = BigramDeadEndFilter(
matrix, tokenizer, all_ban_words, vocab_size=model.get_vocab_size(),
)
# Add to processors and generate one more draft
yield emit("Generating with lookahead...")
# ... generate with dead_end_filter added to LogitsProcessorList
# --- Layer 4: Targeted rollout (if still failing) ---
if still_failing and model_ref is not None:
yield emit("Activating targeted rollout...")
from phonolex_governors.generation.lookahead import TargetedRolloutProcessor
rollout = TargetedRolloutProcessor(
model=model_ref, tokenizer=tokenizer,
ban_words=all_ban_words,
)
# ... generate with rollout added to LogitsProcessorList
The exact wiring depends on how model.generate_single() is refactored to accept additional LogitsProcessors and bad_words_ids. The key pattern: escalation layers are lazy — only instantiated when earlier layers fail.
- [ ] Step 2: Update model.py to accept bad_words_ids
Modify generate_single() to accept an optional bad_words_ids parameter:
def generate_single(
prompt: str,
logits_processor: LogitsProcessorList | None = None,
bad_words_ids: list[list[int]] | None = None,
) -> tuple[list[int], str, float]:
Add bad_words_ids to gen_kwargs if provided.
- [ ] Step 3: Test the full pipeline manually
Start all three servers (Workers API, frontend, generation server). Set a phoneme exclusion constraint in the UI, generate, and verify: 1. Status messages appear below the Generate button 2. Compliance highlighting works on the output 3. Word list resolution happens (check generation server logs)
- [ ] Step 4: Commit
git add packages/generation/server/routes/generate.py packages/generation/server/model.py
git commit -m "feat(generation): wire lookahead escalation into SSE pipeline"
Task 12: Update main.py and Clean Up Dead Code¶
Files:
- Modify: packages/generation/server/main.py
- Delete: dead test files
- Modify: packages/generation/server/tests/conftest.py
- [ ] Step 1: Simplify main.py
Remove GovernorCache initialization, lookup loading, and session/profile store setup. Keep model loading and word_norms loading. Update route initialization.
- [ ] Step 2: Update conftest.py
Remove the lookup and small_lookup fixtures (no longer needed — no token-level governor tests). Add any new fixtures needed for the word-list-based tests.
- [ ] Step 3: Delete dead test files
rm packages/generation/server/tests/test_enrichment.py # if it tests old compliance
rm packages/generation/server/tests/test_compliance.py # if it tests old checker path
Review each before deleting — keep any tests that test GUARD-relevant behavior.
- [ ] Step 4: Run full test suite
Run: cd packages/generation && uv run python -m pytest server/tests/ -v
Expected: All remaining tests pass
- [ ] Step 5: Commit
git add -u packages/generation/
git commit -m "chore: clean up dead code — remove sessions, profiles, lookup fixtures"
Task 13: Frontend Build Verification¶
Files: - Modify: various frontend files as needed for type errors
- [ ] Step 1: Run TypeScript type check
Run: cd packages/web/frontend && npx tsc --noEmit
Expected: No errors. If there are errors from the governance.ts type changes (removed old types still referenced elsewhere), fix them.
- [ ] Step 2: Run frontend tests
Run: cd packages/web/frontend && npx vitest run
Expected: All tests pass
- [ ] Step 3: Run Workers API tests
Run: cd packages/web/workers && npm test
Expected: All tests pass including new word-list tests
- [ ] Step 4: Commit any fixes
git add -u packages/web/
git commit -m "fix: resolve type errors from governance.ts simplification"