fix(training): clarify mixed-precision optimizer-state setup#2248
fix(training): clarify mixed-precision optimizer-state setup#2248yuhezhang-ai wants to merge 14 commits into
Conversation
`build_lr_scheduler` computed `total_steps` via `len(step_scheduler.dataloader)`, which raises `NotImplementedError` for IterableDataset-backed dataloaders (e.g. `NanogptDataset`, `MegatronPretraining`). LLM pre-training recipes that stream tokens cannot use the LR scheduler today. Use `step_scheduler.epoch_len` instead, which is already in optimizer-step units (microbatches // grad_acc_steps) and is set to `None` by the step scheduler exactly when the dataloader has no `__len__`. In that branch, fall back to `step_scheduler.max_steps` and raise a clear error if the user has set neither. No behavior change for map-style datasets: the new path resolves to the same value (epochs * epoch_len, optionally capped at max_steps) without the redundant `// grad_acc_steps` that the step scheduler already applied. Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
`FSDP2Config.__post_init__` now constructs the default `MixedPrecisionPolicy` with `reduce_dtype=torch.float32` instead of `torch.bfloat16`. This matches Megatron-LM and Lingua: forward / backward stay in bf16 (fast), but the gradient reduce-scatter / all-reduce runs in fp32 so accumulation error does not compound at large DP world sizes. `param_dtype` and `output_dtype` are unchanged at `bfloat16`. This is a behavior change. Users who relied on the old all-bf16 reduction can opt back in by setting `mp_policy:` explicitly; the recipe to do so is recorded in `docs/breaking-changes.md` for 0.5.0. The companion docs commit adds a full mixed-precision guide (`docs/guides/mixed-precision.md`) plus a worked example in `examples/llm_pretrain/llama3_70b_pretrain.yaml`. Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
New guide `docs/guides/mixed-precision.md` walks through the two canonical FSDP2 precision patterns and the trap that catches users. - Pattern A (recommended for pre-training): `model.torch_dtype: float32` for fp32 master weights, paired with the new default `mp_policy` (bf16 compute, fp32 gradient reduction). - Pattern B (memory-constrained): all-bf16 storage, with the documented caveat that bf16 storage forces `torch.optim.AdamW` EMAs into bf16 too, where the 7-bit mantissa quantizes EMA updates -- producing the sinusoidal `grad_norm` and periodic loss artefacts observed on Llama-3.2-1B pre-training in issue #1679. The fused TE optimizer is recommended as an alternative for users who need bf16 storage with fp32 optimizer state. The guide is wired into the `docs/index.md` Development toctree and the For Practitioners card grid. `FSDP2Config.mp_policy`'s docstring gains a forward-reference to the guide. `examples/llm_pretrain/llama3_70b_pretrain.yaml` is realigned with the recommended pattern: `model.torch_dtype` flips from `bf16` to `float32` (this was the only tracked pre-training recipe pinning `torch_dtype: bf16` for the model -- exactly the configuration that seeds the #1679 footgun for users who copy from examples), and the canonical `mp_policy:` block is included explicitly so the pattern is discoverable from a real recipe rather than only from the guide. Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
c1ad3c0 to
80a5257
Compare
|
/ok to test 80a5257 |
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
…ision-config Signed-off-by: Yuhe Zhang <yuhez@nvidia.com> # Conflicts: # nemo_automodel/recipes/diffusion/train.py
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
|
/ok to test fe987fd |
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
jgerh
left a comment
There was a problem hiding this comment.
Completed the tech pubs review of the .md files. Applied formatting‑style updates, added copyedits, refined several sentences for clarity, and provided new link suggestions for the index. Let us know what you think.
jgerh
left a comment
There was a problem hiding this comment.
Completed the tech pubs review of the .md files. Applied formatting‑style updates, added copyedits, refined several sentences for clarity, and provided new link suggestions for the index. Let us know what you think.
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
jgerh
left a comment
There was a problem hiding this comment.
One minor edit for punctuation
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com> Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
0fa6c4a to
6e1e8eb
Compare
|
/ok to test 6e1e8eb |
|
Hi @akoumpa, I just realized we now have the active Fern docs tree under fern/versions/nightly/pages/, while docs/ seems to be legacy/reference? In my branch, I added docs/guides/mixed-precision-training.md and some other docs changes. Should I mirror/port these changes into fern/versions/nightly/pages/guides/ and add it to fern/versions/nightly.yml? |
What does this PR do ?
Adds guidance and recipe support for fp32 optimizer-state training while keeping bf16 forward/backward compute. The main outcome is a clearer mixed-precision contract:
model.torch_dtypecontrols resident parameter and checkpoint dtype, while FSDP2MixedPrecisionPolicycontrols compute/reduction/output dtype.Motivation
Long pre-training runs can be fragile when
torch.optim.AdamWis paired with bf16 resident parameters, because AdamW initializes its EMA state from the parameter dtype. That can leaveexp_avg/exp_avg_sqin bf16 even when FSDP2 compute and reduction use mixed precision. The recommended configs now make the optimizer-state choice explicit:model.torch_dtype: float32when the PyTorch optimizer path is the validated lower-memory or more stable path for that model.Main changes
docs/guides/mixed-precision-training.mdand links it from the docs index and fine-tuning guide.reduce_dtypedefault change indocs/breaking-changes.md._target_configs, then updates Flux, Qwen Image, and Wan examples to use torch AdamW with fp32 model storage and bf16 compute.Minor fixes found during testing
IterableDatasetdataloaders by usingstep_scheduler.epoch_lenwhen available andmax_stepsfor iterable datasets instead of callinglen(dataloader).foreachoptimizer kwarg into TE FusedAdam when TP is enabled.Validation
666 passed, 37 skipped.Before your PR is "Ready for review"
Pre checks: