Skip to content

Pointer-Generator Canonical Decomposer Implementation Plan

For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (- [ ]) syntax for tracking.

Goal: Replace the split-point + edit-codebook decomposer with a character-level seq2seq pointer-generator that eliminates the mono/edit Pareto tension.

Architecture: BiLSTM encoder (2-layer, 128-dim bidirectional → 256-dim output), unidirectional LSTM decoder (1-layer, 256-dim) with Bahdanau attention and pointer-generator gate. Output is a character sequence with + boundary tokens. ~1.2M parameters.

Tech Stack: PyTorch, existing CharVocab from phonolex_tokenizer.model.features, existing MorphyNet loader + negatives from phonolex_tokenizer.data, existing eval metrics from phonolex_tokenizer.eval.decomposer_metrics.

Spec: docs/superpowers/specs/2026-04-14-pointer-generator-canonical-decomposer-design.md


Task 1: Seq2Seq Data Loader

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/data/seq2seq_loader.py - Test: packages/tokenizer/tests/test_seq2seq_loader.py

This loader converts DecompositionExample and MonoExample into character-level input/output sequences for the seq2seq model.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_loader.py
"""Tests for seq2seq data loader."""

import pytest
from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.data.seq2seq_loader import (
    Seq2SeqExample,
    to_seq2seq_example,
    to_seq2seq_examples,
)
from phonolex_tokenizer.model.schema import MorphLabel


class TestToSeq2SeqExample:
    def test_suffix_surface_faithful(self):
        ex = DecompositionExample(
            word="kindness", base="kind", affix="ness",
            label=MorphLabel.SUFFIX, is_allomorphic=False,
        )
        result = to_seq2seq_example(ex)
        assert result.input_chars == list("kindness")
        assert result.output_chars == list("kind") + ["+"] + list("ness")
        assert result.is_allomorphic is False

    def test_suffix_allomorphic(self):
        ex = DecompositionExample(
            word="happily", base="happy", affix="ly",
            label=MorphLabel.SUFFIX, is_allomorphic=True,
        )
        result = to_seq2seq_example(ex)
        assert result.input_chars == list("happily")
        assert result.output_chars == list("happy") + ["+"] + list("ly")
        assert result.is_allomorphic is True

    def test_prefix(self):
        ex = DecompositionExample(
            word="unhappy", base="happy", affix="un",
            label=MorphLabel.PREFIX, is_allomorphic=False,
        )
        result = to_seq2seq_example(ex)
        assert result.input_chars == list("unhappy")
        assert result.output_chars == list("un") + ["+"] + list("happy")

    def test_mono_example(self):
        mono = MonoExample(word="butter")
        result = to_seq2seq_example(mono)
        assert result.input_chars == list("butter")
        assert result.output_chars == list("butter")
        assert result.is_allomorphic is False

    def test_batch_conversion(self):
        positives = [
            DecompositionExample(
                word="kindness", base="kind", affix="ness",
                label=MorphLabel.SUFFIX, is_allomorphic=False,
            ),
        ]
        negatives = [MonoExample(word="butter")]
        results = to_seq2seq_examples(positives, negatives)
        assert len(results) == 2
        assert "+" in results[0].output_chars
        assert "+" not in results[1].output_chars
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_loader.py -v Expected: FAIL with ModuleNotFoundError: No module named 'phonolex_tokenizer.data.seq2seq_loader'

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/data/seq2seq_loader.py
"""Seq2seq data loader for the pointer-generator decomposer.

Converts DecompositionExample and MonoExample into character-level
input/output sequences where morpheme boundaries are marked with '+'.

Examples:
    Suffix:     "kindness" → "k i n d + n e s s"
    Allomorphic: "happily" → "h a p p y + l y"
    Prefix:     "unhappy"  → "u n + h a p p y"
    Mono:       "butter"   → "b u t t e r"
"""

from __future__ import annotations

from dataclasses import dataclass

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.model.schema import MorphLabel

BOUNDARY = "+"


@dataclass(frozen=True)
class Seq2SeqExample:
    """A single input/output pair for seq2seq training."""

    word: str
    input_chars: list[str]
    output_chars: list[str]
    is_allomorphic: bool


def to_seq2seq_example(
    ex: DecompositionExample | MonoExample,
) -> Seq2SeqExample:
    """Convert a decomposition or mono example to a seq2seq pair."""
    if isinstance(ex, MonoExample):
        chars = list(ex.word)
        return Seq2SeqExample(
            word=ex.word,
            input_chars=chars,
            output_chars=chars.copy(),
            is_allomorphic=False,
        )

    input_chars = list(ex.word)

    if ex.label == MorphLabel.PREFIX:
        output_chars = list(ex.affix) + [BOUNDARY] + list(ex.base)
    else:
        output_chars = list(ex.base) + [BOUNDARY] + list(ex.affix)

    return Seq2SeqExample(
        word=ex.word,
        input_chars=input_chars,
        output_chars=output_chars,
        is_allomorphic=ex.is_allomorphic,
    )


def to_seq2seq_examples(
    positives: list[DecompositionExample],
    negatives: list[MonoExample],
) -> list[Seq2SeqExample]:
    """Convert all examples to seq2seq pairs."""
    results: list[Seq2SeqExample] = []
    for ex in positives:
        results.append(to_seq2seq_example(ex))
    for ex in negatives:
        results.append(to_seq2seq_example(ex))
    return results
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_loader.py -v Expected: All 5 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/data/seq2seq_loader.py packages/tokenizer/tests/test_seq2seq_loader.py
git commit -m "feat(tokenizer): seq2seq data loader for pointer-generator decomposer"

Task 2: Seq2Seq CharVocab Extension

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/__init__.py - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/vocab.py - Test: packages/tokenizer/tests/test_seq2seq_vocab.py

The decoder needs a vocabulary that includes the + boundary token and EOS. Extend the existing CharVocab with a decoder-side vocabulary builder.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_vocab.py
"""Tests for seq2seq vocabulary."""

import torch
import pytest
from phonolex_tokenizer.seq2seq.vocab import Seq2SeqVocab


class TestSeq2SeqVocab:
    def test_build_from_examples(self):
        input_chars = [list("happy"), list("butter")]
        output_chars = [list("happy"), list("butter")]
        vocab = Seq2SeqVocab.from_chars(input_chars, output_chars)

        # Special tokens exist
        assert vocab.pad_idx == 0
        assert vocab.bos_idx is not None
        assert vocab.eos_idx is not None
        assert vocab.boundary_idx is not None

        # All characters are in vocab
        for c in "happybutter":
            assert vocab.char_to_idx(c) != vocab.unk_idx

    def test_boundary_token_in_vocab(self):
        vocab = Seq2SeqVocab.from_chars(
            [list("ab")],
            [list("a") + ["+"] + list("b")],
        )
        assert vocab.boundary_idx != vocab.unk_idx
        assert vocab.idx_to_char(vocab.boundary_idx) == "+"

    def test_encode_decode_roundtrip(self):
        vocab = Seq2SeqVocab.from_chars(
            [list("kind")],
            [list("kind")],
        )
        encoded = vocab.encode_input("kind")
        decoded = vocab.decode_output(encoded)
        assert decoded == "kind"

    def test_encode_output_with_boundary(self):
        vocab = Seq2SeqVocab.from_chars(
            [list("kindness")],
            [list("kind") + ["+"] + list("ness")],
        )
        encoded = vocab.encode_output(list("kind") + ["+"] + list("ness"))
        # Should end with EOS
        assert encoded[-1] == vocab.eos_idx
        # Should contain boundary
        assert vocab.boundary_idx in encoded

    def test_pad_batch(self):
        vocab = Seq2SeqVocab.from_chars([list("ab"), list("cd")], [list("ab"), list("cd")])
        seqs = [vocab.encode_input("ab"), vocab.encode_input("a")]
        padded, lengths = vocab.pad_batch(seqs)
        assert padded.shape == (2, 2)
        assert lengths.tolist() == [2, 1]
        assert padded[1, 1].item() == vocab.pad_idx

    def test_vocab_size(self):
        vocab = Seq2SeqVocab.from_chars(
            [list("abc")],
            [list("a") + ["+"] + list("bc")],
        )
        # pad, unk, bos, eos = 4 specials + a, b, c, + = 4 chars = 8
        assert vocab.size == 8
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_vocab.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Create the seq2seq package and write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/__init__.py
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/vocab.py
"""Vocabulary for the seq2seq pointer-generator decomposer.

Shared vocabulary for encoder and decoder. The decoder vocab is the
encoder vocab plus the '+' boundary token (if not already present).
Both sides share pad/unk/bos/eos special tokens.
"""

from __future__ import annotations

import json
from pathlib import Path

import torch


class Seq2SeqVocab:
    """Character-level vocabulary with boundary token for seq2seq decomposer.

    Index layout: 0=<pad>, 1=<unk>, 2=<bos>, 3=<eos>, 4..N=characters (sorted).
    The '+' boundary token is a regular character in the vocabulary.
    """

    _SPECIALS = ["<pad>", "<unk>", "<bos>", "<eos>"]

    def __init__(self, char_to_idx: dict[str, int]) -> None:
        self._c2i = char_to_idx
        self._i2c = {v: k for k, v in char_to_idx.items()}

    @classmethod
    def from_chars(
        cls,
        input_seqs: list[list[str]],
        output_seqs: list[list[str]],
    ) -> Seq2SeqVocab:
        """Build vocabulary from input and output character sequences."""
        chars: set[str] = set()
        for seq in input_seqs:
            chars.update(seq)
        for seq in output_seqs:
            chars.update(seq)

        char_to_idx: dict[str, int] = {}
        for i, special in enumerate(cls._SPECIALS):
            char_to_idx[special] = i
        for c in sorted(chars):
            if c not in char_to_idx:
                char_to_idx[c] = len(char_to_idx)

        return cls(char_to_idx)

    @property
    def size(self) -> int:
        return len(self._c2i)

    @property
    def pad_idx(self) -> int:
        return self._c2i["<pad>"]

    @property
    def unk_idx(self) -> int:
        return self._c2i["<unk>"]

    @property
    def bos_idx(self) -> int:
        return self._c2i["<bos>"]

    @property
    def eos_idx(self) -> int:
        return self._c2i["<eos>"]

    @property
    def boundary_idx(self) -> int:
        return self._c2i["+"]

    def char_to_idx(self, c: str) -> int:
        return self._c2i.get(c, self.unk_idx)

    def idx_to_char(self, idx: int) -> str:
        return self._i2c.get(idx, "")

    def encode_input(self, word: str) -> list[int]:
        """Encode a word as a list of character indices."""
        return [self._c2i.get(c, self.unk_idx) for c in word]

    def encode_output(self, chars: list[str]) -> list[int]:
        """Encode output character list, appending EOS."""
        ids = [self._c2i.get(c, self.unk_idx) for c in chars]
        ids.append(self.eos_idx)
        return ids

    def decode_output(self, ids: list[int]) -> str:
        """Decode output indices back to a string (strips specials)."""
        skip = {self.pad_idx, self.bos_idx, self.eos_idx}
        return "".join(self._i2c.get(i, "") for i in ids if i not in skip)

    def pad_batch(
        self, sequences: list[list[int]]
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Pad sequences to equal length. Returns (padded, lengths)."""
        lengths = [len(s) for s in sequences]
        max_len = max(lengths)
        padded = [s + [self.pad_idx] * (max_len - len(s)) for s in sequences]
        return (
            torch.tensor(padded, dtype=torch.long),
            torch.tensor(lengths, dtype=torch.long),
        )

    def save(self, path: str | Path) -> None:
        with open(path, "w", encoding="utf-8") as f:
            json.dump(self._c2i, f, ensure_ascii=False, indent=2)

    @classmethod
    def load(cls, path: str | Path) -> Seq2SeqVocab:
        with open(path, encoding="utf-8") as f:
            char_to_idx = json.load(f)
        return cls(char_to_idx)
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_vocab.py -v Expected: All 6 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/__init__.py packages/tokenizer/src/phonolex_tokenizer/seq2seq/vocab.py packages/tokenizer/tests/test_seq2seq_vocab.py
git commit -m "feat(tokenizer): seq2seq vocabulary with boundary token support"

Task 3: Attention Module

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/attention.py - Test: packages/tokenizer/tests/test_seq2seq_attention.py

Bahdanau (additive) attention over encoder hidden states.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_attention.py
"""Tests for Bahdanau attention module."""

import torch
import pytest
from phonolex_tokenizer.seq2seq.attention import BahdanauAttention


class TestBahdanauAttention:
    def test_output_shapes(self):
        attn = BahdanauAttention(enc_dim=256, dec_dim=256)
        # encoder_outputs: (batch=2, src_len=8, 256)
        encoder_outputs = torch.randn(2, 8, 256)
        # decoder_hidden: (batch=2, 256)
        decoder_hidden = torch.randn(2, 256)

        context, weights = attn(decoder_hidden, encoder_outputs)
        assert context.shape == (2, 256)
        assert weights.shape == (2, 8)

    def test_weights_sum_to_one(self):
        attn = BahdanauAttention(enc_dim=256, dec_dim=256)
        encoder_outputs = torch.randn(1, 5, 256)
        decoder_hidden = torch.randn(1, 256)

        _, weights = attn(decoder_hidden, encoder_outputs)
        assert torch.allclose(weights.sum(dim=-1), torch.tensor([1.0]), atol=1e-5)

    def test_masking(self):
        attn = BahdanauAttention(enc_dim=256, dec_dim=256)
        encoder_outputs = torch.randn(2, 6, 256)
        decoder_hidden = torch.randn(2, 256)
        # Second example only has 3 valid positions
        mask = torch.tensor([
            [False, False, False, False, False, False],
            [False, False, False, True, True, True],
        ])

        _, weights = attn(decoder_hidden, encoder_outputs, mask=mask)
        # Masked positions should have ~0 weight
        assert weights[1, 3:].sum().item() < 1e-5
        # Unmasked should still sum to 1
        assert torch.allclose(weights[1, :3].sum(), torch.tensor(1.0), atol=1e-5)
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_attention.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/attention.py
"""Bahdanau (additive) attention for the pointer-generator decomposer."""

from __future__ import annotations

import torch
import torch.nn as nn


class BahdanauAttention(nn.Module):
    """Additive attention: score(s, h) = v^T tanh(W_h h + W_s s).

    Args:
        enc_dim: Encoder hidden dimension.
        dec_dim: Decoder hidden dimension.
    """

    def __init__(self, enc_dim: int, dec_dim: int) -> None:
        super().__init__()
        self.W_h = nn.Linear(enc_dim, dec_dim, bias=False)
        self.W_s = nn.Linear(dec_dim, dec_dim, bias=True)
        self.v = nn.Linear(dec_dim, 1, bias=False)

    def forward(
        self,
        decoder_hidden: torch.Tensor,
        encoder_outputs: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Compute attention context and weights.

        Args:
            decoder_hidden: (batch, dec_dim)
            encoder_outputs: (batch, src_len, enc_dim)
            mask: (batch, src_len) — True means MASKED (ignored).

        Returns:
            context: (batch, enc_dim)
            weights: (batch, src_len) — attention distribution.
        """
        # (batch, src_len, dec_dim)
        encoder_proj = self.W_h(encoder_outputs)
        # (batch, 1, dec_dim)
        decoder_proj = self.W_s(decoder_hidden).unsqueeze(1)
        # (batch, src_len, 1) → (batch, src_len)
        energy = self.v(torch.tanh(encoder_proj + decoder_proj)).squeeze(-1)

        if mask is not None:
            energy = energy.masked_fill(mask, float("-inf"))

        weights = torch.softmax(energy, dim=-1)
        # (batch, 1, src_len) @ (batch, src_len, enc_dim) → (batch, 1, enc_dim)
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)

        return context, weights
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_attention.py -v Expected: All 3 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/attention.py packages/tokenizer/tests/test_seq2seq_attention.py
git commit -m "feat(tokenizer): Bahdanau attention module for pointer-generator"

Task 4: Pointer-Generator Model

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/model.py - Test: packages/tokenizer/tests/test_seq2seq_model.py

The core model: BiLSTM encoder, LSTM decoder with attention, pointer-generator gate.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_model.py
"""Tests for PointerGeneratorNet forward pass and output shapes."""

import torch
import pytest
from phonolex_tokenizer.seq2seq.model import PointerGeneratorNet


@pytest.fixture
def model():
    return PointerGeneratorNet(
        vocab_size=30,
        embed_dim=16,
        enc_hidden_dim=32,
        dec_hidden_dim=64,  # 32 * 2 (bidirectional)
        num_enc_layers=2,
        dropout=0.0,
    )


