Skip to content

Fix Unigram trainer prune loss to use per-piece alternative count#2070

Open
hunter-heidenreich wants to merge 2 commits into
huggingface:mainfrom
hunter-heidenreich:fix/unigram-prune-per-piece-alternatives
Open

Fix Unigram trainer prune loss to use per-piece alternative count#2070
hunter-heidenreich wants to merge 2 commits into
huggingface:mainfrom
hunter-heidenreich:fix/unigram-prune-per-piece-alternatives

Conversation

@hunter-heidenreich
Copy link
Copy Markdown

What

UnigramTrainer::prune_sentence_pieces computes the prune-loss term logsum_alt with alternatives.len() (the total number of pieces) where it should use alternatives[id].len() (the number of alternative segmentations for the one
piece being scored).

- let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln();
+ let logsum_alt = (sum + freq[id] * (alternatives[id].len() - 1) as f64).ln();

alternatives is the per-piece vector Vec<Vec<usize>>, so alternatives.len() is just pieces.len() (the whole vocabulary size) rather than the handful of pieces id would be re-segmented into if removed.

Closes #2069. (Same defect previously reported in #1536, closed without a reply.)

Why it matters

logsum_alt is the log of the corpus token total after removing piece id and reassigning its frequency to its alternatives. That total should grow by freq[id] * (number_of_alternatives - 1). Using the total piece count instead multiplies the growth term by the entire vocabulary size, inflating the piece's loss. Because the inflation scales with freq[id], it systematically over-values high-frequency pieces in the keep/drop ranking.

The giveaway is the comment, ported verbatim from SentencePiece but with the [i] dropped, and the code followed the comment.

Parity with SentencePiece

This restores the reference behavior. SentencePiece src/unigram_model_trainer.cc, Trainer::PruneSentencePieces:

// new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size()
const float logsum_alt = std::log(
    static_cast<double>(sum + freq[i] * (alternatives[i].size() - 1)));

Does it change real tokenizers? Yes... bounded, and only when retraining

This affects only new train() runs; existing serialized tokenizer.json files and all inference/encoding are byte-for-byte unaffected. The divergence is concentrated where pruning does most of the work (small vocab_size) and disappears once the vocab is large enough that the EM threshold and finalize trim the affected margin anyway.

Training Unigram on Pride and Prejudice (Project Gutenberg 1342, words lowercased), shrinking_factor=0.9, single-threaded:

vocab_size pieces changed (each way) Jaccard
1000 14 0.972
2000 9 0.991
4000 0 (identical) 1.000

At vocab_size=1000 the fix keeps more complete words (arrangement, benefit, explanation, inferior, leisure, original, similar); the buggy formula instead keeps short high-frequency pieces and fragments (into, act, best, friends, ten, tion, ize). The result is deterministic.

Reproduction script (train on main vs this branch, then diff)
#!/usr/bin/env python3
"""Reproduce the prune-loss fix's effect on a real corpus.

Run once on `main` (buggy) and once on this PR branch (fixed) — rebuild the
bindings (`maturin develop --release`) between checkouts — then --compare:

    python repro_pp.py --train --vocab-size 1000 --shrink 0.9 --out buggy.json   # on main
    python repro_pp.py --train --vocab-size 1000 --shrink 0.9 --out fixed.json   # on PR branch
    python repro_pp.py --compare fixed.json buggy.json
"""
import argparse, json, re, sys
from pathlib import Path

CACHE = Path(__file__).with_name("pp.words.json")
PG_URL = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt"

def load_corpus():
    if CACHE.exists():
        return json.loads(CACHE.read_text())
    import requests
    text = requests.get(PG_URL, headers={"User-Agent": "repro"}, timeout=60).text
    start, end = text.find("*** START OF"), text.find("*** END OF")
    if start != -1: text = text[text.find("\n", start) + 1:]
    if end != -1:   text = text[:end]
    words = re.findall(r"[a-z]+", text.lower())
    CACHE.write_text(json.dumps(words))
    return words

def train(words, vocab_size, shrink):
    from tokenizers import Tokenizer
    from tokenizers.models import Unigram
    from tokenizers.trainers import UnigramTrainer
    tok = Tokenizer(Unigram())
    tok.train_from_iterator(words, UnigramTrainer(
        vocab_size=vocab_size, special_tokens=[], shrinking_factor=shrink,
        max_piece_length=128, show_progress=False))
    return sorted(tok.get_vocab().keys())

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", action="store_true")
    ap.add_argument("--vocab-size", type=int, default=1000)
    ap.add_argument("--shrink", type=float, default=0.9)
    ap.add_argument("--out")
    ap.add_argument("--compare", nargs=2, metavar=("A", "B"))
    a = ap.parse_args()
    if a.compare:
        A = set(json.loads(Path(a.compare[0]).read_text()))
        B = set(json.loads(Path(a.compare[1]).read_text()))
        j = len(A & B) / len(A | B) if (A | B) else 1.0
        print(f"|A|={len(A)} |B|={len(B)} differ={len(A ^ B)} jaccard={j:.4f}")
        print(f"only in {a.compare[0]}: {sorted(A - B)}")
        print(f"only in {a.compare[1]}: {sorted(B - A)}")
        return
    if a.train:
        v = train(load_corpus(), a.vocab_size, a.shrink)
        if a.out: Path(a.out).write_text(json.dumps(v))
        print(f"vocab_size={a.vocab_size} shrink={a.shrink}: {len(v)} pieces", file=sys.stderr)
        return
    ap.error("specify --train or --compare")

if __name__ == "__main__":
    main()

Set RAYON_RS_NUM_THREADS=1 for a deterministic single-threaded run. The --compare differ=N value counts both directions (so differ=28 ⇒ 14 each way).

Tests

prune_sentence_pieces had no coverage. The existing training tests use a tiny corpus that breaks out of the EM loop before pruning runs. This adds:

  • test_prune_sentence_pieces_keeps_costly_alternative: a white-box unit test that locks the per-piece keep decision. It fails on the old line (keeps the cheap-to-replace piece) and passes on the fix, independent of the corpus above.
  • test_do_train_runs_prune: the first end-to-end test that actually exercises the EM + prune loop (coverage for the previously-untested path).

cargo fmt --check and cargo clippy --all-targets --all-features -- -D warnings are clean; models::unigram and the Python Unigram binding tests pass.

hunter-heidenreich and others added 2 commits May 24, 2026 10:29
prune_sentence_pieces computed logsum_alt with alternatives.len() (the
total piece count) instead of alternatives[id].len() (the alternatives
for the piece being scored), diverging from SentencePiece's
PruneSentencePieces. This inflates the loss for high-frequency pieces
and changes which pieces survive pruning.

Add the first tests for prune_sentence_pieces (previously uncovered):
a regression test locking the per-piece keep decision, and a smoke test
exercising the EM + prune loop end to end.

Refs huggingface#2069

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment-only follow-up to the per-piece alternatives fix (no logic change):

- Cite SentencePiece's Trainer::PruneSentencePieces at the fix site and note
  why the term must use alternatives[id], not alternatives.len(), to guard
  against silent re-drift of the ported comment.
- Tidy the drifted comment: (alternatives[i] - 1) -> (alternatives[i].size() - 1).
- Explain the slot arithmetic in test_prune_sentence_pieces_keeps_costly_alternative
  (why exactly one slot is contested, and what m/n padding is for).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unigram trainer: prune loss uses alternatives.len() instead of alternatives[id].len()

1 participant