Skip to content

Feat: integrate NNX LoRA support via Qwix with unified configuration#3320

Open
RexBearIU wants to merge 1 commit intomainfrom
jackyf/feat/lora-nnx
Open

Feat: integrate NNX LoRA support via Qwix with unified configuration#3320
RexBearIU wants to merge 1 commit intomainfrom
jackyf/feat/lora-nnx

Conversation

@RexBearIU
Copy link
Copy Markdown
Collaborator

@RexBearIU RexBearIU commented Mar 5, 2026

Description

Overview
This pull request introduces native LoRA support in MaxText by leveraging the NNX model definition and the Qwix library. It enables a seamless workflow for applying LoRA adapters during
training and provides utilities for bidirectional checkpoint conversion with the HuggingFace ecosystem.

Key Changes

  • Core NNX Integration:
    • Refactored NNXDecoder layer application logic to support nnx.scan with dynamic graph initialization, ensuring compatibility with Qwix's parameter materialization.
  • SFT Pipeline Enhancements:
    • Integrated apply_lora_to_model and restore_lora_from_path into the SFT trainer.
    • Added dummy input preparation to materialize LoRA parameters before trainer initialization.
  • Bidirectional Conversion Scripts:
    • hf_lora_to_maxtext.py: Converts HuggingFace PEFT adapters to MaxText checkpoint format. Updated to 2026 copyright and cleaned up comments.
    • maxtext_to_hf_lora.py: Converts MaxText LoRA checkpoints back to HuggingFace format. Updated to use max_logging and 2026 copyright.
  • Configuration & Type System:
    • Added lora_module_path auto-detection logic for popular models (Llama, etc.) via lora_module_path.yml.
    • Updated types.py with specific LoRA/QLoRA fields.
  • Current Limitations:
    • QLoRA flags (lora_weight_qtype, lora_tile_size) are included in the configuration but explicitly marked as TODO / Not Working for this initial release.

Tests

The Qwix-based LoRA implementation was validated through a new unit test suite and verified via a comprehensive tutorial.

  1. Unit Tests
    Implemented tests/unit/lora_utils_test.py to ensure structural correctness and trainer compatibility. Key areas covered:
  • Model Transformation: Verified that apply_lora_to_model correctly injects nnx.LoRAParam into the model state.
  • Layer Scanning: Confirmed the implementation works with both scan_layers=True and scan_layers=False by handling the resulting differences in the nnx module path tree.
  • Trainer Compatibility: Validated that tunix.sft.peft_trainer.PeftTrainer correctly identifies the LoRA parameters for optimization, ensuring only adapter weights are trained.
  • Path Matching: Tested the regex logic for auto-detecting LoRA target modules across different model architectures (e.g., Llama).

Command to run unit tests:

1 # From the maxtext root directory
2 export PYTHONPATH=$PYTHONPATH:$(pwd)/src:$(pwd)
3 python3 tests/unit/lora_utils_test.py

  1. Documentation
  • Added docs/tutorials/posttraining/lora.md, which provides a step-by-step guide for running LoRA fine-tuning, including environment setup and checkpoint conversion. This tutorial serves as the reference for end-to-end functional verification.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 11939f9 to 6540bc8 Compare March 5, 2026 10:53
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 5, 2026

Codecov Report

❌ Patch coverage is 13.63636% with 114 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/lora_utils.py 17.47% 85 Missing ⚠️
src/maxtext/layers/nnx_decoders.py 0.00% 28 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 7 times, most recently from 69e481b to 80b5592 Compare March 11, 2026 08:19
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 11 times, most recently from f5a0f6d to 23e79c7 Compare March 25, 2026 08:33
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 5 times, most recently from 5a05148 to 7570b3d Compare April 14, 2026 02:45
@RexBearIU RexBearIU marked this pull request as ready for review April 14, 2026 04:08
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 5 times, most recently from 77e692a to a0bb088 Compare April 15, 2026 08:19
@RexBearIU RexBearIU requested a review from jacoguzo as a code owner April 15, 2026 08:45
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 8 times, most recently from 8bdfc89 to 6c12235 Compare April 15, 2026 10:38
@@ -0,0 +1,203 @@
<!--
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere in here, consider mentioning how to specify which layers are trained by LoRA -- explaining about lora_module_path.yml and how to override via sft.yml or command line argument

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here or in a separate PR, can we have a variant of an SFT test that shows eval improvement using LoRA SFT. This can be a notebook that we can encorporate into CI just like the SFT notebooks

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I've updated the documentation to explain how to target specific layers using lora_module_path.yml, including how to override it via sft.yml or CLI arguments.

@shralex - I'll tackle the eval notebook in a separate PR. There are a few known issues with QLoRA at the moment, so it makes sense to include the notebook in the follow-up once those are ironed out.

@@ -0,0 +1,276 @@
# Copyright 2023–2026 Google LLC
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there is some opportunity for code reuse here and in maxtext_to_hf_lora with their non-lora counterparts

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, there's definitely room for deduplication here. Just to make sure we're on the same page regarding the scope: are you suggesting we also extract and refactor the base to_maxtext and to_huggingface logic into shared utilities, or just focus on the LoRA-specific parts?

@@ -0,0 +1,276 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did we test these?

We should probably do a logits verification test like we do for model bring-up -- load a LoRA adaptor into MaxText, and do one training or SFT step. Compare logits with doing the same but on a HF model. Convert MaxText model back to HF and compare against HF again. The results should be described in the testing section of this PR, like we do in model bringup PRs, e.g., #1858

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we are working on it!

Comment thread src/maxtext/configs/pyconfig.py Outdated
"maxtext.checkpoint_conversion.to_maxtext": "base.yml",
"maxtext.checkpoint_conversion.to_huggingface": "base.yml",
"maxtext.checkpoint_conversion.maxtext_to_hf_lora": "base.yml",
"maxtext.checkpoint_conversion.hf_lora_to_maxtext": "post_train/sft.yml",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be base.yml.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point it out. Move to the base.yml.

packing: True
learning_rate: 2.e-5

# -------------- LoRA / QLoRA --------------
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have these LoRA configs nested like:

?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this test module to tests/post_training/unit

from flax import nnx

# Skip the entire test suite if dependencies are missing
pytest.importorskip("tunix")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You won't need this if you add post-training pytest marker

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I just tested it with the marker and it works perfectly. I've removed the old logic and updated the PR.

Comment thread src/maxtext/utils/lora_utils.py Outdated
"""Validates that LoRA is active or that target modules were matched."""
from tunix.sft import utils as tunix_sft_utils # pylint: disable=import-outside-toplevel

if tunix_sft_utils.is_lora_enabled(lora_model):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have is_lora_enabled method in MaxText instead of importing from Tunix?

Copy link
Copy Markdown
Collaborator Author

@RexBearIU RexBearIU Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've added the method in maxtext. But I'm afraid that if tunix change the logic in this method, the sft trainer might behave differently. and we might fallback to full weight quietly.

Comment thread src/maxtext/utils/lora_utils.py Outdated
def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:
"""Optionally restores LoRA params from an external checkpoint item path."""
lora_restore_path = getattr(mt_config, "lora_restore_path", "")
if not lora_restore_path:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just call this method only when lora_restore_path` is not None?

Comment thread src/maxtext/utils/lora_utils.py Outdated
max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application")
return model

if not getattr(mt_config, "enable_lora", False):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this method only when enable_lora is True?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is better approach.

else:
# HF Hub repo
try:
config_file = hf_hub_download(adapter_path, "adapter_config.json", token=os.environ.get("HF_AUTH_TOKEN"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF_AUTH_TOKEN is same as config.hf_access_token right?

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 2 times, most recently from 0dfeb76 to 2f91ad8 Compare April 16, 2026 10:32
@RexBearIU RexBearIU changed the title Jackyf/feat/lora nnx Feat: integrate NNX LoRA support via Qwix with unified configuration Apr 16, 2026
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 2f91ad8 to f5736a1 Compare April 16, 2026 10:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants