Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions examples/diffusion/finetune/wan2_2_t2v_flow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
seed: 42

wandb:
project: wan-t2v-flow-matching
mode: online
# Stage name is auto-appended (e.g. wan2_2_t2v_fm_high_noise) when model.stage is set.
name: wan2_2_t2v_fm

dist_env:
backend: nccl
timeout_minutes: 30

model:
pretrained_model_name_or_path: Wan-AI/Wan2.2-T2V-A14B-Diffusers
mode: finetune
# Two-stage finetuning knob. Required for Wan2.2-T2V-A14B.
# "high_noise" trains pipe.transformer on sigma in [boundary_ratio, 1.0].
# "low_noise" trains pipe.transformer_2 on sigma in [0.0, boundary_ratio].
# Run twice (once per stage) with different checkpoint dirs to fully finetune.
stage: high_noise
# Optional override; if unset the recipe reads boundary_ratio from the
# pipeline config (Wan2.2-T2V-A14B ships with 0.875).
boundary_ratio: 0.875

step_scheduler:
global_batch_size: 8
local_batch_size: 1
ckpt_every_steps: 1000
num_epochs: 100
log_every: 2
save_checkpoint_every_epoch: false

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
# Point this at the cache produced by tools/diffusion/processors/wan.py --variant wan22
cache_dir: PATH_TO_YOUR_WAN22_DATA
model_type: wan
base_resolution: [512, 512]
dynamic_batch_size: false
shuffle: true
drop_last: false
num_workers: 2

optim:
learning_rate: 5e-6
optimizer:
weight_decay: 0.01
betas: [0.9, 0.999]
foreach: true

performance:
check_loss: false
grad_clip_foreach: true

lr_scheduler:
lr_decay_style: cosine
lr_warmup_steps: 0
min_lr: 1e-6

# Flow matching V2 configuration. sigma_min/sigma_max are derived from
# model.stage + boundary_ratio at runtime; any values set here are ignored.
flow_matching:
adapter_type: "simple"
adapter_kwargs: {}
timestep_sampling: "uniform"
logit_mean: 0.0
logit_std: 1.0
flow_shift: 3.0
mix_uniform_ratio: 0.1
num_train_timesteps: 1000
i2v_prob: 0.3
use_loss_weighting: true
log_interval: 100
summary_log_interval: 10

fsdp:
tp_size: 1
cp_size: 1
pp_size: 1
dp_replicate_size: 1
# 14B transformer needs aggressive sharding; bump dp_size beyond Wan2.1's default.
dp_size: 8
defer_fsdp_grad_sync: true
enable_fsdp2_prefetch: true
fsdp2_backward_prefetch_depth: 2
fsdp2_forward_prefetch_depth: 1
# Explicitly required for the 14B model to fit at bf16.
activation_checkpointing: true

checkpoint:
enabled: true
# Use a different checkpoint_dir per stage so the two stages stay distinct
# (e.g. .../wan22_high/ and .../wan22_low/). Both are then consumed by
# examples/diffusion/generate/configs/generate_wan22.yaml.
checkpoint_dir: PATH_TO_YOUR_CKPT_DIR
model_save_format: safetensors
save_consolidated: true
diffusers_compatible: true
restore_from: null

ci:
recipe_owner: linnanw
time: "00:30:00"
48 changes: 48 additions & 0 deletions examples/diffusion/generate/configs/generate_wan22.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
model:
pretrained_model_name_or_path: "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
# Two-transformer loading. Either, both, or neither may be set:
# - both unset → use the hub-pretrained weights for both stages
# - one set → swap that stage's weights, leave the other stage at hub baseline
# - both set → swap both stages' weights
# Each path should be a training checkpoint dir produced by
# examples/diffusion/finetune/wan2_2_t2v_flow.yaml with the matching model.stage.
checkpoint_high_noise: ./WAN22_CKPT/wan22_high/epoch_19_step_99
checkpoint_low_noise: ./WAN22_CKPT/wan22_low/epoch_19_step_99

inference:
num_inference_steps: 50
# High-noise stage guidance (used for timesteps >= boundary_ratio * num_train_timesteps).
guidance_scale: 5.0
height: 480
width: 832
dtype: "bfloat16"
max_samples: 10
prompts:
- "The video shows a scene from an anime, featuring two characters standing on a ledge or shelf. The camera is static. A door opens from the left, revealing a man. The man is Caucasian, appears to be in his late 30s, and is dressed in outdoor attire. He wears a tan vest over a white shirt, paired with light brown pants and brown shoes. A rope is slung over his shoulder, and he has a backpack on. He also sports goggles on his forehead. He walks to the center of the frame. To the right of the man, a young girl stands with her arms crossed. She has fair skin and red hair styled with two small buns on top of her head. She wears a red dress with a matching bag slung over her shoulder and brown boots. The background is dark, with a window visible behind them. The scene is dimly lit, creating a somber mood. The man speaks"
- "The video presents a close-up shot of an animated young man with fair skin, likely of East Asian descent, with short, dark blue hair that is neatly styled. He is wearing a white, collared shirt with the top buttons undone. The camera is static, maintaining a steady focus on his face. The background features a soft, blurred mix of green and blue hues, suggesting an outdoor setting with trees and sky. The young man's expression is thoughtful and slightly melancholic. His eyes are downcast, and his mouth is slightly downturned. He appears to be listening intently or reflecting on something. The overall style of the video is reminiscent of a classic anime or animated film, with soft lines and gentle color palettes."
- "The video is a still shot of a scene from an animated movie. The camera is static and shows a narrow, dimly lit hallway. At the end of the hallway, a young girl with short, dark hair stands on a slightly elevated platform. She is wearing a red dress and faces away from the camera. The hallway is cluttered with various objects. On the left side, there are tools leaning against the wall. The walls are made of corrugated metal. On the right side, there is a shelf with a large glass jar on it. In the center of the hallway, there are pipes and metal objects. The overall style is reminiscent of Studio Ghibli's animation style, with detailed backgrounds and a soft color palette."
pipeline_kwargs:
# Must satisfy (num_frames - 1) % 4 == 0 due to Wan VAE's 4x temporal downsample.
# Valid choices: 1, 5, 9, ..., 57, 61, 65, ..., 81. 60 is invalid; use 57 or 61.
num_frames: 61
negative_prompt: ""
# Low-noise stage guidance (used for timesteps < boundary_ratio * num_train_timesteps).
# Wan2.2 A14B uses asymmetric guidance: a stronger scale on the high-noise stage
# and a milder one on the low-noise stage. Tune per-prompt as needed.
guidance_scale_2: 5.0

output:
output_dir: "./inference_outputs"
fps: 16

distributed: null

vae:
enable_slicing: true
enable_tiling: true
# WanPipeline cpu_offload_seq cycles text_encoder->transformer->transformer_2->vae,
# so enabling cpu offload keeps only one transformer on GPU at a time. Strongly
# recommended for A14B (two 14B transformers won't co-reside on a single GPU).
enable_cpu_offload: true

seed: 42
148 changes: 122 additions & 26 deletions examples/diffusion/generate/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,27 +183,63 @@ def _build_parallel_scheme(scheme_cfg, dist_info):


def load_checkpoint_into_pipeline(pipe, cfg):
"""Load a training checkpoint into the pipeline's transformer.
"""Load training checkpoint(s) into the pipeline's transformer(s).

Expects a consolidated HF safetensors checkpoint produced by training
with model_save_format: safetensors, save_consolidated: true, and
diffusers_compatible: true. The checkpoint directory should contain
model/consolidated/ with diffusion_pytorch_model.safetensors.index.json
and the corresponding safetensors files.
Supports both single-transformer pipelines (Wan2.1, FLUX, HunyuanVideo) and
two-transformer pipelines (Wan2.2-T2V-A14B with ``transformer`` for the
high-noise stage and ``transformer_2`` for the low-noise stage).

Uses the standard diffusers from_pretrained() API for loading.
Single-transformer path: set ``model.checkpoint`` to load into ``pipe.transformer``.

Two-transformer path (Wan2.2): set ``model.checkpoint_high_noise`` and/or
``model.checkpoint_low_noise``. Each is independently optional — a missing
one leaves that stage's transformer at its hub-pretrained weights, which is
useful for sanity-checking a partial finetune.

Expects consolidated HF safetensors checkpoints produced by training with
``model_save_format: safetensors`` and ``save_consolidated: true``.

Args:
pipe: The diffusion pipeline with a `.transformer` attribute.
cfg: Config node with `model.checkpoint` path.
pipe: The diffusion pipeline. May expose ``transformer_2`` for Wan2.2.
cfg: Config node with one of ``model.checkpoint``,
``model.checkpoint_high_noise``, or ``model.checkpoint_low_noise``.

Raises:
ValueError: If both single-stage and two-stage checkpoint fields are set.
"""
checkpoint = getattr(cfg.model, "checkpoint", None)
if not checkpoint:
return
checkpoint_high = getattr(cfg.model, "checkpoint_high_noise", None)
checkpoint_low = getattr(cfg.model, "checkpoint_low_noise", None)

if checkpoint and (checkpoint_high or checkpoint_low):
raise ValueError(
"model.checkpoint is mutually exclusive with "
"model.checkpoint_high_noise / model.checkpoint_low_noise. "
"Use the latter pair for two-transformer pipelines (Wan2.2) and "
"model.checkpoint for single-transformer pipelines."
)

dtype_str = getattr(cfg.inference, "dtype", "bfloat16")
torch_dtype = _resolve_dtype(dtype_str)

if checkpoint:
_load_checkpoint_into_attr(pipe, "transformer", checkpoint, torch_dtype)
return

if checkpoint_high:
_load_checkpoint_into_attr(pipe, "transformer", checkpoint_high, torch_dtype)
if checkpoint_low:
if getattr(pipe, "transformer_2", None) is None:
raise ValueError(
"model.checkpoint_low_noise is set but the loaded pipeline has no "
"transformer_2 attribute. This option only applies to two-stage "
"models like Wan2.2-T2V-A14B."
)
_load_checkpoint_into_attr(pipe, "transformer_2", checkpoint_low, torch_dtype)


def _load_checkpoint_into_attr(pipe, attr_name, checkpoint, torch_dtype):
"""Load a single consolidated/sharded checkpoint into ``pipe.<attr_name>``."""
checkpoint_dir = Path(checkpoint)
if not checkpoint_dir.exists():
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
Expand All @@ -213,32 +249,48 @@ def load_checkpoint_into_pipeline(pipe, cfg):
consolidated_st_dir = checkpoint_dir / "model" / "consolidated"
sharded_dir = checkpoint_dir / "model"

target = getattr(pipe, attr_name)
if target is None:
raise AttributeError(f"Pipeline has no attribute {attr_name!r} to load checkpoint into")

if ema_path.exists():
logger.info("Loading EMA checkpoint from %s", ema_path)
logger.info("Loading EMA checkpoint from %s into %s", ema_path, attr_name)
ema_state = torch.load(ema_path, map_location="cuda", weights_only=True)
pipe.transformer.load_state_dict(ema_state, strict=True)
logger.info("Loaded EMA checkpoint")
target.load_state_dict(ema_state, strict=True)
elif consolidated_path.exists():
logger.info("Loading consolidated checkpoint from %s", consolidated_path)
logger.info("Loading consolidated checkpoint from %s into %s", consolidated_path, attr_name)
state_dict = torch.load(consolidated_path, map_location="cuda", weights_only=True)
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
pipe.transformer.load_state_dict(state_dict, strict=True)
logger.info("Loaded consolidated checkpoint")
target.load_state_dict(state_dict, strict=True)
elif consolidated_st_dir.is_dir() and any(
name.endswith(".safetensors") for name in os.listdir(consolidated_st_dir)
):
logger.info("Loading consolidated safetensors checkpoint from %s", consolidated_st_dir)
pipe.transformer = type(pipe.transformer).from_pretrained(str(consolidated_st_dir), torch_dtype=torch_dtype)
pipe.transformer.to("cuda")
logger.info("Loaded consolidated safetensors checkpoint")
logger.info("Loading consolidated safetensors checkpoint from %s into %s", consolidated_st_dir, attr_name)
new_module = type(target).from_pretrained(str(consolidated_st_dir), torch_dtype=torch_dtype)
new_module.to("cuda")
setattr(pipe, attr_name, new_module)
elif sharded_dir.is_dir() and any(name.endswith(".distcp") for name in os.listdir(sharded_dir)):
logger.info("Loading sharded FSDP checkpoint from %s", sharded_dir)
pipe.transformer = _load_sharded_fsdp_checkpoint(pipe.transformer, str(sharded_dir), torch_dtype)
pipe.transformer.to("cuda", dtype=torch_dtype)
logger.info("Loaded sharded FSDP checkpoint")
logger.info("Loading sharded FSDP checkpoint from %s into %s", sharded_dir, attr_name)
new_module = _load_sharded_fsdp_checkpoint(target, str(sharded_dir), torch_dtype)
new_module.to("cuda", dtype=torch_dtype)
setattr(pipe, attr_name, new_module)
elif sharded_dir.is_dir() and any(
name.startswith("shard-") and name.endswith(".safetensors") for name in os.listdir(sharded_dir)
):
# NeMo-AutoModel sharded HF safetensors: one ``shard-XXXXX-*.safetensors``
# per FSDP rank from a run that set ``save_consolidated: false``. Use the
# DCP HuggingFaceStorageReader to materialize the full state dict.
logger.info("Loading sharded HF safetensors checkpoint from %s into %s", sharded_dir, attr_name)
new_module = _load_sharded_hf_safetensors_checkpoint(target, str(sharded_dir), torch_dtype)
new_module.to("cuda", dtype=torch_dtype)
setattr(pipe, attr_name, new_module)
else:
logger.warning("No recognized checkpoint format found in %s, using base model weights", checkpoint_dir)
logger.warning(
"No recognized checkpoint format found in %s, leaving %s at base weights",
checkpoint_dir,
attr_name,
)


def load_lora_weights_into_pipeline(pipe, cfg):
Expand Down Expand Up @@ -320,6 +372,50 @@ def _load_sharded_fsdp_checkpoint(transformer, sharded_dir, torch_dtype=torch.bf
dist.destroy_process_group()


def _load_sharded_hf_safetensors_checkpoint(transformer, sharded_dir, torch_dtype=torch.bfloat16):
"""Load NeMo-AutoModel sharded HF safetensors checkpoint into a transformer.

Handles directories containing ``shard-XXXXX-model-XXXXX-of-XXXXX.safetensors``
files produced by training runs with ``save_consolidated: false``. Uses DCP's
``HuggingFaceStorageReader`` to gather all shards into the target state dict.

Args:
transformer: The transformer nn.Module to load weights into.
sharded_dir: Path to the directory containing shard-*.safetensors files.
torch_dtype: The dtype to cast the transformer to before loading.

Returns:
The transformer module with the merged state dict loaded.
"""
from torch.distributed.checkpoint import load as dist_load

# Prefer the upstream HF storage reader; fall back to NeMo's backport if
# the torch version is too old to ship it.
try:
from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageReader
except ImportError:
from nemo_automodel.components.checkpoint._backports.hf_storage import (
_HuggingFaceStorageReader as HuggingFaceStorageReader,
)

init_dist = False
if not dist.is_initialized():
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
dist.init_process_group(backend="gloo", rank=0, world_size=1)
init_dist = True

try:
transformer.to(device="cuda", dtype=torch_dtype)
state_dict = transformer.state_dict()
dist_load(state_dict=state_dict, storage_reader=HuggingFaceStorageReader(path=sharded_dir))
transformer.load_state_dict(state_dict, strict=True)
return transformer
finally:
if init_dist:
dist.destroy_process_group()


def apply_optimizations(pipe, cfg):
"""Apply VAE and memory optimizations to the pipeline.

Expand Down
Loading
Loading