class TestPointerGeneratorNet:
    def test_encode_shapes(self, model):
        src = torch.randint(1, 30, (2, 8))
        src_lengths = torch.tensor([8, 6])
        encoder_outputs, final_hidden = model.encode(src, src_lengths)
        assert encoder_outputs.shape == (2, 8, 64)
        assert final_hidden[0].shape == (1, 2, 64)  # (layers, batch, dec_dim)
        assert final_hidden[1].shape == (1, 2, 64)

    def test_decode_step_shapes(self, model):
        encoder_outputs = torch.randn(2, 8, 64)
        dec_input = torch.randint(0, 30, (2,))
        h = torch.zeros(1, 2, 64)
        c = torch.zeros(1, 2, 64)

        vocab_dist, attn_weights, p_gen, (h_new, c_new) = model.decode_step(
            dec_input, (h, c), encoder_outputs,
        )
        assert vocab_dist.shape == (2, 30)
        assert attn_weights.shape == (2, 8)
        assert p_gen.shape == (2, 1)
        assert h_new.shape == (1, 2, 64)

    def test_forward_shapes(self, model):
        src = torch.randint(1, 30, (2, 8))
        src_lengths = torch.tensor([8, 6])
        tgt = torch.randint(0, 30, (2, 10))  # decoder targets

        log_probs = model(src, src_lengths, tgt)
        # (batch, tgt_len, vocab_size)
        assert log_probs.shape == (2, 10, 30)

    def test_p_gen_range(self, model):
        encoder_outputs = torch.randn(1, 5, 64)
        dec_input = torch.randint(0, 30, (1,))
        h = torch.zeros(1, 1, 64)
        c = torch.zeros(1, 1, 64)

        _, _, p_gen, _ = model.decode_step(
            dec_input, (h, c), encoder_outputs,
        )
        assert 0.0 <= p_gen.item() <= 1.0
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_model.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/model.py
"""Pointer-generator network for canonical morphological decomposition.

BiLSTM encoder + LSTM decoder with Bahdanau attention and a pointer-generator
gate that decides at each step whether to copy a character from the input
(pointer mode) or generate from the vocabulary (generator mode).
"""

from __future__ import annotations

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from phonolex_tokenizer.seq2seq.attention import BahdanauAttention


class PointerGeneratorNet(nn.Module):
    """Pointer-generator seq2seq model for morphological decomposition.

    Args:
        vocab_size: Character vocabulary size (shared encoder/decoder).
        embed_dim: Character embedding dimension.
        enc_hidden_dim: Encoder LSTM hidden size (per direction).
        dec_hidden_dim: Decoder LSTM hidden size. Must equal enc_hidden_dim * 2.
        num_enc_layers: Number of encoder BiLSTM layers.
        dropout: Dropout probability.
    """

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int = 64,
        enc_hidden_dim: int = 128,
        dec_hidden_dim: int = 256,
        num_enc_layers: int = 2,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.dec_hidden_dim = dec_hidden_dim

        # Shared embedding
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

        # Encoder: bidirectional LSTM
        self.encoder = nn.LSTM(
            embed_dim,
            enc_hidden_dim,
            num_layers=num_enc_layers,
            bidirectional=True,
            dropout=dropout if num_enc_layers > 1 else 0.0,
            batch_first=True,
        )
        self.enc_hidden_dim = enc_hidden_dim
        self.num_enc_layers = num_enc_layers

        # Project encoder final hidden to decoder initial hidden
        self.h_proj = nn.Linear(enc_hidden_dim * 2, dec_hidden_dim)
        self.c_proj = nn.Linear(enc_hidden_dim * 2, dec_hidden_dim)

        # Decoder: unidirectional LSTM
        # Input: embedding(y_{t-1}) + context
        self.decoder = nn.LSTM(
            embed_dim + dec_hidden_dim,
            dec_hidden_dim,
            num_layers=1,
            batch_first=False,
        )

        # Attention
        self.attention = BahdanauAttention(
            enc_dim=dec_hidden_dim,
            dec_dim=dec_hidden_dim,
        )

        # Output projection: decoder hidden + context → vocab distribution
        self.out_proj = nn.Linear(dec_hidden_dim + dec_hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

        # Pointer-generator gate
        self.p_gen_linear = nn.Linear(dec_hidden_dim + dec_hidden_dim + embed_dim, 1)

    def encode(
        self,
        src: torch.Tensor,
        src_lengths: torch.Tensor,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """Encode source sequence.

        Args:
            src: (batch, src_len) character indices.
            src_lengths: (batch,) valid lengths.

        Returns:
            encoder_outputs: (batch, src_len, dec_hidden_dim)
            final_hidden: ((1, batch, dec_hidden_dim), (1, batch, dec_hidden_dim))
        """
        embedded = self.dropout(self.embedding(src))

        # Pack, run LSTM, unpack
        packed = pack_padded_sequence(
            embedded, src_lengths.cpu(), batch_first=True, enforce_sorted=False,
        )
        packed_out, (h_n, c_n) = self.encoder(packed)
        encoder_outputs, _ = pad_packed_sequence(packed_out, batch_first=True)

        # h_n: (num_layers * 2, batch, enc_hidden)
        # Combine final forward and backward hidden states from last layer
        # Forward: h_n[-2], Backward: h_n[-1]
        h_fwd = h_n[-2]  # (batch, enc_hidden)
        h_bwd = h_n[-1]  # (batch, enc_hidden)
        h_combined = torch.cat([h_fwd, h_bwd], dim=-1)  # (batch, dec_hidden)

        c_fwd = c_n[-2]
        c_bwd = c_n[-1]
        c_combined = torch.cat([c_fwd, c_bwd], dim=-1)

        # Project to decoder dimension
        h0 = torch.tanh(self.h_proj(h_combined)).unsqueeze(0)  # (1, batch, dec_hidden)
        c0 = torch.tanh(self.c_proj(c_combined)).unsqueeze(0)

        return encoder_outputs, (h0, c0)

    def decode_step(
        self,
        dec_input: torch.Tensor,
        hidden: tuple[torch.Tensor, torch.Tensor],
        encoder_outputs: torch.Tensor,
        src_mask: torch.Tensor | None = None,
    ) -> tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        tuple[torch.Tensor, torch.Tensor],
    ]:
        """Run one decoder step.

        Args:
            dec_input: (batch,) current input token indices.
            hidden: ((1, batch, dec_dim), (1, batch, dec_dim)) LSTM state.
            encoder_outputs: (batch, src_len, dec_dim).
            src_mask: (batch, src_len) True = masked.

        Returns:
            vocab_dist: (batch, vocab_size) generation distribution (pre-mixing).
            attn_weights: (batch, src_len) attention weights.
            p_gen: (batch, 1) pointer-generator gate value.
            new_hidden: updated LSTM state.
        """
        embedded = self.dropout(self.embedding(dec_input))  # (batch, embed_dim)

        # Attention over encoder outputs using previous hidden state
        h_squeezed = hidden[0].squeeze(0)  # (batch, dec_dim)
        context, attn_weights = self.attention(
            h_squeezed, encoder_outputs, mask=src_mask,
        )

        # Decoder LSTM input: [embedded; context]
        lstm_input = torch.cat([embedded, context], dim=-1)  # (batch, embed+dec_dim)
        lstm_input = lstm_input.unsqueeze(0)  # (1, batch, embed+dec_dim)
        lstm_out, new_hidden = self.decoder(lstm_input, hidden)
        lstm_out = lstm_out.squeeze(0)  # (batch, dec_dim)

        # Vocab distribution
        vocab_logits = self.out_proj(
            self.dropout(torch.cat([lstm_out, context], dim=-1))
        )
        vocab_dist = torch.softmax(vocab_logits, dim=-1)

        # Pointer-generator gate
        p_gen = torch.sigmoid(
            self.p_gen_linear(torch.cat([lstm_out, context, embedded], dim=-1))
        )

        return vocab_dist, attn_weights, p_gen, new_hidden

    def forward(
        self,
        src: torch.Tensor,
        src_lengths: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Teacher-forced forward pass.

        Args:
            src: (batch, src_len) source character indices.
            src_lengths: (batch,) source lengths.
            tgt: (batch, tgt_len) target character indices (used as decoder input).
            src_mask: (batch, src_len) True = masked.

        Returns:
            log_probs: (batch, tgt_len, vocab_size) — log probabilities of the
                mixed pointer-generator distribution at each step.
        """
        batch_size, tgt_len = tgt.shape

        encoder_outputs, hidden = self.encode(src, src_lengths)
        src_len = encoder_outputs.size(1)

        all_log_probs: list[torch.Tensor] = []

        for t in range(tgt_len):
            dec_input = tgt[:, t]
            vocab_dist, attn_weights, p_gen, hidden = self.decode_step(
                dec_input, hidden, encoder_outputs, src_mask,
            )

            # Mix pointer and generator distributions
            # p_gen * vocab_dist + (1 - p_gen) * copy_dist
            gen_dist = p_gen * vocab_dist  # (batch, vocab_size)

            # Scatter attention weights into vocabulary-sized tensor
            copy_dist = torch.zeros_like(vocab_dist)  # (batch, vocab_size)
            copy_dist.scatter_add_(1, src[:, :src_len], (1 - p_gen) * attn_weights)

            mixed = gen_dist + copy_dist
            # Clamp for numerical stability before log
            log_probs = torch.log(mixed.clamp(min=1e-12))
            all_log_probs.append(log_probs)

        return torch.stack(all_log_probs, dim=1)  # (batch, tgt_len, vocab_size)
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_model.py -v Expected: All 4 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/model.py packages/tokenizer/tests/test_seq2seq_model.py
git commit -m "feat(tokenizer): pointer-generator network for canonical decomposition"

Task 5: Greedy Decoder and Output Parser

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/decode.py - Test: packages/tokenizer/tests/test_seq2seq_decode.py

Greedy decoding at inference time, plus parsing the output character sequence back into morphemes.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_decode.py
"""Tests for greedy decoding and output parsing."""

import pytest
from phonolex_tokenizer.seq2seq.decode import parse_morphemes
from phonolex_tokenizer.model.schema import MorphLabel


class TestParseMorphemes:
    def test_monomorphemic(self):
        morphemes, labels = parse_morphemes("butter")
        assert morphemes == ["butter"]
        assert labels == [MorphLabel.ROOT]

    def test_suffix(self):
        morphemes, labels = parse_morphemes("kind+ness")
        assert morphemes == ["kind", "ness"]
        assert labels == [MorphLabel.ROOT, MorphLabel.SUFFIX]

    def test_prefix(self):
        morphemes, labels = parse_morphemes("un+happy")
        assert morphemes == ["un", "happy"]
        assert labels == [MorphLabel.PREFIX, MorphLabel.ROOT]

    def test_three_morphemes(self):
        morphemes, labels = parse_morphemes("un+happy+ly")
        assert morphemes == ["un", "happy", "ly"]
        assert labels == [MorphLabel.PREFIX, MorphLabel.ROOT, MorphLabel.SUFFIX]

    def test_empty_segments_ignored(self):
        morphemes, labels = parse_morphemes("kind++ness")
        assert morphemes == ["kind", "ness"]

    def test_known_prefix_detection(self):
        morphemes, labels = parse_morphemes("re+build")
        assert labels == [MorphLabel.PREFIX, MorphLabel.ROOT]

    def test_two_suffixes(self):
        morphemes, labels = parse_morphemes("kind+ness+es")
        assert morphemes == ["kind", "ness", "es"]
        assert labels == [MorphLabel.ROOT, MorphLabel.SUFFIX, MorphLabel.SUFFIX]
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_decode.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/decode.py
"""Greedy decoding and output parsing for the pointer-generator decomposer."""

from __future__ import annotations

import torch

from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.model import PointerGeneratorNet
from phonolex_tokenizer.seq2seq.vocab import Seq2SeqVocab

# Prefixes recognized for label assignment
_KNOWN_PREFIXES: frozenset[str] = frozenset([
    "un", "re", "pre", "dis", "mis", "in", "im", "ir", "il",
    "non", "anti", "de", "over", "under", "out", "sub", "super",
    "inter", "trans", "pro", "post", "counter", "semi", "micro",
    "macro", "ultra", "hyper", "hypo", "extra", "co", "multi",
    "bi", "tri", "mono", "poly", "neo", "pseudo",
])


def parse_morphemes(decoded_str: str) -> tuple[list[str], list[MorphLabel]]:
    """Parse a decoded output string into morphemes and labels.

    Splits on '+' boundaries, assigns labels based on position and
    known prefix list. Empty segments from consecutive '+' are dropped.

    Args:
        decoded_str: Decoded character string, e.g. "un+happy+ly".

    Returns:
        (morphemes, labels) — parallel lists.
    """
    segments = [s for s in decoded_str.split("+") if s]

    if len(segments) <= 1:
        word = segments[0] if segments else ""
        return [word], [MorphLabel.ROOT]

    # Find the root: longest segment, or first non-prefix segment
    labels: list[MorphLabel] = []
    root_idx = _find_root_index(segments)

    for i, seg in enumerate(segments):
        if i < root_idx:
            labels.append(MorphLabel.PREFIX)
        elif i == root_idx:
            labels.append(MorphLabel.ROOT)
        else:
            labels.append(MorphLabel.SUFFIX)

    return segments, labels


def _find_root_index(segments: list[str]) -> int:
    """Find the root morpheme index. The root is the first segment
    that is not a known prefix."""
    for i, seg in enumerate(segments):
        if seg not in _KNOWN_PREFIXES:
            return i
    # All segments are known prefixes — last one is the root
    return len(segments) - 1


@torch.no_grad()
def greedy_decode(
    model: PointerGeneratorNet,
    vocab: Seq2SeqVocab,
    src: torch.Tensor,
    src_lengths: torch.Tensor,
    max_len: int = 40,
) -> list[str]:
    """Greedy decode a batch of source sequences.

    Args:
        model: Trained PointerGeneratorNet.
        vocab: Seq2SeqVocab used during training.
        src: (batch, src_len) source character indices.
        src_lengths: (batch,) valid lengths.
        max_len: Maximum output length per example.

    Returns:
        List of decoded strings (one per batch element).
    """
    model.eval()
    batch_size = src.size(0)
    device = src.device
    src_len = src.size(1)

    encoder_outputs, hidden = model.encode(src, src_lengths)

    # Build source padding mask
    src_mask = torch.arange(src_len, device=device).unsqueeze(0) >= src_lengths.unsqueeze(1)

    # Start with BOS token
    dec_input = torch.full((batch_size,), vocab.bos_idx, dtype=torch.long, device=device)

    outputs: list[list[int]] = [[] for _ in range(batch_size)]
    finished = [False] * batch_size

    for _ in range(max_len):
        vocab_dist, attn_weights, p_gen, hidden = model.decode_step(
            dec_input, hidden, encoder_outputs, src_mask,
        )

        # Mix pointer and generator
        gen_dist = p_gen * vocab_dist
        copy_dist = torch.zeros_like(vocab_dist)
        copy_dist.scatter_add_(1, src[:, :src_len], (1 - p_gen) * attn_weights)
        mixed = gen_dist + copy_dist

        # Greedy: argmax
        next_tokens = mixed.argmax(dim=-1)  # (batch,)

        for i in range(batch_size):
            tok = next_tokens[i].item()
            if finished[i]:
                continue
            if tok == vocab.eos_idx:
                finished[i] = True
            else:
                outputs[i].append(tok)

        if all(finished):
            break

        dec_input = next_tokens

    return [vocab.decode_output(ids) for ids in outputs]
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_decode.py -v Expected: All 7 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/decode.py packages/tokenizer/tests/test_seq2seq_decode.py
git commit -m "feat(tokenizer): greedy decoder and morpheme parser for pointer-generator"

Task 6: Seq2Seq Dataset (PyTorch)

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/dataset.py - Test: packages/tokenizer/tests/test_seq2seq_dataset.py

PyTorch Dataset that produces padded batches for training.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_dataset.py
"""Tests for seq2seq PyTorch dataset."""

import torch
import pytest
from phonolex_tokenizer.data.seq2seq_loader import Seq2SeqExample
from phonolex_tokenizer.seq2seq.dataset import Seq2SeqDataset
from phonolex_tokenizer.seq2seq.vocab import Seq2SeqVocab


@pytest.fixture
def examples():
    return [
        Seq2SeqExample(
            word="kindness",
            input_chars=list("kindness"),
            output_chars=list("kind") + ["+"] + list("ness"),
            is_allomorphic=False,
        ),
        Seq2SeqExample(
            word="butter",
            input_chars=list("butter"),
            output_chars=list("butter"),
            is_allomorphic=False,
        ),
    ]


@pytest.fixture
def vocab(examples):
    return Seq2SeqVocab.from_chars(
        [ex.input_chars for ex in examples],
        [ex.output_chars for ex in examples],
    )


class TestSeq2SeqDataset:
    def test_length(self, examples, vocab):
        ds = Seq2SeqDataset(examples, vocab)
        assert len(ds) == 2

    def test_item_keys(self, examples, vocab):
        ds = Seq2SeqDataset(examples, vocab)
        item = ds[0]
        assert "src" in item
        assert "tgt_input" in item
        assert "tgt_output" in item

    def test_tgt_input_starts_with_bos(self, examples, vocab):
        ds = Seq2SeqDataset(examples, vocab)
        item = ds[0]
        assert item["tgt_input"][0] == vocab.bos_idx

    def test_tgt_output_ends_with_eos(self, examples, vocab):
        ds = Seq2SeqDataset(examples, vocab)
        item = ds[0]
        assert item["tgt_output"][-1] == vocab.eos_idx

    def test_collate_batch(self, examples, vocab):
        ds = Seq2SeqDataset(examples, vocab)
        batch = Seq2SeqDataset.collate([ds[0], ds[1]])
        assert batch["src"].shape[0] == 2
        assert batch["tgt_input"].shape[0] == 2
        assert batch["tgt_output"].shape[0] == 2
        assert batch["src_lengths"].shape == (2,)
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_dataset.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/dataset.py
"""PyTorch Dataset for seq2seq pointer-generator training."""

from __future__ import annotations

import torch
from torch.utils.data import Dataset

from phonolex_tokenizer.data.seq2seq_loader import Seq2SeqExample
from phonolex_tokenizer.seq2seq.vocab import Seq2SeqVocab


class Seq2SeqDataset(Dataset):
    """Dataset of character-level seq2seq pairs.

    Each item returns:
        src: list[int] — encoded input characters.
        tgt_input: list[int] — [BOS] + encoded output chars (decoder input).
        tgt_output: list[int] — encoded output chars + [EOS] (decoder target).
    """

    def __init__(
        self,
        examples: list[Seq2SeqExample],
        vocab: Seq2SeqVocab,
    ) -> None:
        self.examples = examples
        self.vocab = vocab

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> dict[str, list[int]]:
        ex = self.examples[idx]
        src = self.vocab.encode_input(ex.word)
        out_ids = [self.vocab.char_to_idx(c) for c in ex.output_chars]

        tgt_input = [self.vocab.bos_idx] + out_ids
        tgt_output = out_ids + [self.vocab.eos_idx]

        return {
            "src": src,
            "tgt_input": tgt_input,
            "tgt_output": tgt_output,
        }

    @staticmethod
    def collate(batch: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]:
        """Collate a list of items into padded tensors."""
        src_seqs = [item["src"] for item in batch]
        tgt_in_seqs = [item["tgt_input"] for item in batch]
        tgt_out_seqs = [item["tgt_output"] for item in batch]

        src_lengths = torch.tensor([len(s) for s in src_seqs], dtype=torch.long)

        src_padded = _pad(src_seqs, pad_value=0)
        tgt_in_padded = _pad(tgt_in_seqs, pad_value=0)
        tgt_out_padded = _pad(tgt_out_seqs, pad_value=0)

        return {
            "src": src_padded,
            "src_lengths": src_lengths,
            "tgt_input": tgt_in_padded,
            "tgt_output": tgt_out_padded,
        }


def _pad(sequences: list[list[int]], pad_value: int = 0) -> torch.Tensor:
    max_len = max(len(s) for s in sequences)
    padded = [s + [pad_value] * (max_len - len(s)) for s in sequences]
    return torch.tensor(padded, dtype=torch.long)
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_dataset.py -v Expected: All 5 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/dataset.py packages/tokenizer/tests/test_seq2seq_dataset.py
git commit -m "feat(tokenizer): PyTorch dataset for seq2seq pointer-generator training"

Task 7: Seq2Seq Decomposer Wrapper

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/seq2seq/decomposer.py - Test: packages/tokenizer/tests/test_seq2seq_decomposer.py

High-level wrapper that provides decompose_batch(words) → list[MorphTree] — the same interface as the existing Decomposer, plus training logic.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_decomposer.py
"""Tests for the Seq2SeqDecomposer wrapper."""

import tempfile
import torch
import pytest

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.decomposer.schema import Mono, Split
from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer


@pytest.fixture
def small_positives():
    return [
        DecompositionExample("kindness", "kind", "ness", MorphLabel.SUFFIX, False),
        DecompositionExample("unhappy", "happy", "un", MorphLabel.PREFIX, False),
        DecompositionExample("happily", "happy", "ly", MorphLabel.SUFFIX, True),
    ]


@pytest.fixture
def small_negatives():
    return [MonoExample("butter"), MonoExample("hammer")]


@pytest.fixture
def decomposer(small_positives, small_negatives):
    return Seq2SeqDecomposer.build(
        small_positives, small_negatives,
        embed_dim=16, enc_hidden_dim=16,
    )


class TestSeq2SeqDecomposer:
    def test_build(self, decomposer):
        assert decomposer is not None
        params = sum(p.numel() for p in decomposer.get_parameters())
        assert params > 0

    def test_train_epoch_returns_loss(self, decomposer, small_positives, small_negatives):
        optimizer = torch.optim.Adam(decomposer.get_parameters(), lr=0.001)
        loss = decomposer.train_epoch(
            small_positives, small_negatives, optimizer, batch_size=2,
        )
        assert isinstance(loss, float)
        assert loss > 0

    def test_decompose_batch_returns_morph_trees(self, decomposer):
        results = decomposer.decompose_batch(["butter", "kindness"])
        assert len(results) == 2
        for r in results:
            assert isinstance(r, (Mono, Split))

    def test_single_char_word(self, decomposer):
        results = decomposer.decompose_batch(["a"])
        assert len(results) == 1
        assert isinstance(results[0], Mono)

    def test_save_load_roundtrip(self, decomposer):
        with tempfile.TemporaryDirectory() as tmpdir:
            decomposer.save(tmpdir)
            loaded = Seq2SeqDecomposer.load(tmpdir, device=torch.device("cpu"))
            results = loaded.decompose_batch(["butter"])
            assert len(results) == 1
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_decomposer.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/seq2seq/decomposer.py
"""Seq2SeqDecomposer — train/infer/save/load wrapper for the pointer-generator.

Provides the same decompose_batch(words) → list[MorphTree] interface as the
existing Decomposer, enabling drop-in replacement for evaluation and governor
integration.
"""

from __future__ import annotations

import json
import random
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.data.seq2seq_loader import to_seq2seq_examples
from phonolex_tokenizer.decomposer.schema import Mono, MorphTree, Split
from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.dataset import Seq2SeqDataset
from phonolex_tokenizer.seq2seq.decode import greedy_decode, parse_morphemes
from phonolex_tokenizer.seq2seq.model import PointerGeneratorNet
from phonolex_tokenizer.seq2seq.vocab import Seq2SeqVocab


class Seq2SeqDecomposer:
    """Pointer-generator canonical morphological decomposer."""

    def __init__(
        self,
        vocab: Seq2SeqVocab,
        model: PointerGeneratorNet,
        device: torch.device,
    ) -> None:
        self.vocab = vocab
        self.model = model
        self._device = device

    @staticmethod
    def _detect_device() -> torch.device:
        if torch.backends.mps.is_available():
            return torch.device("mps")
        if torch.cuda.is_available():
            return torch.device("cuda")
        return torch.device("cpu")

    @classmethod
    def build(
        cls,
        positives: list[DecompositionExample],
        negatives: list[MonoExample],
        embed_dim: int = 64,
        enc_hidden_dim: int = 128,
        num_enc_layers: int = 2,
        dropout: float = 0.1,
        device: torch.device | None = None,
    ) -> Seq2SeqDecomposer:
        """Build a Seq2SeqDecomposer from training data."""
        if device is None:
            device = cls._detect_device()

        examples = to_seq2seq_examples(positives, negatives)
        vocab = Seq2SeqVocab.from_chars(
            [ex.input_chars for ex in examples],
            [ex.output_chars for ex in examples],
        )

        dec_hidden_dim = enc_hidden_dim * 2
        model = PointerGeneratorNet(
            vocab_size=vocab.size,
            embed_dim=embed_dim,
            enc_hidden_dim=enc_hidden_dim,
            dec_hidden_dim=dec_hidden_dim,
            num_enc_layers=num_enc_layers,
            dropout=dropout,
        ).to(device)

        return cls(vocab=vocab, model=model, device=device)

    def get_parameters(self) -> list[nn.Parameter]:
        return list(self.model.parameters())

    def train_epoch(
        self,
        positives: list[DecompositionExample],
        negatives: list[MonoExample],
        optimizer: torch.optim.Optimizer,
        batch_size: int = 64,
        max_grad_norm: float = 5.0,
    ) -> float:
        """Train for one epoch. Returns average loss."""
        examples = to_seq2seq_examples(positives, negatives)
        random.shuffle(examples)

        dataset = Seq2SeqDataset(examples, self.vocab)
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=Seq2SeqDataset.collate,
        )

        self.model.train()
        total_loss = 0.0
        n_batches = 0

        for batch in loader:
            src = batch["src"].to(self._device)
            src_lengths = batch["src_lengths"].to(self._device)
            tgt_input = batch["tgt_input"].to(self._device)
            tgt_output = batch["tgt_output"].to(self._device)

            log_probs = self.model(src, src_lengths, tgt_input)

            # NLL loss — ignore padding (idx 0)
            log_probs_flat = log_probs.reshape(-1, log_probs.size(-1))
            targets_flat = tgt_output.reshape(-1)
            loss = nn.functional.nll_loss(
                log_probs_flat, targets_flat, ignore_index=self.vocab.pad_idx,
            )

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.get_parameters(), max_grad_norm)
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        return total_loss / max(n_batches, 1)

    def decompose_batch(self, words: list[str]) -> list[MorphTree]:
        """Decompose a batch of words into MorphTree structures."""
        if not words:
            return []

        # Handle trivially short words
        trivial_idx: dict[int, MorphTree] = {}
        nontrivial_words: list[str] = []
        nontrivial_positions: list[int] = []

        for i, word in enumerate(words):
            if len(word) < 2:
                trivial_idx[i] = Mono(word=word)
            else:
                nontrivial_words.append(word)
                nontrivial_positions.append(i)

        if not nontrivial_words:
            return [trivial_idx[i] for i in range(len(words))]

        # Encode and decode
        src_ids = [self.vocab.encode_input(w) for w in nontrivial_words]
        src_padded, src_lengths = self.vocab.pad_batch(src_ids)
        src_padded = src_padded.to(self._device)
        src_lengths = src_lengths.to(self._device)

        decoded_strs = greedy_decode(
            self.model, self.vocab,
            src_padded, src_lengths,
        )

        # Parse into MorphTrees
        nt_results: list[MorphTree] = []
        for word, decoded in zip(nontrivial_words, decoded_strs):
            morphemes, labels = parse_morphemes(decoded)

            if len(morphemes) <= 1:
                nt_results.append(Mono(word=word))
            elif len(morphemes) == 2:
                # Binary split — matches existing Split interface
                if labels[0] == MorphLabel.PREFIX:
                    nt_results.append(Split(
                        word=word, base=morphemes[1], affix=morphemes[0],
                        label=MorphLabel.PREFIX,
                    ))
                else:
                    label = MorphLabel.SUFFIX
                    nt_results.append(Split(
                        word=word, base=morphemes[0], affix=morphemes[1],
                        label=label,
                    ))
            else:
                # 3+ morphemes — nest as right-branching splits
                tree = _build_nested_tree(word, morphemes, labels)
                nt_results.append(tree)

        # Merge back in original order
        results: list[MorphTree] = []
        nt_iter = iter(nt_results)
        for i in range(len(words)):
            if i in trivial_idx:
                results.append(trivial_idx[i])
            else:
                results.append(next(nt_iter))
        return results

    def save(self, path: str | Path) -> None:
        """Save model, vocab, and config to a directory."""
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        self.vocab.save(path / "vocab.json")

        torch.save(
            {
                "model": self.model.state_dict(),
                "config": {
                    "vocab_size": self.vocab.size,
                    "embed_dim": self.model.embedding.embedding_dim,
                    "enc_hidden_dim": self.model.enc_hidden_dim,
                    "dec_hidden_dim": self.model.dec_hidden_dim,
                    "num_enc_layers": self.model.num_enc_layers,
                },
            },
            path / "model.pt",
        )

    @classmethod
    def load(
        cls,
        path: str | Path,
        device: torch.device | None = None,
    ) -> Seq2SeqDecomposer:
        """Load a saved Seq2SeqDecomposer."""
        if device is None:
            device = cls._detect_device()

        path = Path(path)
        vocab = Seq2SeqVocab.load(path / "vocab.json")

        checkpoint = torch.load(
            path / "model.pt", weights_only=True, map_location=device,
        )
        cfg = checkpoint["config"]

        model = PointerGeneratorNet(
            vocab_size=cfg["vocab_size"],
            embed_dim=cfg["embed_dim"],
            enc_hidden_dim=cfg["enc_hidden_dim"],
            dec_hidden_dim=cfg["dec_hidden_dim"],
            num_enc_layers=cfg["num_enc_layers"],
            dropout=0.0,
        ).to(device)
        model.load_state_dict(checkpoint["model"])

        return cls(vocab=vocab, model=model, device=device)


def _build_nested_tree(
    word: str,
    morphemes: list[str],
    labels: list[MorphLabel],
) -> MorphTree:
    """Build a nested MorphTree from a flat list of 3+ morphemes.

    Strategy: find the root, attach prefixes left-to-right and suffixes
    right-to-left as nested Split nodes.
    """
    root_idx = next(i for i, l in enumerate(labels) if l == MorphLabel.ROOT)

    # Start with the root
    tree: MorphTree = Mono(word=morphemes[root_idx])

    # Attach suffixes (right of root), innermost first
    for i in range(root_idx + 1, len(morphemes)):
        base_word = "".join(morphemes[root_idx : i + 1])
        tree = Split(
            word=base_word,
            base=morphemes[root_idx] if isinstance(tree, Mono) else tree.word,
            affix=morphemes[i],
            label=MorphLabel.SUFFIX,
            base_tree=tree,
        )

    # Attach prefixes (left of root), innermost first
    for i in range(root_idx - 1, -1, -1):
        current_word = "".join(morphemes[i:])
        tree = Split(
            word=current_word,
            base=tree.word,
            affix=morphemes[i],
            label=MorphLabel.PREFIX,
            base_tree=tree,
        )

    # Fix the top-level word to the original surface form
    if isinstance(tree, Split):
        tree = Split(
            word=word,
            base=tree.base,
            affix=tree.affix,
            label=tree.label,
            base_tree=tree.base_tree,
        )

    return tree
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_decomposer.py -v Expected: All 5 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/seq2seq/decomposer.py packages/tokenizer/tests/test_seq2seq_decomposer.py
git commit -m "feat(tokenizer): Seq2SeqDecomposer wrapper with train/infer/save/load"

Task 8: Seq2Seq Eval Benchmark Adapter

Files: - Create: packages/tokenizer/src/phonolex_tokenizer/eval/seq2seq_benchmark.py - Test: packages/tokenizer/tests/test_seq2seq_benchmark.py

Adapter that runs the existing 5 metrics against a Seq2SeqDecomposer, reusing the metric functions from decomposer_metrics.py.

  • [ ] Step 1: Write the failing tests
# packages/tokenizer/tests/test_seq2seq_benchmark.py
"""Tests for seq2seq benchmark adapter."""

import torch
import pytest

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.eval.seq2seq_benchmark import run_seq2seq_benchmark
from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer


@pytest.fixture
def small_decomposer():
    positives = [
        DecompositionExample("kindness", "kind", "ness", MorphLabel.SUFFIX, False),
    ]
    negatives = [MonoExample("butter")]
    return Seq2SeqDecomposer.build(
        positives, negatives, embed_dim=16, enc_hidden_dim=16,
    )


class TestSeq2SeqBenchmark:
    def test_returns_all_metrics(self, small_decomposer):
        positives = [
            DecompositionExample("kindness", "kind", "ness", MorphLabel.SUFFIX, False),
        ]
        negatives = [MonoExample("butter")]
        results = run_seq2seq_benchmark(
            small_decomposer, positives, negatives,
        )
        assert "decomposition_accuracy" in results
        assert "canonical_form_accuracy" in results
        assert "hard_negative_precision" in results
        assert "allomorphic_accuracy" in results
        assert "surface_faithful_accuracy" in results
        assert "count" in results

    def test_metrics_are_floats(self, small_decomposer):
        positives = [
            DecompositionExample("kindness", "kind", "ness", MorphLabel.SUFFIX, False),
        ]
        negatives = [MonoExample("butter")]
        results = run_seq2seq_benchmark(
            small_decomposer, positives, negatives,
        )
        for key in ["decomposition_accuracy", "canonical_form_accuracy",
                     "hard_negative_precision"]:
            assert isinstance(results[key], float)

    def test_empty_inputs(self, small_decomposer):
        results = run_seq2seq_benchmark(small_decomposer, [], [])
        assert results["count"] == 0
  • [ ] Step 2: Run tests to verify they fail

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_benchmark.py -v Expected: FAIL with ModuleNotFoundError

  • [ ] Step 3: Write the implementation
# packages/tokenizer/src/phonolex_tokenizer/eval/seq2seq_benchmark.py
"""Benchmark adapter for the Seq2SeqDecomposer.

Provides the same interface as run_decomposer_benchmark but accepts a
Seq2SeqDecomposer. Reuses metric functions from decomposer_metrics.
"""

from __future__ import annotations

import logging

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.decomposer.schema import Mono, MorphTree, Split
from phonolex_tokenizer.eval.decomposer_metrics import (
    canonical_form_accuracy,
    decomposition_accuracy,
    hard_negative_precision,
)
from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer

logger = logging.getLogger(__name__)


def run_seq2seq_benchmark(
    decomposer: Seq2SeqDecomposer,
    test_positives: list[DecompositionExample],
    test_negatives: list[MonoExample],
) -> dict:
    """Evaluate a Seq2SeqDecomposer against labelled test sets.

    Same 5 metrics as run_decomposer_benchmark:
    - decomposition_accuracy
    - canonical_form_accuracy
    - hard_negative_precision
    - allomorphic_accuracy
    - surface_faithful_accuracy
    """
    positive_golds: list[MorphTree] = []
    positive_words: list[str] = []
    for ex in test_positives:
        gold = Split(word=ex.word, base=ex.base, affix=ex.affix, label=ex.label)
        positive_golds.append(gold)
        positive_words.append(ex.word)

    negative_golds: list[MorphTree] = []
    negative_words: list[str] = []
    for ex in test_negatives:
        negative_golds.append(Mono(word=ex.word))
        negative_words.append(ex.word)

    all_words = positive_words + negative_words
    all_golds: list[MorphTree] = positive_golds + negative_golds

    if not all_words:
        return {
            "decomposition_accuracy": 1.0,
            "canonical_form_accuracy": 1.0,
            "hard_negative_precision": 1.0,
            "allomorphic_accuracy": 1.0,
            "surface_faithful_accuracy": 1.0,
            "count": 0,
        }

    all_preds = decomposer.decompose_batch(all_words)

    dec_acc = decomposition_accuracy(all_golds, all_preds)
    can_acc = canonical_form_accuracy(all_golds, all_preds)
    hnp = hard_negative_precision(all_golds, all_preds)

    pos_preds = all_preds[: len(test_positives)]

    allomorphic_golds = [
        g for g, ex in zip(positive_golds, test_positives) if ex.is_allomorphic
    ]
    allomorphic_preds = [
        p for p, ex in zip(pos_preds, test_positives) if ex.is_allomorphic
    ]
    allo_acc = decomposition_accuracy(allomorphic_golds, allomorphic_preds)

    surface_golds = [
        g for g, ex in zip(positive_golds, test_positives) if not ex.is_allomorphic
    ]
    surface_preds = [
        p for p, ex in zip(pos_preds, test_positives) if not ex.is_allomorphic
    ]
    surf_acc = decomposition_accuracy(surface_golds, surface_preds)

    return {
        "decomposition_accuracy": dec_acc,
        "canonical_form_accuracy": can_acc,
        "hard_negative_precision": hnp,
        "allomorphic_accuracy": allo_acc,
        "surface_faithful_accuracy": surf_acc,
        "count": len(all_words),
    }
  • [ ] Step 4: Run tests to verify they pass

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_benchmark.py -v Expected: All 3 tests PASS.

  • [ ] Step 5: Commit
git add packages/tokenizer/src/phonolex_tokenizer/eval/seq2seq_benchmark.py packages/tokenizer/tests/test_seq2seq_benchmark.py
git commit -m "feat(tokenizer): eval benchmark adapter for seq2seq decomposer"

Task 9: Training Script

Files: - Create: packages/tokenizer/scripts/train_seq2seq.py

Production training entry point. Mirrors the existing train_decomposer.py pattern.

  • [ ] Step 1: Write the training script
# packages/tokenizer/scripts/train_seq2seq.py
"""Train the pointer-generator canonical morphological decomposer."""

import argparse
import hashlib
import logging
import math
import time

import torch
from pathlib import Path

from phonolex_tokenizer.data.canonical_loader import load_morphynet_canonical
from phonolex_tokenizer.data.negatives import find_opaque_words, find_root_only_words
from phonolex_tokenizer.eval.seq2seq_benchmark import run_seq2seq_benchmark
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)


def _split_key(word: str) -> int:
    return int(hashlib.md5(word.encode()).hexdigest(), 16) % 100


def main():
    parser = argparse.ArgumentParser(
        description="Train pointer-generator canonical decomposer",
    )
    parser.add_argument(
        "--data-dir", type=Path,
        default=Path("packages/tokenizer/data/morphynet"),
    )
    parser.add_argument(
        "--output-dir", type=Path,
        default=Path("packages/tokenizer/models/seq2seq"),
    )
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight-decay", type=float, default=1e-5)
    parser.add_argument("--embed-dim", type=int, default=64)
    parser.add_argument("--enc-hidden-dim", type=int, default=128)
    parser.add_argument("--num-enc-layers", type=int, default=2)
    parser.add_argument("--dropout", type=float, default=0.1)
    parser.add_argument("--eval-every", type=int, default=5)
    args = parser.parse_args()

    # Load data
    logger.info("Loading MorphyNet canonical data from %s", args.data_dir)
    positives = load_morphynet_canonical(args.data_dir)
    logger.info("Loaded %d decomposition examples", len(positives))

    # Build negatives
    targets = {ex.word for ex in positives}
    sources = {ex.base for ex in positives}
    opaque_vocab = set()
    try:
        from phonolex_data.loaders.cmu import load_cmu_dict
        cmu = load_cmu_dict()
        opaque_vocab = set(cmu.keys())
    except Exception:
        opaque_vocab = sources | targets
    opaque = find_opaque_words(targets, opaque_vocab)
    roots = find_root_only_words(sources, targets)
    negatives = opaque + roots
    logger.info(
        "Built %d negative examples (%d opaque, %d roots)",
        len(negatives), len(opaque), len(roots),
    )

    # Split: train (hash 0-79), dev (80-89), test (90-99)
    train_pos = [ex for ex in positives if _split_key(ex.word) < 80]
    dev_pos = [ex for ex in positives if 80 <= _split_key(ex.word) < 90]
    test_pos = [ex for ex in positives if _split_key(ex.word) >= 90]

    train_neg = [ex for ex in negatives if _split_key(ex.word) < 80]
    dev_neg = [ex for ex in negatives if 80 <= _split_key(ex.word) < 90]
    test_neg = [ex for ex in negatives if _split_key(ex.word) >= 90]

    logger.info(
        "Split: train=%d/%d, dev=%d/%d, test=%d/%d (pos/neg)",
        len(train_pos), len(train_neg),
        len(dev_pos), len(dev_neg),
        len(test_pos), len(test_neg),
    )

    # Build model
    decomposer = Seq2SeqDecomposer.build(
        train_pos, train_neg,
        embed_dim=args.embed_dim,
        enc_hidden_dim=args.enc_hidden_dim,
        num_enc_layers=args.num_enc_layers,
        dropout=args.dropout,
    )
    params = sum(p.numel() for p in decomposer.get_parameters())
    logger.info("Model built: %d parameters", params)

    # Optimizer + scheduler
    optimizer = torch.optim.AdamW(
        decomposer.get_parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=args.epochs, eta_min=1e-5,
    )

    # Train
    best_dev_score = 0.0
    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        loss = decomposer.train_epoch(
            train_pos, train_neg, optimizer, batch_size=args.batch_size,
        )
        scheduler.step()
        elapsed = time.time() - t0
        lr = scheduler.get_last_lr()[0]
        logger.info(
            "Epoch %d/%d — loss: %.4f, lr: %.6f, time: %.0fs",
            epoch, args.epochs, loss, lr, elapsed,
        )

        # Evaluate on dev set periodically
        if epoch % args.eval_every == 0 or epoch == args.epochs:
            dev_results = run_seq2seq_benchmark(decomposer, dev_pos, dev_neg)
            logger.info(
                "  dev — decomp: %.4f, canonical: %.4f, hard_neg: %.4f, "
                "allo: %.4f, surface: %.4f",
                dev_results["decomposition_accuracy"],
                dev_results["canonical_form_accuracy"],
                dev_results["hard_negative_precision"],
                dev_results["allomorphic_accuracy"],
                dev_results["surface_faithful_accuracy"],
            )
            if dev_results["decomposition_accuracy"] > best_dev_score:
                best_dev_score = dev_results["decomposition_accuracy"]
                decomposer.save(args.output_dir)
                logger.info("  New best dev score: %.4f — saved", best_dev_score)

    # Final test eval
    logger.info("Loading best model for final test evaluation")
    best_decomposer = Seq2SeqDecomposer.load(
        args.output_dir, device=decomposer._device,
    )
    test_results = run_seq2seq_benchmark(best_decomposer, test_pos, test_neg)
    logger.info("Test results:")
    for k, v in test_results.items():
        logger.info("  %s: %s", k, v)


if __name__ == "__main__":
    main()
  • [ ] Step 2: Verify the script parses without errors

Run: cd packages/tokenizer && uv run python -c "import ast; ast.parse(open('scripts/train_seq2seq.py').read()); print('OK')" Expected: OK

  • [ ] Step 3: Commit
git add packages/tokenizer/scripts/train_seq2seq.py
git commit -m "feat(tokenizer): training script for pointer-generator decomposer"

Task 10: Evaluation Script

Files: - Create: packages/tokenizer/scripts/eval_seq2seq.py

Standalone evaluation entry point that outputs JSON (for autoresearch compatibility).

  • [ ] Step 1: Write the evaluation script
# packages/tokenizer/scripts/eval_seq2seq.py
"""Evaluate the pointer-generator canonical morphological decomposer.

Outputs JSON with the same schema as eval_decomposer.py for autoresearch
compatibility.
"""

import argparse
import hashlib
import json
import logging

import torch
from pathlib import Path

from phonolex_tokenizer.data.canonical_loader import load_morphynet_canonical
from phonolex_tokenizer.data.negatives import find_opaque_words, find_root_only_words
from phonolex_tokenizer.eval.seq2seq_benchmark import run_seq2seq_benchmark
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)


def _split_key(word: str) -> int:
    return int(hashlib.md5(word.encode()).hexdigest(), 16) % 100


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate pointer-generator canonical decomposer",
    )
    parser.add_argument(
        "--model-dir", type=Path,
        default=Path("packages/tokenizer/models/seq2seq"),
    )
    parser.add_argument(
        "--data-dir", type=Path,
        default=Path("packages/tokenizer/data/morphynet"),
    )
    parser.add_argument("--json", action="store_true", help="Output JSON only")
    args = parser.parse_args()

    logger.info("Loading model from %s", args.model_dir)
    decomposer = Seq2SeqDecomposer.load(args.model_dir, device=torch.device("cpu"))

    positives = load_morphynet_canonical(args.data_dir)
    targets = {ex.word for ex in positives}
    sources = {ex.base for ex in positives}

    opaque_vocab = set()
    try:
        from phonolex_data.loaders.cmu import load_cmu_dict
        cmu = load_cmu_dict()
        opaque_vocab = set(cmu.keys())
    except Exception:
        opaque_vocab = sources | targets
    opaque = find_opaque_words(targets, opaque_vocab)
    roots = find_root_only_words(sources, targets)
    negatives = opaque + roots

    test_positives = [ex for ex in positives if _split_key(ex.word) >= 90]
    test_negatives = [ex for ex in negatives if _split_key(ex.word) >= 90]

    logger.info("Test set: %d positives, %d negatives", len(test_positives), len(test_negatives))

    benchmark = run_seq2seq_benchmark(decomposer, test_positives, test_negatives)

    output = {
        "decomposition_accuracy": round(benchmark["decomposition_accuracy"], 4),
        "canonical_form_accuracy": round(benchmark["canonical_form_accuracy"], 4),
        "hard_negative_precision": round(benchmark["hard_negative_precision"], 4),
        "allomorphic_accuracy": round(benchmark.get("allomorphic_accuracy", 0.0), 4),
        "surface_faithful_accuracy": round(benchmark.get("surface_faithful_accuracy", 0.0), 4),
        "count": benchmark["count"],
        "score": round(benchmark["decomposition_accuracy"], 4),
    }

    if args.json:
        print(json.dumps(output))
    else:
        logger.info("Results:")
        for k, v in output.items():
            logger.info("  %s: %s", k, v)

    return output


if __name__ == "__main__":
    main()
  • [ ] Step 2: Verify the script parses without errors

Run: cd packages/tokenizer && uv run python -c "import ast; ast.parse(open('scripts/eval_seq2seq.py').read()); print('OK')" Expected: OK

  • [ ] Step 3: Commit
git add packages/tokenizer/scripts/eval_seq2seq.py
git commit -m "feat(tokenizer): evaluation script for pointer-generator decomposer"

Task 11: Full Integration Test

Files: - Create: packages/tokenizer/tests/test_seq2seq_integration.py

End-to-end test: build from small data, train 2 epochs, evaluate, save/load, verify metrics exist.

  • [ ] Step 1: Write the integration test
# packages/tokenizer/tests/test_seq2seq_integration.py
"""End-to-end integration test for the seq2seq pointer-generator decomposer."""

import tempfile
import torch
import pytest

from phonolex_tokenizer.data.canonical_loader import DecompositionExample
from phonolex_tokenizer.data.negatives import MonoExample
from phonolex_tokenizer.eval.seq2seq_benchmark import run_seq2seq_benchmark
from phonolex_tokenizer.model.schema import MorphLabel
from phonolex_tokenizer.seq2seq.decomposer import Seq2SeqDecomposer


@pytest.fixture
def training_data():
    positives = [
        DecompositionExample("kindness", "kind", "ness", MorphLabel.SUFFIX, False),
        DecompositionExample("happily", "happy", "ly", MorphLabel.SUFFIX, True),
        DecompositionExample("unhappy", "happy", "un", MorphLabel.PREFIX, False),
        DecompositionExample("teacher", "teach", "er", MorphLabel.SUFFIX, False),
        DecompositionExample("reading", "read", "ing", MorphLabel.SUFFIX, False),
        DecompositionExample("unkind", "kind", "un", MorphLabel.PREFIX, False),
        DecompositionExample("darkness", "dark", "ness", MorphLabel.SUFFIX, False),
        DecompositionExample("rewrite", "write", "re", MorphLabel.PREFIX, False),
    ]
    negatives = [
        MonoExample("butter"), MonoExample("hammer"), MonoExample("carpet"),
        MonoExample("winter"), MonoExample("sister"),
    ]
    return positives, negatives


class TestSeq2SeqIntegration:
    def test_train_eval_save_load(self, training_data):
        positives, negatives = training_data

        # Build
        decomposer = Seq2SeqDecomposer.build(
            positives, negatives,
            embed_dim=16, enc_hidden_dim=16, dropout=0.0,
        )
        params = sum(p.numel() for p in decomposer.get_parameters())
        assert params > 0

        # Train 2 epochs
        optimizer = torch.optim.Adam(decomposer.get_parameters(), lr=0.01)
        loss1 = decomposer.train_epoch(positives, negatives, optimizer, batch_size=4)
        loss2 = decomposer.train_epoch(positives, negatives, optimizer, batch_size=4)
        assert loss2 <= loss1 * 1.5  # loss should not explode

        # Eval
        results = run_seq2seq_benchmark(decomposer, positives[:3], negatives[:2])
        assert "decomposition_accuracy" in results
        assert 0.0 <= results["decomposition_accuracy"] <= 1.0
        assert results["count"] == 5

        # Save/load roundtrip
        with tempfile.TemporaryDirectory() as tmpdir:
            decomposer.save(tmpdir)
            loaded = Seq2SeqDecomposer.load(tmpdir, device=torch.device("cpu"))
            results2 = run_seq2seq_benchmark(loaded, positives[:3], negatives[:2])
            assert results2["decomposition_accuracy"] == results["decomposition_accuracy"]

    def test_loss_decreases_over_epochs(self, training_data):
        positives, negatives = training_data

        decomposer = Seq2SeqDecomposer.build(
            positives, negatives,
            embed_dim=16, enc_hidden_dim=16, dropout=0.0,
        )
        optimizer = torch.optim.Adam(decomposer.get_parameters(), lr=0.01)

        losses = []
        for _ in range(5):
            loss = decomposer.train_epoch(positives, negatives, optimizer, batch_size=4)
            losses.append(loss)

        # Loss at epoch 5 should be lower than epoch 1
        assert losses[-1] < losses[0]
  • [ ] Step 2: Run the integration test

Run: cd packages/tokenizer && uv run python -m pytest tests/test_seq2seq_integration.py -v Expected: All 2 tests PASS.

  • [ ] Step 3: Run the full test suite to check for regressions

Run: cd packages/tokenizer && uv run python -m pytest tests/ -v Expected: All existing tests still pass, plus all new seq2seq tests pass.

  • [ ] Step 4: Commit
git add packages/tokenizer/tests/test_seq2seq_integration.py
git commit -m "test(tokenizer): end-to-end integration tests for seq2seq decomposer"

Task 12: Run Experiment 0 — Pointer-Generator Baseline

Files: - None created (uses existing scripts)

Run the baseline 20-epoch training on MPS and record results.

  • [ ] Step 1: Run baseline training

Run:

cd packages/tokenizer && uv run python scripts/train_seq2seq.py \
    --epochs 20 \
    --batch-size 64 \
    --lr 0.001 \
    --embed-dim 64 \
    --enc-hidden-dim 128 \
    --num-enc-layers 2 \
    --dropout 0.1 \
    --eval-every 5

Monitor for: loss decreasing, dev metrics reported at epochs 5/10/15/20, total wall time.

  • [ ] Step 2: Run evaluation on test set

Run:

cd packages/tokenizer && uv run python scripts/eval_seq2seq.py --json

Record the JSON output. Compare decomposition_accuracy to the 81.5% baseline.

  • [ ] Step 3: Record results

Save the JSON output and training logs. If decomposition_accuracy > 0.815, proceed to experiments 1-5. If not, analyze failure modes before continuing.


Summary

Task What Files Tests
1 Seq2seq data loader data/seq2seq_loader.py 5
2 Seq2seq vocabulary seq2seq/vocab.py 6
3 Attention module seq2seq/attention.py 3
4 Pointer-generator model seq2seq/model.py 4
5 Greedy decoder + parser seq2seq/decode.py 7
6 PyTorch dataset seq2seq/dataset.py 5
7 Decomposer wrapper seq2seq/decomposer.py 5
8 Eval benchmark adapter eval/seq2seq_benchmark.py 3
9 Training script scripts/train_seq2seq.py
10 Evaluation script scripts/eval_seq2seq.py
11 Integration tests tests/test_seq2seq_integration.py 2
12 Run experiment 0
Total 10 new files 40 tests