From d14ca9a5e4291c4d3660b745257111278961f881 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Wed, 20 May 2026 19:29:46 -0700 Subject: [PATCH 1/4] feat(diffusion): add Wan2.2 T2V-A14B two-stage finetuning support Signed-off-by: linnan wang --- .../diffusion/finetune/wan2_2_t2v_flow.yaml | 104 ++++++++++++++++++ .../generate/configs/generate_wan22.yaml | 44 ++++++++ examples/diffusion/generate/generate.py | 94 +++++++++++----- .../_diffusers/auto_diffusion_pipeline.py | 58 ++++++++++ nemo_automodel/recipes/diffusion/train.py | 58 ++++++++++ tools/diffusion/preprocessing_multiprocess.py | 2 +- tools/diffusion/processors/__init__.py | 3 +- tools/diffusion/processors/wan.py | 42 ++++++- 8 files changed, 373 insertions(+), 32 deletions(-) create mode 100644 examples/diffusion/finetune/wan2_2_t2v_flow.yaml create mode 100644 examples/diffusion/generate/configs/generate_wan22.yaml diff --git a/examples/diffusion/finetune/wan2_2_t2v_flow.yaml b/examples/diffusion/finetune/wan2_2_t2v_flow.yaml new file mode 100644 index 0000000000..5867cdd417 --- /dev/null +++ b/examples/diffusion/finetune/wan2_2_t2v_flow.yaml @@ -0,0 +1,104 @@ +seed: 42 + +wandb: + project: wan-t2v-flow-matching + mode: online + # Stage name is auto-appended (e.g. wan2_2_t2v_fm_high_noise) when model.stage is set. + name: wan2_2_t2v_fm + +dist_env: + backend: nccl + timeout_minutes: 30 + +model: + pretrained_model_name_or_path: Wan-AI/Wan2.2-T2V-A14B-Diffusers + mode: finetune + # Two-stage finetuning knob. Required for Wan2.2-T2V-A14B. + # "high_noise" trains pipe.transformer on sigma in [boundary_ratio, 1.0]. + # "low_noise" trains pipe.transformer_2 on sigma in [0.0, boundary_ratio]. + # Run twice (once per stage) with different checkpoint dirs to fully finetune. + stage: high_noise + # Optional override; if unset the recipe reads boundary_ratio from the + # pipeline config (Wan2.2-T2V-A14B ships with 0.875). + boundary_ratio: 0.875 + +step_scheduler: + global_batch_size: 8 + local_batch_size: 1 + ckpt_every_steps: 1000 + num_epochs: 100 + log_every: 2 + save_checkpoint_every_epoch: false + +data: + dataloader: + _target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader + # Point this at the cache produced by tools/diffusion/processors/wan.py --variant wan22 + cache_dir: PATH_TO_YOUR_WAN22_DATA + model_type: wan + base_resolution: [512, 512] + dynamic_batch_size: false + shuffle: true + drop_last: false + num_workers: 2 + +optim: + learning_rate: 5e-6 + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + foreach: true + +performance: + check_loss: false + grad_clip_foreach: true + +lr_scheduler: + lr_decay_style: cosine + lr_warmup_steps: 0 + min_lr: 1e-6 + +# Flow matching V2 configuration. sigma_min/sigma_max are derived from +# model.stage + boundary_ratio at runtime; any values set here are ignored. +flow_matching: + adapter_type: "simple" + adapter_kwargs: {} + timestep_sampling: "uniform" + logit_mean: 0.0 + logit_std: 1.0 + flow_shift: 3.0 + mix_uniform_ratio: 0.1 + num_train_timesteps: 1000 + i2v_prob: 0.3 + use_loss_weighting: true + log_interval: 100 + summary_log_interval: 10 + +fsdp: + tp_size: 1 + cp_size: 1 + pp_size: 1 + dp_replicate_size: 1 + # 14B transformer needs aggressive sharding; bump dp_size beyond Wan2.1's default. + dp_size: 8 + defer_fsdp_grad_sync: true + enable_fsdp2_prefetch: true + fsdp2_backward_prefetch_depth: 2 + fsdp2_forward_prefetch_depth: 1 + # Explicitly required for the 14B model to fit at bf16. + activation_checkpointing: true + +checkpoint: + enabled: true + # Use a different checkpoint_dir per stage so the two stages stay distinct + # (e.g. .../wan22_high/ and .../wan22_low/). Both are then consumed by + # examples/diffusion/generate/configs/generate_wan22.yaml. + checkpoint_dir: PATH_TO_YOUR_CKPT_DIR + model_save_format: safetensors + save_consolidated: true + diffusers_compatible: true + restore_from: null + +ci: + recipe_owner: linnanw + time: "00:30:00" diff --git a/examples/diffusion/generate/configs/generate_wan22.yaml b/examples/diffusion/generate/configs/generate_wan22.yaml new file mode 100644 index 0000000000..adb9776aa9 --- /dev/null +++ b/examples/diffusion/generate/configs/generate_wan22.yaml @@ -0,0 +1,44 @@ +model: + pretrained_model_name_or_path: "Wan-AI/Wan2.2-T2V-A14B-Diffusers" + # Two-transformer loading. Either, both, or neither may be set: + # - both unset → use the hub-pretrained weights for both stages + # - one set → swap that stage's weights, leave the other stage at hub baseline + # - both set → swap both stages' weights + # Each path should be a training checkpoint dir produced by + # examples/diffusion/finetune/wan2_2_t2v_flow.yaml with the matching model.stage. + checkpoint_high_noise: null # set to: PATH_TO_CKPT/wan22_high/checkpoint-1000 + checkpoint_low_noise: null # set to: PATH_TO_CKPT/wan22_low/checkpoint-1000 + +inference: + num_inference_steps: 50 + # High-noise stage guidance (used for timesteps >= boundary_ratio * num_train_timesteps). + guidance_scale: 5.0 + height: 480 + width: 832 + dtype: "bfloat16" + max_samples: 10 + prompts: + - "A cat sitting on a windowsill watching the rain" + pipeline_kwargs: + num_frames: 81 + negative_prompt: "" + # Low-noise stage guidance (used for timesteps < boundary_ratio * num_train_timesteps). + # Wan2.2 A14B uses asymmetric guidance: a stronger scale on the high-noise stage + # and a milder one on the low-noise stage. Tune per-prompt as needed. + guidance_scale_2: 5.0 + +output: + output_dir: "./inference_outputs" + fps: 16 + +distributed: null + +vae: + enable_slicing: true + enable_tiling: true + # WanPipeline cpu_offload_seq cycles text_encoder->transformer->transformer_2->vae, + # so enabling cpu offload keeps only one transformer on GPU at a time. Strongly + # recommended for A14B (two 14B transformers won't co-reside on a single GPU). + enable_cpu_offload: true + +seed: 42 diff --git a/examples/diffusion/generate/generate.py b/examples/diffusion/generate/generate.py index 6a164041bb..73c23f9682 100644 --- a/examples/diffusion/generate/generate.py +++ b/examples/diffusion/generate/generate.py @@ -183,27 +183,63 @@ def _build_parallel_scheme(scheme_cfg, dist_info): def load_checkpoint_into_pipeline(pipe, cfg): - """Load a training checkpoint into the pipeline's transformer. + """Load training checkpoint(s) into the pipeline's transformer(s). - Expects a consolidated HF safetensors checkpoint produced by training - with model_save_format: safetensors, save_consolidated: true, and - diffusers_compatible: true. The checkpoint directory should contain - model/consolidated/ with diffusion_pytorch_model.safetensors.index.json - and the corresponding safetensors files. + Supports both single-transformer pipelines (Wan2.1, FLUX, HunyuanVideo) and + two-transformer pipelines (Wan2.2-T2V-A14B with ``transformer`` for the + high-noise stage and ``transformer_2`` for the low-noise stage). - Uses the standard diffusers from_pretrained() API for loading. + Single-transformer path: set ``model.checkpoint`` to load into ``pipe.transformer``. + + Two-transformer path (Wan2.2): set ``model.checkpoint_high_noise`` and/or + ``model.checkpoint_low_noise``. Each is independently optional — a missing + one leaves that stage's transformer at its hub-pretrained weights, which is + useful for sanity-checking a partial finetune. + + Expects consolidated HF safetensors checkpoints produced by training with + ``model_save_format: safetensors`` and ``save_consolidated: true``. Args: - pipe: The diffusion pipeline with a `.transformer` attribute. - cfg: Config node with `model.checkpoint` path. + pipe: The diffusion pipeline. May expose ``transformer_2`` for Wan2.2. + cfg: Config node with one of ``model.checkpoint``, + ``model.checkpoint_high_noise``, or ``model.checkpoint_low_noise``. + + Raises: + ValueError: If both single-stage and two-stage checkpoint fields are set. """ checkpoint = getattr(cfg.model, "checkpoint", None) - if not checkpoint: - return + checkpoint_high = getattr(cfg.model, "checkpoint_high_noise", None) + checkpoint_low = getattr(cfg.model, "checkpoint_low_noise", None) + + if checkpoint and (checkpoint_high or checkpoint_low): + raise ValueError( + "model.checkpoint is mutually exclusive with " + "model.checkpoint_high_noise / model.checkpoint_low_noise. " + "Use the latter pair for two-transformer pipelines (Wan2.2) and " + "model.checkpoint for single-transformer pipelines." + ) dtype_str = getattr(cfg.inference, "dtype", "bfloat16") torch_dtype = _resolve_dtype(dtype_str) + if checkpoint: + _load_checkpoint_into_attr(pipe, "transformer", checkpoint, torch_dtype) + return + + if checkpoint_high: + _load_checkpoint_into_attr(pipe, "transformer", checkpoint_high, torch_dtype) + if checkpoint_low: + if getattr(pipe, "transformer_2", None) is None: + raise ValueError( + "model.checkpoint_low_noise is set but the loaded pipeline has no " + "transformer_2 attribute. This option only applies to two-stage " + "models like Wan2.2-T2V-A14B." + ) + _load_checkpoint_into_attr(pipe, "transformer_2", checkpoint_low, torch_dtype) + + +def _load_checkpoint_into_attr(pipe, attr_name, checkpoint, torch_dtype): + """Load a single consolidated/sharded checkpoint into ``pipe.``.""" checkpoint_dir = Path(checkpoint) if not checkpoint_dir.exists(): raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") @@ -213,32 +249,38 @@ def load_checkpoint_into_pipeline(pipe, cfg): consolidated_st_dir = checkpoint_dir / "model" / "consolidated" sharded_dir = checkpoint_dir / "model" + target = getattr(pipe, attr_name) + if target is None: + raise AttributeError(f"Pipeline has no attribute {attr_name!r} to load checkpoint into") + if ema_path.exists(): - logger.info("Loading EMA checkpoint from %s", ema_path) + logger.info("Loading EMA checkpoint from %s into %s", ema_path, attr_name) ema_state = torch.load(ema_path, map_location="cuda", weights_only=True) - pipe.transformer.load_state_dict(ema_state, strict=True) - logger.info("Loaded EMA checkpoint") + target.load_state_dict(ema_state, strict=True) elif consolidated_path.exists(): - logger.info("Loading consolidated checkpoint from %s", consolidated_path) + logger.info("Loading consolidated checkpoint from %s into %s", consolidated_path, attr_name) state_dict = torch.load(consolidated_path, map_location="cuda", weights_only=True) if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] - pipe.transformer.load_state_dict(state_dict, strict=True) - logger.info("Loaded consolidated checkpoint") + target.load_state_dict(state_dict, strict=True) elif consolidated_st_dir.is_dir() and any( name.endswith(".safetensors") for name in os.listdir(consolidated_st_dir) ): - logger.info("Loading consolidated safetensors checkpoint from %s", consolidated_st_dir) - pipe.transformer = type(pipe.transformer).from_pretrained(str(consolidated_st_dir), torch_dtype=torch_dtype) - pipe.transformer.to("cuda") - logger.info("Loaded consolidated safetensors checkpoint") + logger.info("Loading consolidated safetensors checkpoint from %s into %s", consolidated_st_dir, attr_name) + new_module = type(target).from_pretrained(str(consolidated_st_dir), torch_dtype=torch_dtype) + new_module.to("cuda") + setattr(pipe, attr_name, new_module) elif sharded_dir.is_dir() and any(name.endswith(".distcp") for name in os.listdir(sharded_dir)): - logger.info("Loading sharded FSDP checkpoint from %s", sharded_dir) - pipe.transformer = _load_sharded_fsdp_checkpoint(pipe.transformer, str(sharded_dir), torch_dtype) - pipe.transformer.to("cuda", dtype=torch_dtype) - logger.info("Loaded sharded FSDP checkpoint") + logger.info("Loading sharded FSDP checkpoint from %s into %s", sharded_dir, attr_name) + new_module = _load_sharded_fsdp_checkpoint(target, str(sharded_dir), torch_dtype) + new_module.to("cuda", dtype=torch_dtype) + setattr(pipe, attr_name, new_module) else: - logger.warning("No recognized checkpoint format found in %s, using base model weights", checkpoint_dir) + logger.warning( + "No recognized checkpoint format found in %s, leaving %s at base weights", + checkpoint_dir, + attr_name, + ) def load_lora_weights_into_pipeline(pipe, cfg): diff --git a/nemo_automodel/_diffusers/auto_diffusion_pipeline.py b/nemo_automodel/_diffusers/auto_diffusion_pipeline.py index 3ad3a0d495..582e5f1439 100644 --- a/nemo_automodel/_diffusers/auto_diffusion_pipeline.py +++ b/nemo_automodel/_diffusers/auto_diffusion_pipeline.py @@ -188,6 +188,50 @@ def _move_module_to_device(module: nn.Module, device: torch.device, torch_dtype: module.to(device=device) +def _select_active_transformer(pipe, active_transformer: str) -> None: + """Keep only the chosen transformer on a two-transformer pipeline. + + Two-stage diffusion pipelines (Wan2.2 T2V-A14B) register both + ``transformer`` (high-noise) and ``transformer_2`` (low-noise). Finetuning + only needs one at a time. This helper swaps the chosen one into + ``pipe.transformer`` and nulls the other so subsequent device placement, + LoRA injection, and FSDP2 wrapping only touch the active model. + + Args: + pipe: A diffusers pipeline that may expose ``transformer_2``. + active_transformer: Either ``"transformer"`` or ``"transformer_2"``. + + Raises: + ValueError: If ``active_transformer`` is unrecognized. + AttributeError: If ``active_transformer="transformer_2"`` but the pipeline + has no ``transformer_2`` attribute (model is not a two-stage variant). + """ + if active_transformer not in ("transformer", "transformer_2"): + raise ValueError(f"active_transformer must be 'transformer' or 'transformer_2', got {active_transformer!r}") + + has_t2 = getattr(pipe, "transformer_2", None) is not None + if active_transformer == "transformer_2": + if not has_t2: + raise AttributeError( + "active_transformer='transformer_2' requested but the loaded pipeline " + "has no transformer_2 attribute. This option is for two-stage models like " + "Wan2.2-T2V-A14B." + ) + # Move transformer_2 into the transformer slot and free the old one. + old_transformer = pipe.transformer + pipe.transformer = pipe.transformer_2 + pipe.transformer_2 = None + del old_transformer + logger.info("[INFO] Selected transformer_2 as active transformer (low-noise stage)") + else: + if has_t2: + pipe.transformer_2 = None + logger.info("[INFO] Selected transformer as active transformer (high-noise stage); freed transformer_2") + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _ensure_params_trainable(module: nn.Module, module_name: Optional[str] = None) -> int: """ Ensure that all parameters in the given module are trainable. @@ -373,6 +417,7 @@ def from_pretrained( components_to_load: Optional[Iterable[str]] = None, peft_cfg=None, model_type=None, + active_transformer: Optional[str] = None, **kwargs, ) -> Tuple[DiffusionPipeline, Dict[str, ParallelManager]]: """ @@ -396,6 +441,12 @@ def from_pretrained( before _apply_parallelization() (FSDP2 wrapping). Base weights are frozen after FSDP2; LoRA params are collected pre-FSDP2 and stored on pipe. model_type: "flux" | "wan" | "hunyuan". Required when peft_cfg is provided. + active_transformer: For two-transformer pipelines (e.g. Wan2.2 with + ``transformer`` + ``transformer_2``), selects which one becomes + ``pipe.transformer`` for training. Accepts ``"transformer"`` (default, + high-noise stage in Wan2.2) or ``"transformer_2"`` (low-noise stage). + The unused transformer is replaced with ``None`` and freed before + device placement so only one transformer occupies GPU memory. **kwargs: Additional arguments passed to DiffusionPipeline.from_pretrained Returns: @@ -419,6 +470,13 @@ def from_pretrained( logger.info("[INFO] Loaded pipeline type: %s", type(pipe).__name__) + # Two-transformer pipelines (Wan2.2): keep only the selected transformer + # on the pipeline so device placement / sharding only touches one model. + # We do this before any device move so the dropped transformer never + # occupies GPU memory. + if active_transformer is not None: + _select_active_transformer(pipe, active_transformer) + # Decide device dev = _choose_device(device) diff --git a/nemo_automodel/recipes/diffusion/train.py b/nemo_automodel/recipes/diffusion/train.py index 8b7c291433..3d5fc4a728 100644 --- a/nemo_automodel/recipes/diffusion/train.py +++ b/nemo_automodel/recipes/diffusion/train.py @@ -100,6 +100,7 @@ def build_model_and_optimizer( pipeline_spec: Optional[Dict[str, Any]] = None, peft_cfg=None, model_type=None, + active_transformer: Optional[str] = None, ) -> tuple[NeMoAutoDiffusionPipeline, torch.optim.Optimizer, Any]: """Build the diffusion model, parallel scheme, and optimizer. @@ -122,6 +123,10 @@ def build_model_and_optimizer( peft_cfg: PeftConfig instance or None. When provided, only LoRA params are trained; base weights are frozen and sharded by FSDP2 for memory. model_type: "flux" | "wan" | "hunyuan". Required when peft_cfg is provided. + active_transformer: For two-transformer pipelines (Wan2.2), select which + transformer to finetune. ``"transformer"`` (default for Wan2.2 = high-noise) + or ``"transformer_2"`` (low-noise). The unused transformer is dropped + before device placement so only one transformer lives on GPU. Returns: Tuple of (pipeline, optimizer, device_mesh or None). @@ -212,6 +217,8 @@ def build_model_and_optimizer( if finetune_mode: # Finetuning: load from pretrained weights logging.info("[INFO] Loading pretrained model for finetuning") + if active_transformer is not None: + logging.info("[INFO] Active transformer: %s", active_transformer) pipe, created_managers = NeMoAutoDiffusionPipeline.from_pretrained( model_id, torch_dtype=dtype, @@ -222,6 +229,7 @@ def build_model_and_optimizer( low_cpu_mem_usage=True, peft_cfg=peft_cfg, model_type=model_type, + active_transformer=active_transformer, ) else: # Pretraining: initialize with random weights using pipeline_spec @@ -413,6 +421,13 @@ def setup(self): if self.dist_env.is_main and hasattr(self.cfg, "wandb"): suppress_wandb_log_messages() + # For two-stage Wan2.2 finetuning, suffix the wandb run name with the + # active stage so high-noise and low-noise runs are distinguishable. + stage_for_wandb = self.cfg.get("model.stage", None) + if stage_for_wandb is not None: + current_name = self.cfg.get("wandb.name", None) + if current_name is not None and not str(current_name).endswith(f"_{stage_for_wandb}"): + self.cfg.wandb.name = f"{current_name}_{stage_for_wandb}" run = build_wandb(self.cfg) if run is not None: logging.info("🚀 View run at {}".format(run.url)) @@ -491,6 +506,20 @@ def setup(self): adapter_kwargs.to_dict() if hasattr(adapter_kwargs, "to_dict") else dict(adapter_kwargs or {}) ) + # Two-stage finetuning (Wan2.2 T2V-A14B): each stage trains only one + # transformer against a restricted timestep range. The stage knob both + # selects the active transformer and clamps the sigma sampling window so + # this run only sees noise levels its transformer is responsible for. + self.stage = self.cfg.get("model.stage", None) + self.boundary_ratio = self.cfg.get("model.boundary_ratio", None) + self.active_transformer = None + if self.stage is not None: + stage = str(self.stage).lower() + if stage not in ("high_noise", "low_noise"): + raise ValueError(f"model.stage must be 'high_noise' or 'low_noise', got {self.stage!r}") + self.stage = stage + self.active_transformer = "transformer" if stage == "high_noise" else "transformer_2" + logging.info("[INFO] Flow Matching V2 Pipeline") logging.info(f"[INFO] - Adapter type: {self.adapter_type}") logging.info(f"[INFO] - Timestep sampling: {self.timestep_sampling}") @@ -500,6 +529,8 @@ def setup(self): logging.info(f"[INFO] - CFG dropout prob: {self.cfg_dropout_prob}") logging.info(f"[INFO] - Use loss weighting: {self.use_loss_weighting}") logging.info(f"[INFO] - Loss weighting scheme: {self.loss_weighting_scheme}") + if self.stage is not None: + logging.info(f"[INFO] - Two-stage finetune: stage={self.stage}, active={self.active_transformer}") # Get pipeline_spec for pretraining mode (required when mode != "finetune") pipeline_spec_cfg = self.cfg.get("model.pipeline_spec", None) @@ -541,11 +572,38 @@ def setup(self): pipeline_spec=pipeline_spec, peft_cfg=self.peft_cfg, model_type=self.model_type, + active_transformer=self.active_transformer, ) self.model = self.pipe.transformer self.peft_config = getattr(self.pipe, "_peft_config", None) + # Resolve sigma range for two-stage finetuning now that the pipeline + # is loaded and we can read its boundary_ratio config. + if self.stage is not None: + if self.boundary_ratio is None: + pipe_cfg = getattr(self.pipe, "config", None) + self.boundary_ratio = pipe_cfg.get("boundary_ratio") if pipe_cfg is not None else None + if self.boundary_ratio is None: + raise ValueError( + "model.stage is set but no boundary_ratio could be resolved. " + "Set model.boundary_ratio in YAML, or use a pipeline whose config " + "carries boundary_ratio (e.g. Wan-AI/Wan2.2-T2V-A14B-Diffusers)." + ) + self.boundary_ratio = float(self.boundary_ratio) + if self.stage == "high_noise": + self.sigma_min = self.boundary_ratio + self.sigma_max = 1.0 + else: + self.sigma_min = 0.0 + self.sigma_max = self.boundary_ratio + logging.info( + "[INFO] - Stage sigma range: [%.4f, %.4f] (boundary_ratio=%.4f)", + self.sigma_min, + self.sigma_max, + self.boundary_ratio, + ) + checkpoint_cfg = self.cfg.get("checkpoint", None) self.num_epochs = self.cfg.step_scheduler.num_epochs diff --git a/tools/diffusion/preprocessing_multiprocess.py b/tools/diffusion/preprocessing_multiprocess.py index 2dc1883d7c..f85837da40 100644 --- a/tools/diffusion/preprocessing_multiprocess.py +++ b/tools/diffusion/preprocessing_multiprocess.py @@ -1054,7 +1054,7 @@ def main(): "--processor", type=str, required=True, - choices=["wan", "wan2.1", "hunyuan", "hunyuanvideo", "hunyuanvideo-1.5"], + choices=["wan", "wan2.1", "wan2.2", "hunyuan", "hunyuanvideo", "hunyuanvideo-1.5"], ) video_parser.add_argument("--model_name", type=str, default=None, help="Model name (uses processor default)") video_parser.add_argument("--mode", type=str, default="video", choices=["video", "frames"], help="Processing mode") diff --git a/tools/diffusion/processors/__init__.py b/tools/diffusion/processors/__init__.py index 523627bfb9..fa4971335c 100644 --- a/tools/diffusion/processors/__init__.py +++ b/tools/diffusion/processors/__init__.py @@ -26,7 +26,7 @@ from .hunyuan import HunyuanVideoProcessor from .qwen_image import QwenImageProcessor from .registry import ProcessorRegistry -from .wan import WanProcessor +from .wan import Wan22Processor, WanProcessor __all__ = [ # Base classes @@ -39,6 +39,7 @@ "QwenImageProcessor", # Video processors "WanProcessor", + "Wan22Processor", "HunyuanVideoProcessor", # Caption loaders "CaptionLoader", diff --git a/tools/diffusion/processors/wan.py b/tools/diffusion/processors/wan.py index ee05b06460..a0c63ab923 100644 --- a/tools/diffusion/processors/wan.py +++ b/tools/diffusion/processors/wan.py @@ -13,12 +13,17 @@ # limitations under the License. """ -Wan2.1 video model processor for preprocessing. +Wan video model processors for preprocessing. -Handles Wan2.1-T2V models (1.3B and 14B variants) with: +Handles Wan2.1-T2V and Wan2.2-T2V-A14B with: - AutoencoderKLWan for video encoding - UMT5 text encoder for text conditioning - Latent normalization using latents_mean and latents_std + +Wan2.1 and Wan2.2 share the same VAE class and UMT5 text encoder, but their +hub weights and ``latents_mean`` / ``latents_std`` may differ. Use the +``wan2.2`` processor to preprocess data targeting Wan2.2 finetuning so the +cache stays clearly separated from any existing Wan2.1 cache. """ import html @@ -107,9 +112,10 @@ def load_models(self, model_name: str, device: str) -> Dict[str, Any]: - text_encoder: UMT5EncoderModel - tokenizer: AutoTokenizer """ - from diffusers import AutoencoderKLWan from transformers import AutoTokenizer, UMT5EncoderModel + from diffusers import AutoencoderKLWan + dtype = torch.float16 if "cuda" in device else torch.float32 # UMT5 requires bfloat16 (float16 causes overflow/zeros in attention and layer norm) text_encoder_dtype = torch.bfloat16 if "cuda" in device else torch.float32 @@ -353,7 +359,35 @@ def get_cache_data( "video_path": metadata.get("video_path"), # Processing settings "deterministic_latents": metadata.get("deterministic", True), - "model_version": "wan2.1", + "model_version": self.model_version, "processing_mode": metadata.get("mode", "video"), "model_type": self.model_type, } + + @property + def model_version(self) -> str: + return "wan2.1" + + +@ProcessorRegistry.register("wan2.2") +class Wan22Processor(WanProcessor): + """ + Processor for Wan2.2-T2V-A14B (two-stage) video model. + + Wan2.2 reuses the same ``AutoencoderKLWan`` VAE class and UMT5 text encoder + as Wan2.1, but pulls VAE / text-encoder weights from the A14B hub. Cache + files emitted by this processor record ``model_version: "wan2.2"`` so + Wan2.1 and Wan2.2 caches remain unambiguous side-by-side. + """ + + @property + def model_type(self) -> str: + return "wan22" + + @property + def model_version(self) -> str: + return "wan2.2" + + @property + def default_model_name(self) -> str: + return "Wan-AI/Wan2.2-T2V-A14B-Diffusers" From 960aa6f8f17a149caddc624d2bc0d0bb80a23383 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 21 May 2026 12:53:35 -0700 Subject: [PATCH 2/4] fix the memory management for training large 14B wan model --- .../components/distributed/parallelizer.py | 12 ++++++++++++ nemo_automodel/recipes/diffusion/train.py | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/nemo_automodel/components/distributed/parallelizer.py b/nemo_automodel/components/distributed/parallelizer.py index 23ad6d6478..96995df95b 100644 --- a/nemo_automodel/components/distributed/parallelizer.py +++ b/nemo_automodel/components/distributed/parallelizer.py @@ -568,6 +568,18 @@ def parallelize( except Exception as e: logger.warning(f"Wan strategy: failed to TP blocks/proj_out: {e}") + # Activation checkpointing wraps every WanTransformerBlock so its + # forward activations are recomputed on backward instead of being + # held in memory. Critical for Wan2.2-A14B (14B params, ~30k-token + # video sequence) — without this, fp32 layer-norm casts in the block + # forward will OOM even on 8x80GB H100. + if activation_checkpointing and hasattr(model, "blocks"): + for idx in range(len(model.blocks)): + model.blocks[idx] = checkpoint_wrapper( + model.blocks[idx], + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + # Mixed precision default like Default strategy if not mp_policy: mp_policy = MixedPrecisionPolicy( diff --git a/nemo_automodel/recipes/diffusion/train.py b/nemo_automodel/recipes/diffusion/train.py index 3d5fc4a728..08f175a912 100644 --- a/nemo_automodel/recipes/diffusion/train.py +++ b/nemo_automodel/recipes/diffusion/train.py @@ -24,7 +24,7 @@ import torch.distributed as dist import wandb from huggingface_hub.constants import HF_HUB_CACHE -from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy from nemo_automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig @@ -210,6 +210,10 @@ def build_model_and_optimizer( reduce_dtype=torch.float32, output_dtype=dtype, ), + # CPU offload: when enabled, sharded params + optimizer state live on + # host RAM and are paged to GPU per-block during forward/backward. + # Saves ~21 GB per H100 for a 14B AdamW finetune; adds H2D traffic. + "offload_policy": CPUOffloadPolicy(pin_memory=True) if cpu_offload else None, } parallel_scheme = {"transformer": manager_args} From 17cdb7d035792516f2bd54bbc7cfe76de1a4f00f Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 21 May 2026 13:18:43 -0700 Subject: [PATCH 3/4] fix wan2.2 support --- nemo_automodel/recipes/base_recipe.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nemo_automodel/recipes/base_recipe.py b/nemo_automodel/recipes/base_recipe.py index 2bf3af9350..0ee2e2fbb3 100644 --- a/nemo_automodel/recipes/base_recipe.py +++ b/nemo_automodel/recipes/base_recipe.py @@ -269,6 +269,14 @@ def save_checkpoint( # Wait for any in-flight checkpoint (async case) to complete self.checkpointer.async_wait() + # Free GPU caches before DCP's gather-and-write. DCP allocates NCCL + # workspace and materializes DTensor shards on GPU; with CPU-offloaded + # FSDP2 the residual training-time fragments can leave just enough + # headroom to break the gather (cuda failure 2 / "out of memory"). + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + # If a previous async checkpoint just finished, update the "latest" symlink now prev_pending = getattr(self, "_last_pending_checkpoint_dir", None) is_dist_initialized = torch.distributed.is_initialized() From 03b46bd9628294f729d1e12f1667d7b92a365f70 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 21 May 2026 17:27:55 -0700 Subject: [PATCH 4/4] all good for wan2.2 --- .../generate/configs/generate_wan22.yaml | 12 +++-- examples/diffusion/generate/generate.py | 54 +++++++++++++++++++ nemo_automodel/recipes/base_recipe.py | 7 +++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/examples/diffusion/generate/configs/generate_wan22.yaml b/examples/diffusion/generate/configs/generate_wan22.yaml index adb9776aa9..e8d3532b01 100644 --- a/examples/diffusion/generate/configs/generate_wan22.yaml +++ b/examples/diffusion/generate/configs/generate_wan22.yaml @@ -6,8 +6,8 @@ model: # - both set → swap both stages' weights # Each path should be a training checkpoint dir produced by # examples/diffusion/finetune/wan2_2_t2v_flow.yaml with the matching model.stage. - checkpoint_high_noise: null # set to: PATH_TO_CKPT/wan22_high/checkpoint-1000 - checkpoint_low_noise: null # set to: PATH_TO_CKPT/wan22_low/checkpoint-1000 + checkpoint_high_noise: ./WAN22_CKPT/wan22_high/epoch_19_step_99 + checkpoint_low_noise: ./WAN22_CKPT/wan22_low/epoch_19_step_99 inference: num_inference_steps: 50 @@ -18,9 +18,13 @@ inference: dtype: "bfloat16" max_samples: 10 prompts: - - "A cat sitting on a windowsill watching the rain" + - "The video shows a scene from an anime, featuring two characters standing on a ledge or shelf. The camera is static. A door opens from the left, revealing a man. The man is Caucasian, appears to be in his late 30s, and is dressed in outdoor attire. He wears a tan vest over a white shirt, paired with light brown pants and brown shoes. A rope is slung over his shoulder, and he has a backpack on. He also sports goggles on his forehead. He walks to the center of the frame. To the right of the man, a young girl stands with her arms crossed. She has fair skin and red hair styled with two small buns on top of her head. She wears a red dress with a matching bag slung over her shoulder and brown boots. The background is dark, with a window visible behind them. The scene is dimly lit, creating a somber mood. The man speaks" + - "The video presents a close-up shot of an animated young man with fair skin, likely of East Asian descent, with short, dark blue hair that is neatly styled. He is wearing a white, collared shirt with the top buttons undone. The camera is static, maintaining a steady focus on his face. The background features a soft, blurred mix of green and blue hues, suggesting an outdoor setting with trees and sky. The young man's expression is thoughtful and slightly melancholic. His eyes are downcast, and his mouth is slightly downturned. He appears to be listening intently or reflecting on something. The overall style of the video is reminiscent of a classic anime or animated film, with soft lines and gentle color palettes." + - "The video is a still shot of a scene from an animated movie. The camera is static and shows a narrow, dimly lit hallway. At the end of the hallway, a young girl with short, dark hair stands on a slightly elevated platform. She is wearing a red dress and faces away from the camera. The hallway is cluttered with various objects. On the left side, there are tools leaning against the wall. The walls are made of corrugated metal. On the right side, there is a shelf with a large glass jar on it. In the center of the hallway, there are pipes and metal objects. The overall style is reminiscent of Studio Ghibli's animation style, with detailed backgrounds and a soft color palette." pipeline_kwargs: - num_frames: 81 + # Must satisfy (num_frames - 1) % 4 == 0 due to Wan VAE's 4x temporal downsample. + # Valid choices: 1, 5, 9, ..., 57, 61, 65, ..., 81. 60 is invalid; use 57 or 61. + num_frames: 61 negative_prompt: "" # Low-noise stage guidance (used for timesteps < boundary_ratio * num_train_timesteps). # Wan2.2 A14B uses asymmetric guidance: a stronger scale on the high-noise stage diff --git a/examples/diffusion/generate/generate.py b/examples/diffusion/generate/generate.py index 73c23f9682..05fd923134 100644 --- a/examples/diffusion/generate/generate.py +++ b/examples/diffusion/generate/generate.py @@ -275,6 +275,16 @@ def _load_checkpoint_into_attr(pipe, attr_name, checkpoint, torch_dtype): new_module = _load_sharded_fsdp_checkpoint(target, str(sharded_dir), torch_dtype) new_module.to("cuda", dtype=torch_dtype) setattr(pipe, attr_name, new_module) + elif sharded_dir.is_dir() and any( + name.startswith("shard-") and name.endswith(".safetensors") for name in os.listdir(sharded_dir) + ): + # NeMo-AutoModel sharded HF safetensors: one ``shard-XXXXX-*.safetensors`` + # per FSDP rank from a run that set ``save_consolidated: false``. Use the + # DCP HuggingFaceStorageReader to materialize the full state dict. + logger.info("Loading sharded HF safetensors checkpoint from %s into %s", sharded_dir, attr_name) + new_module = _load_sharded_hf_safetensors_checkpoint(target, str(sharded_dir), torch_dtype) + new_module.to("cuda", dtype=torch_dtype) + setattr(pipe, attr_name, new_module) else: logger.warning( "No recognized checkpoint format found in %s, leaving %s at base weights", @@ -362,6 +372,50 @@ def _load_sharded_fsdp_checkpoint(transformer, sharded_dir, torch_dtype=torch.bf dist.destroy_process_group() +def _load_sharded_hf_safetensors_checkpoint(transformer, sharded_dir, torch_dtype=torch.bfloat16): + """Load NeMo-AutoModel sharded HF safetensors checkpoint into a transformer. + + Handles directories containing ``shard-XXXXX-model-XXXXX-of-XXXXX.safetensors`` + files produced by training runs with ``save_consolidated: false``. Uses DCP's + ``HuggingFaceStorageReader`` to gather all shards into the target state dict. + + Args: + transformer: The transformer nn.Module to load weights into. + sharded_dir: Path to the directory containing shard-*.safetensors files. + torch_dtype: The dtype to cast the transformer to before loading. + + Returns: + The transformer module with the merged state dict loaded. + """ + from torch.distributed.checkpoint import load as dist_load + + # Prefer the upstream HF storage reader; fall back to NeMo's backport if + # the torch version is too old to ship it. + try: + from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageReader + except ImportError: + from nemo_automodel.components.checkpoint._backports.hf_storage import ( + _HuggingFaceStorageReader as HuggingFaceStorageReader, + ) + + init_dist = False + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + dist.init_process_group(backend="gloo", rank=0, world_size=1) + init_dist = True + + try: + transformer.to(device="cuda", dtype=torch_dtype) + state_dict = transformer.state_dict() + dist_load(state_dict=state_dict, storage_reader=HuggingFaceStorageReader(path=sharded_dir)) + transformer.load_state_dict(state_dict, strict=True) + return transformer + finally: + if init_dist: + dist.destroy_process_group() + + def apply_optimizations(pipe, cfg): """Apply VAE and memory optimizations to the pipeline. diff --git a/nemo_automodel/recipes/base_recipe.py b/nemo_automodel/recipes/base_recipe.py index 0ee2e2fbb3..ccd30994cb 100644 --- a/nemo_automodel/recipes/base_recipe.py +++ b/nemo_automodel/recipes/base_recipe.py @@ -415,6 +415,13 @@ def to_item(x): if is_dist_initialized: torch.distributed.barrier() + # Release NCCL workspace and DCP gather scratch back to the allocator. + # Without this, the next training step's backward sees a fragmented + # heap (~74 GB still resident on tight 14B FSDP2 runs) and OOMs. + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + def _update_checkpoint_symlink(self, link_name: str, target_dir: str) -> None: """ Create or update a symlink named `link_name` under the checkpoint root