feat: Abstract LLM/VLM forward-backward step#2228
Conversation
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
| distributed_config: Any, | ||
| loss_fn: Callable[..., torch.Tensor] | None, | ||
| calculate_loss_fn: Callable[..., torch.Tensor], | ||
| loss_buffer: list[torch.Tensor], |
There was a problem hiding this comment.
I would let the called handle this, and instead return the loss to the caller
| loss_fn: Callable[..., torch.Tensor] | None, | ||
| calculate_loss_fn: Callable[..., torch.Tensor], |
There was a problem hiding this comment.
can we consolidate the two
- loss_fn
- calculate_loss_fn
?
| pp_enabled: bool, | ||
| pp: Any | None, |
There was a problem hiding this comment.
is there a strong use-case for pp_enabled? if not, i would do something like pp_enabled = (pp is not None) inside forward_backward_step's body.
| is_train: bool, | ||
| pp_enabled: bool, | ||
| pp: Any | None, | ||
| dp_group_size: int, |
There was a problem hiding this comment.
can we derive this from device_mesh by placing a convention on mesh naming?
| return value | ||
|
|
||
|
|
||
| def forward_backward_step( |
There was a problem hiding this comment.
I feel this function is challenging to get to the right abstraction, because it is trying to abstract LLM/VLM + train/eval + PP/non-PP + CP + FP8 + FSDP sync + loss calculation + modality hooks, so I'm thinking what ways could be explored to simplify it, without losing critical functionality.
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Summary
schedule.eval.Why
The LLM and VLM recipes had largely duplicated forward/backward control flow. VLM PP validation was previously skipped because validation used a direct model-forward path and did not prepare VLM media tensors for PP microbatches.
Validation
python -m py_compile nemo_automodel/components/training/forward_backward.py nemo_automodel/recipes/llm/train_ft.py nemo_automodel/recipes/vlm/finetune.pypython -m ruff check nemo_automodel/components/training/forward_backward.py nemo_automodel/recipes/llm/train_ft.py nemo_automodel/recipes/vlm/finetune.pygit diff --checkpytest -q tests/unit_tests/recipes/test_train_ft.pypytest -q tests/unit_tests/recipes/test_finetune_vlm_helpers.pyloss 2.3266,num_label_tokens 930loss 2.2712,num_label_tokens 941