Skip to content

fix(training): clarify mixed-precision optimizer-state setup#2248

Open
yuhezhang-ai wants to merge 14 commits into
mainfrom
yuhez/feat/mixed-precision-config
Open

fix(training): clarify mixed-precision optimizer-state setup#2248
yuhezhang-ai wants to merge 14 commits into
mainfrom
yuhez/feat/mixed-precision-config

Conversation

@yuhezhang-ai
Copy link
Copy Markdown
Contributor

@yuhezhang-ai yuhezhang-ai commented May 15, 2026

What does this PR do ?

Adds guidance and recipe support for fp32 optimizer-state training while keeping bf16 forward/backward compute. The main outcome is a clearer mixed-precision contract: model.torch_dtype controls resident parameter and checkpoint dtype, while FSDP2 MixedPrecisionPolicy controls compute/reduction/output dtype.

Motivation

Long pre-training runs can be fragile when torch.optim.AdamW is paired with bf16 resident parameters, because AdamW initializes its EMA state from the parameter dtype. That can leave exp_avg / exp_avg_sq in bf16 even when FSDP2 compute and reduction use mixed precision. The recommended configs now make the optimizer-state choice explicit:

  • Use TE FusedAdam with bf16 model storage when the model has a validated TE path and bf16 training checkpoints are desired.
  • Use torch AdamW with model.torch_dtype: float32 when the PyTorch optimizer path is the validated lower-memory or more stable path for that model.

Main changes

  • Adds docs/guides/mixed-precision-training.md and links it from the docs index and fine-tuning guide.
  • Documents the FSDP2 reduce_dtype default change in docs/breaking-changes.md.
  • Updates LLM pre-training config examples to the recommended precision setup:
    • Llama 3 70B uses TE FusedAdam with bf16 model storage.
    • Moonlight 16B uses torch AdamW with fp32 model storage and bf16 compute.
    • GPT-style small examples use TE because memory is small and the path is valid.
    • DeepSeek keeps the existing bf16/TE path with comments noting that it has not been validated yet due DeepEP dispatch failures.
  • Adds diffusion pre-training support for split storage/compute dtype and optimizer _target_ configs, then updates Flux, Qwen Image, and Wan examples to use torch AdamW with fp32 model storage and bf16 compute.
  • Fixes PP stage-shape metadata so split fp32 storage + bf16 compute uses the FSDP output dtype for pipeline metadata.
  • Added a runtime warning for full-parameter training with torch.optim.Adam/AdamW on trainable bf16 parameters, guiding users toward fp32 model params with FSDP mixed precision or TE FusedAdam master weights.

Minor fixes found during testing

  • Keeps LR scheduler construction compatible with IterableDataset dataloaders by using step_scheduler.epoch_len when available and max_steps for iterable datasets instead of calling len(dataloader).
  • Avoids injecting the torch-only foreach optimizer kwarg into TE FusedAdam when TP is enabled.
  • Hardens Megatron preprocessing by skipping malformed or empty text fragments, finalizing every output builder, and propagating failed worker processes as non-zero exits.
  • Adds focused unit coverage for diffusion optimizer construction, PP dtype metadata behavior, and scheduler behavior with iterable datasets.

Validation

  • User ran the unit-test suite: 666 passed, 37 skipped.
  • Manual Slurm experiments:
    • Llama 70B validated on 32 GPUs with the TE path.
    • Moonlight 16B validated with torch AdamW/fp32 storage using lower memory than TE.
    • DeepSeek still fails in DeepEP dispatch before memory comparison, so its recipe is left unchanged except for comments.
    • Diffusion TE paths OOMed in experiments; torch AdamW/fp32 storage was the validated path for Flux/Qwen/Wan.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

`build_lr_scheduler` computed `total_steps` via `len(step_scheduler.dataloader)`,
which raises `NotImplementedError` for IterableDataset-backed dataloaders
(e.g. `NanogptDataset`, `MegatronPretraining`). LLM pre-training recipes
that stream tokens cannot use the LR scheduler today.

Use `step_scheduler.epoch_len` instead, which is already in optimizer-step
units (microbatches // grad_acc_steps) and is set to `None` by the step
scheduler exactly when the dataloader has no `__len__`. In that branch,
fall back to `step_scheduler.max_steps` and raise a clear error if the
user has set neither.

No behavior change for map-style datasets: the new path resolves to the
same value (epochs * epoch_len, optionally capped at max_steps) without
the redundant `// grad_acc_steps` that the step scheduler already applied.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
`FSDP2Config.__post_init__` now constructs the default
`MixedPrecisionPolicy` with `reduce_dtype=torch.float32` instead of
`torch.bfloat16`. This matches Megatron-LM and Lingua: forward / backward
stay in bf16 (fast), but the gradient reduce-scatter / all-reduce runs
in fp32 so accumulation error does not compound at large DP world sizes.
`param_dtype` and `output_dtype` are unchanged at `bfloat16`.

This is a behavior change. Users who relied on the old all-bf16
reduction can opt back in by setting `mp_policy:` explicitly; the
recipe to do so is recorded in `docs/breaking-changes.md` for 0.5.0.

The companion docs commit adds a full mixed-precision guide
(`docs/guides/mixed-precision.md`) plus a worked example in
`examples/llm_pretrain/llama3_70b_pretrain.yaml`.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
New guide `docs/guides/mixed-precision.md` walks through the two
canonical FSDP2 precision patterns and the trap that catches users.

- Pattern A (recommended for pre-training): `model.torch_dtype: float32`
  for fp32 master weights, paired with the new default `mp_policy`
  (bf16 compute, fp32 gradient reduction).
- Pattern B (memory-constrained): all-bf16 storage, with the documented
  caveat that bf16 storage forces `torch.optim.AdamW` EMAs into bf16
  too, where the 7-bit mantissa quantizes EMA updates -- producing the
  sinusoidal `grad_norm` and periodic loss artefacts observed on
  Llama-3.2-1B pre-training in issue #1679. The fused TE optimizer is
  recommended as an alternative for users who need bf16 storage with
  fp32 optimizer state.

The guide is wired into the `docs/index.md` Development toctree and
the For Practitioners card grid. `FSDP2Config.mp_policy`'s docstring
gains a forward-reference to the guide.

`examples/llm_pretrain/llama3_70b_pretrain.yaml` is realigned with
the recommended pattern: `model.torch_dtype` flips from `bf16` to
`float32` (this was the only tracked pre-training recipe pinning
`torch_dtype: bf16` for the model -- exactly the configuration that
seeds the #1679 footgun for users who copy from examples), and the
canonical `mp_policy:` block is included explicitly so the pattern is
discoverable from a real recipe rather than only from the guide.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai yuhezhang-ai force-pushed the yuhez/feat/mixed-precision-config branch from c1ad3c0 to 80a5257 Compare May 15, 2026 21:42
@NVIDIA-NeMo NVIDIA-NeMo deleted a comment from copy-pr-bot Bot May 15, 2026
@yuhezhang-ai
Copy link
Copy Markdown
Contributor Author

/ok to test 80a5257

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
…ision-config

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>

# Conflicts:
#	nemo_automodel/recipes/diffusion/train.py
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai
Copy link
Copy Markdown
Contributor Author

/ok to test fe987fd

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai yuhezhang-ai requested a review from a team as a code owner May 20, 2026 19:16
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed the tech pubs review of the .md files. Applied formatting‑style updates, added copyedits, refined several sentences for clarity, and provided new link suggestions for the index. Let us know what you think.

Comment thread docs/guides/llm/finetune.md Outdated
Comment thread docs/guides/llm/finetune.md Outdated
Comment thread docs/guides/llm/finetune.md Outdated
Comment thread docs/guides/llm/finetune.md Outdated
Comment thread docs/guides/llm/finetune.md Outdated
Comment thread docs/breaking-changes.md Outdated
Comment thread docs/breaking-changes.md Outdated
Comment thread docs/breaking-changes.md Outdated
Comment thread docs/index.md Outdated
Comment thread docs/index.md Outdated
Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed the tech pubs review of the .md files. Applied formatting‑style updates, added copyedits, refined several sentences for clarity, and provided new link suggestions for the index. Let us know what you think.

Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor edit for punctuation

Comment thread docs/guides/mixed-precision-training.md Outdated
Co-authored-by: jgerh <163925524+jgerh@users.noreply.github.com>
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai yuhezhang-ai force-pushed the yuhez/feat/mixed-precision-config branch from 0fa6c4a to 6e1e8eb Compare May 20, 2026 21:41
@yuhezhang-ai
Copy link
Copy Markdown
Contributor Author

/ok to test 6e1e8eb

Copy link
Copy Markdown
Contributor Author

Hi @akoumpa, I just realized we now have the active Fern docs tree under fern/versions/nightly/pages/, while docs/ seems to be legacy/reference? In my branch, I added docs/guides/mixed-precision-training.md and some other docs changes. Should I mirror/port these changes into fern/versions/nightly/pages/guides/ and add it to fern/versions/nightly.yml?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants