Skip to content

feat: Abstract LLM/VLM forward-backward step#2228

Draft
HuiyingLi wants to merge 4 commits into
mainfrom
huiyingl/abstract-forward-backward-step
Draft

feat: Abstract LLM/VLM forward-backward step#2228
HuiyingLi wants to merge 4 commits into
mainfrom
huiyingl/abstract-forward-backward-step

Conversation

@HuiyingLi
Copy link
Copy Markdown
Contributor

Summary

  • Add a shared forward/backward helper for LLM and VLM fine-tuning paths.
  • Route LLM and VLM recipes through the shared helper while keeping recipe-specific hooks for CP prep, FP8 context, and VLM PP media staging.
  • Enable VLM pipeline-parallel validation by preparing validation dataloader media chunks for PP and running validation through schedule.eval.

Why

The LLM and VLM recipes had largely duplicated forward/backward control flow. VLM PP validation was previously skipped because validation used a direct model-forward path and did not prepare VLM media tensors for PP microbatches.

Validation

  • python -m py_compile nemo_automodel/components/training/forward_backward.py nemo_automodel/recipes/llm/train_ft.py nemo_automodel/recipes/vlm/finetune.py
  • python -m ruff check nemo_automodel/components/training/forward_backward.py nemo_automodel/recipes/llm/train_ft.py nemo_automodel/recipes/vlm/finetune.py
  • git diff --check
  • pytest -q tests/unit_tests/recipes/test_train_ft.py
  • pytest -q tests/unit_tests/recipes/test_finetune_vlm_helpers.py
  • Local Qwen3 VL MoE 30B EP4/PP2 smoke run with validation:
    • train: loss 2.3266, num_label_tokens 930
    • val: loss 2.2712, num_label_tokens 941

HuiyingLi added 2 commits May 13, 2026 11:23
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi HuiyingLi changed the title Abstract LLM/VLM forward-backward step feat: Abstract LLM/VLM forward-backward step May 13, 2026
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
distributed_config: Any,
loss_fn: Callable[..., torch.Tensor] | None,
calculate_loss_fn: Callable[..., torch.Tensor],
loss_buffer: list[torch.Tensor],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would let the called handle this, and instead return the loss to the caller

Comment on lines +48 to +49
loss_fn: Callable[..., torch.Tensor] | None,
calculate_loss_fn: Callable[..., torch.Tensor],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we consolidate the two

  • loss_fn
  • calculate_loss_fn
    ?

Comment on lines +54 to +55
pp_enabled: bool,
pp: Any | None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a strong use-case for pp_enabled? if not, i would do something like pp_enabled = (pp is not None) inside forward_backward_step's body.

is_train: bool,
pp_enabled: bool,
pp: Any | None,
dp_group_size: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we derive this from device_mesh by placing a convention on mesh naming?

return value


def forward_backward_step(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this function is challenging to get to the right abstraction, because it is trying to abstract LLM/VLM + train/eval + PP/non-PP + CP + FP8 + FSDP sync + loss calculation + modality hooks, so I'm thinking what ways could be explored to simplify it, without losing critical functionality.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants