Add lm head vocab histogram animation#814
Open
klei22 wants to merge 5 commits into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a new training-time visualization for tracking lm_head per-token vector magnitude evolution: TensorBoard histogram logging during training plus an interactive HTML export for post-hoc inspection.
Changes:
- Add TensorBoard histogram logging of
lm_headvocab-vector L2 magnitudes at a configurable interval. - Capture
lm_headmagnitude “snapshots” during training and export them as an interactive Plotly-based HTML report. - Add CLI flags and a demo script + docs to make the feature easy to run.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
train.py |
Implements histogram logging, snapshot capture, and HTML export hooks in the training loop. |
train_args.py |
Adds new CLI flags to enable/configure histogram logging and HTML export. |
demos/README.md |
Documents the new demo and how to view the TensorBoard/HTML outputs. |
demos/lm_head_vocab_histogram_demo.sh |
Provides a runnable example training command enabling the new logging/export. |
Comments suppressed due to low confidence (1)
train.py:1519
- Same issue as above: using
.get(...)onself.model.transformer(annn.ModuleDict) is likely to fail at runtime in multicontext mode. Use membership check +[]indexing (or another ModuleDict-safe accessor) instead.
lm_head = getattr(self.model, "lm_head", None)
if self.args.training_mode == "multicontext" and hasattr(self.model, "transformer"):
try:
dataset_idx = self.args.multicontext_datasets.index(target_dataset)
except ValueError:
dataset_idx = 0
lm_head = self.model.transformer.get(f"lm_head_{dataset_idx}", lm_head)
if lm_head is None or not hasattr(lm_head, "weight"):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1394
to
+1397
| or self.iter_num % self.args.log_lm_head_vocab_hist_interval != 0 | ||
| ): | ||
| return | ||
|
|
| dataset_idx = self.args.multicontext_datasets.index(target_dataset) | ||
| except ValueError: | ||
| dataset_idx = 0 | ||
| lm_head = self.model.transformer.get(f"lm_head_{dataset_idx}", lm_head) |
Comment on lines
+1460
to
+1461
| <script> | ||
| const snapshots = {payload}; |
| out_path = self.args.lm_head_vocab_hist_html_path or os.path.join( | ||
| self.args.out_dir, "lm_head_vocab_histogram.html" | ||
| ) | ||
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
Comment on lines
+1521
to
+1532
| with torch.no_grad(): | ||
| magnitudes = lm_head.weight.detach().norm(dim=1).float().cpu().tolist() | ||
| vocab_data = [] | ||
| for i, m in enumerate(magnitudes): | ||
| token_raw, token_display = self._get_vocab_label_parts(i) | ||
| vocab_data.append({"id": i, "magnitude": float(m), "token_raw": token_raw, "token_display": token_display}) | ||
| self.lm_head_hist_snapshots.append({ | ||
| "iter_num": int(self.iter_num), | ||
| "tokens_trained": float(tokens_trained), | ||
| "dataset": target_dataset, | ||
| "vocab": vocab_data, | ||
| }) |
Comment on lines
2454
to
2457
| # End of training actions | ||
| if self.iter_num > self.args.max_iters: | ||
| self._export_lm_head_vocab_histogram_html() | ||
| print(self.best_val_loss, self.best_iter, self.best_tokens) |
Comment on lines
+1454
to
+1456
| logging_group.add_argument('--log_lm_head_vocab_hist', default=False, action=argparse.BooleanOptionalAction, help='Log TensorBoard histogram of per-token lm_head vector magnitudes over training') | ||
| logging_group.add_argument('--log_lm_head_vocab_hist_interval', default=100, type=int, help='Training-step interval for logging lm_head vocab magnitude histogram') | ||
| logging_group.add_argument('--export_lm_head_vocab_hist_html', default=False, action=argparse.BooleanOptionalAction, help='Export an interactive HTML report of final lm_head vocab-vector magnitudes') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces a new feature for visualizing the evolution of
lm_headvocabulary vector magnitudes during model training, making it easier to analyze and debug model behavior. It adds both TensorBoard histogram logging and an interactive HTML export, along with the necessary configuration options and demo scripts. The changes are grouped as follows:New Visualization Features:
lm_headvector magnitudes as a histogram in TensorBoard during training, controlled by new command-line arguments. This enables users to track how the output layer's token representations change over time. [1] [2] [3] [4]lm_headvocab vector magnitudes, with sortable bars and hover labels for token ids and text. [1] [2]Configuration and Usability:
lm_headvocab histogram, including logging interval and HTML output path.demos/lm_head_vocab_histogram_demo.sh) and updated the documentation to guide users in running and visualizing the new feature. [1] [2]Internal Implementation:
lm_headvector magnitudes and associated token labels, supporting both TensorBoard and HTML visualization. [1] [2] [3] [4]These changes make it much easier to monitor and analyze how the model's output token representations evolve during training, both interactively and post-hoc.