Feat: integrate NNX LoRA support via Qwix with unified configuration#3320
Feat: integrate NNX LoRA support via Qwix with unified configuration#3320
Conversation
11939f9 to
6540bc8
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
69e481b to
80b5592
Compare
f5a0f6d to
23e79c7
Compare
5a05148 to
7570b3d
Compare
77e692a to
a0bb088
Compare
8bdfc89 to
6c12235
Compare
| @@ -0,0 +1,203 @@ | |||
| <!-- | |||
There was a problem hiding this comment.
Somewhere in here, consider mentioning how to specify which layers are trained by LoRA -- explaining about lora_module_path.yml and how to override via sft.yml or command line argument
There was a problem hiding this comment.
Here or in a separate PR, can we have a variant of an SFT test that shows eval improvement using LoRA SFT. This can be a notebook that we can encorporate into CI just like the SFT notebooks
There was a problem hiding this comment.
Sounds good! I've updated the documentation to explain how to target specific layers using lora_module_path.yml, including how to override it via sft.yml or CLI arguments.
@shralex - I'll tackle the eval notebook in a separate PR. There are a few known issues with QLoRA at the moment, so it makes sense to include the notebook in the follow-up once those are ironed out.
| @@ -0,0 +1,276 @@ | |||
| # Copyright 2023–2026 Google LLC | |||
There was a problem hiding this comment.
I'm wondering if there is some opportunity for code reuse here and in maxtext_to_hf_lora with their non-lora counterparts
There was a problem hiding this comment.
Agreed, there's definitely room for deduplication here. Just to make sure we're on the same page regarding the scope: are you suggesting we also extract and refactor the base to_maxtext and to_huggingface logic into shared utilities, or just focus on the LoRA-specific parts?
| @@ -0,0 +1,276 @@ | |||
| # Copyright 2023–2026 Google LLC | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
There was a problem hiding this comment.
How did we test these?
We should probably do a logits verification test like we do for model bring-up -- load a LoRA adaptor into MaxText, and do one training or SFT step. Compare logits with doing the same but on a HF model. Convert MaxText model back to HF and compare against HF again. The results should be described in the testing section of this PR, like we do in model bringup PRs, e.g., #1858
There was a problem hiding this comment.
Agree, we are working on it!
| "maxtext.checkpoint_conversion.to_maxtext": "base.yml", | ||
| "maxtext.checkpoint_conversion.to_huggingface": "base.yml", | ||
| "maxtext.checkpoint_conversion.maxtext_to_hf_lora": "base.yml", | ||
| "maxtext.checkpoint_conversion.hf_lora_to_maxtext": "post_train/sft.yml", |
There was a problem hiding this comment.
This should be base.yml.
There was a problem hiding this comment.
Thanks for point it out. Move to the base.yml.
| packing: True | ||
| learning_rate: 2.e-5 | ||
|
|
||
| # -------------- LoRA / QLoRA -------------- |
There was a problem hiding this comment.
Can we have these LoRA configs nested like:
?There was a problem hiding this comment.
Move this test module to tests/post_training/unit
| from flax import nnx | ||
|
|
||
| # Skip the entire test suite if dependencies are missing | ||
| pytest.importorskip("tunix") |
There was a problem hiding this comment.
You won't need this if you add post-training pytest marker
There was a problem hiding this comment.
Good catch! I just tested it with the marker and it works perfectly. I've removed the old logic and updated the PR.
| """Validates that LoRA is active or that target modules were matched.""" | ||
| from tunix.sft import utils as tunix_sft_utils # pylint: disable=import-outside-toplevel | ||
|
|
||
| if tunix_sft_utils.is_lora_enabled(lora_model): |
There was a problem hiding this comment.
Can we have is_lora_enabled method in MaxText instead of importing from Tunix?
There was a problem hiding this comment.
Sure, I've added the method in maxtext. But I'm afraid that if tunix change the logic in this method, the sft trainer might behave differently. and we might fallback to full weight quietly.
| def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: | ||
| """Optionally restores LoRA params from an external checkpoint item path.""" | ||
| lora_restore_path = getattr(mt_config, "lora_restore_path", "") | ||
| if not lora_restore_path: |
There was a problem hiding this comment.
should we just call this method only when lora_restore_path` is not None?
| max_logging.log("MaxText LoRA adapters loaded, skipping Qwix LoRA application") | ||
| return model | ||
|
|
||
| if not getattr(mt_config, "enable_lora", False): |
There was a problem hiding this comment.
Should we call this method only when enable_lora is True?
There was a problem hiding this comment.
Yes, that is better approach.
| else: | ||
| # HF Hub repo | ||
| try: | ||
| config_file = hf_hub_download(adapter_path, "adapter_config.json", token=os.environ.get("HF_AUTH_TOKEN")) |
There was a problem hiding this comment.
HF_AUTH_TOKEN is same as config.hf_access_token right?
0dfeb76 to
2f91ad8
Compare
2f91ad8 to
f5736a1
Compare
Description
Overview
This pull request introduces native LoRA support in MaxText by leveraging the NNX model definition and the Qwix library. It enables a seamless workflow for applying LoRA adapters during
training and provides utilities for bidirectional checkpoint conversion with the HuggingFace ecosystem.
Key Changes
Tests
The Qwix-based LoRA implementation was validated through a new unit test suite and verified via a comprehensive tutorial.
Implemented tests/unit/lora_utils_test.py to ensure structural correctness and trainer compatibility. Key areas covered:
Command to run unit tests:
1 # From the maxtext root directory
2 export PYTHONPATH=$PYTHONPATH:$(pwd)/src:$(pwd)
3 python3 tests/unit/lora_utils_test.py
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.