Skip to content

NNX migration prep (5/N): enable NNX by default#3526

Draft
ecnal-cienet wants to merge 7 commits intomainfrom
feat/nnx-set-defaults-true
Draft

NNX migration prep (5/N): enable NNX by default#3526
ecnal-cienet wants to merge 7 commits intomainfrom
feat/nnx-set-defaults-true

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 31, 2026

NNX Migration Route Map

  1. ✅ Add NNX scaffolding: pure_nnx flag, init_state_fn, TrainStateNNX, NNX utils. Linen workflow unchanged. (PR #3427)
  2. ✅ NNX sharding utilities: get_abstract_state_nnx, get_named_sharding_nnx, set_named_sharding_nnx, get_partition_spec_nnx, get_mesh_from_config. (PR #3470)
  3. ✅ NNX fully supported end-to-end: TrainStateNNX, model creation, gradient accumulation, checkpointing, and training loop dispatch. (PR #3500)
  4. ✅ NNX sharding diagnostics and bidirectional Linen↔NNX checkpoint conversion utilities. (PR #3525)
  5. 🔄 [This PR] Enable NNX by default; fix unit test failures.
  6. ❌ Remove Linen-specific code paths and NNX compatibility flags.

Description

Note: This is the fifth in a series of NNX migration PRs. This PR flips all three NNX flags to True in base.yml, making NNX the default training path, and fixes the unit test failures that surface as a result.

Config change

src/maxtext/configs/base.yml — three flags flipped to True:

enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

Unit test fixes

File Fix
src/maxtext/layers/nnx_decoders.py Add multimodal_input=None to NNXDecoder.__call__ and unpack into individual fields — Transformer.__call__ passes a unified MultimodalInput object but NNXDecoder previously only accepted the fields individually.
src/maxtext/utils/muon_utils.py Return mapped_state (an nnx.State with correct structure) directly instead of converting to a flat nested dict with a "params" wrapper — the old conversion broke attribute access expected by tests.
src/maxtext/trainers/post_train/distillation/distillation_utils.py Guard optimizer_state restore on whether it exists in the checkpoint — PeftTrainer.save() only saves model_params, so restoring optimizer_state unconditionally caused a KeyError.
tests/integration/gradient_accumulation_test.py Switch test_sft_grad_accumulate_same_loss from the deprecated SFT loop to the NNX-native train_sft.py — the deprecated loop always passed nextrng as a 3rd positional arg, mismatching the 2-element NNX in_shardings.
tests/unit/sharding_compare_test.py Filter abstract_state.model leaves to floating-point only before asserting dtype == float32 — the NNX model state includes RNG state variables (uint32/key) that are not weight parameters.

Tests

python3 -m pytest tests/unit/ tests/integration/ -v

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch 3 times, most recently from 7536d99 to 9b7b0d9 Compare April 2, 2026 18:13
xibinliu and others added 4 commits April 6, 2026 17:12
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
- Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests
- Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils
- Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py
- Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch
…ison utility

- modify print_shardings_params to support NNX (maxtext_utils.py)
- add --pure_nnx flag to run_sharding_dump.py
- add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py)
- add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py)
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch from e4ebfb3 to bac289f Compare April 6, 2026 19:31
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-set-defaults-true branch from bac289f to db75887 Compare April 6, 2026 21:09
@ecnal-cienet ecnal-cienet changed the title Feat/nnx set defaults true NNX migration prep (5/N): enable NNX by default Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants