Skip to content

feat(diffusion): add Wan2.2 T2V-A14B two-stage finetuning support#2284

Open
linnanwang wants to merge 4 commits into
mainfrom
wan22
Open

feat(diffusion): add Wan2.2 T2V-A14B two-stage finetuning support#2284
linnanwang wants to merge 4 commits into
mainfrom
wan22

Conversation

@linnanwang
Copy link
Copy Markdown
Contributor

@linnanwang linnanwang commented May 21, 2026

What does this PR do?

Adds end-to-end finetuning and inference support for Wan2.2-T2V-A14B, a two-stage text-to-video diffusion model whose denoising pipeline routes between a high-noise transformer and a low-noise transformer_2 across a configurable timestep boundary.

Changelog

  • NeMoAutoDiffusionPipeline.from_pretrained: new active_transformer kwarg ("transformer" | "transformer_2"); when set on a two-transformer pipeline the unused transformer is dropped before device placement / FSDP2 wrapping so only one ~14B model occupies GPU memory.
  • TrainDiffusionRecipe (recipes/diffusion/train.py): reads model.stage (high_noise | low_noise) and model.boundary_ratio (falls back to pipe.config.boundary_ratio); derives flow_matching.sigma_min / sigma_max from the stage + boundary so each stage only trains on its own noise range; threads active_transformer into the pipeline loader; suffixes the wandb run name with the stage.
  • examples/diffusion/finetune/wan2_2_t2v_flow.yaml: new finetune config — A14B hub path, stage knob, boundary_ratio: 0.875, bumped dp_size, explicit activation checkpointing.
  • examples/diffusion/generate/configs/generate_wan22.yaml: new inference config — A14B hub path, two optional checkpoint paths, guidance_scale_2, VAE cpu offload defaulted on.
  • examples/diffusion/generate/generate.py: load_checkpoint_into_pipeline accepts model.checkpoint_high_noise / model.checkpoint_low_noise (both optional, mutually exclusive with the legacy single model.checkpoint) and loads each into the matching pipe.transformer / pipe.transformer_2 attribute.
  • tools/diffusion/processors/wan.py: new Wan22Processor subclass registered as wan2.2; defaults to Wan-AI/Wan2.2-T2V-A14B-Diffusers, marks cache files with model_version: "wan2.2" so Wan2.1 and Wan2.2 caches can coexist.
  • tools/diffusion/preprocessing_multiprocess.py: wan2.2 added to the --processor choices.
  • tools/diffusion/processors/__init__.py: export Wan22Processor.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)
  • Workflow: preprocess data once with --processor wan2.2, run finetuning twice (model.stage: high_noise and model.stage: low_noise) with distinct checkpoint.checkpoint_dir per stage, then point generate_wan22.yaml at the two resulting consolidated checkpoint dirs (either or both optional — missing stages fall back to hub-pretrained weights).

Signed-off-by: linnan wang <linnanw@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant