diff --git a/scripts/setup.sh b/scripts/setup.sh index d7c8fd209..5d616d72a 100755 --- a/scripts/setup.sh +++ b/scripts/setup.sh @@ -53,18 +53,14 @@ else echo "Skipping git reset/clean (GIT_RESET_CLEAN is not true). Preserving synced working tree." fi -# Install astral-uv -if ! command -v uv >/dev/null 2>&1; then - if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then - echo "Failed to install uv." >&2 - exit 1 - fi - export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH" +# Install astral-uv (standalone version) +# Always prepend standalone install path so it takes precedence over system/conda uv +export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH" +if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then + echo "Failed to install uv." >&2 + exit 1 fi -# Update uv -uv self update - # Sync the dependencies if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then uv sync --all-extras diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py new file mode 100644 index 000000000..e698a7f1d --- /dev/null +++ b/src/art/_backend_training.py @@ -0,0 +1,105 @@ +from collections.abc import Iterable +import time +from typing import Literal + +from . import dev +from .metrics_taxonomy import ( + average_metric_samples, + build_training_summary_metrics, + summarize_trajectory_groups, +) +from .trajectories import TrajectoryGroup +from .types import TrainConfig + + +def build_rl_train_configs( + *, + learning_rate: float, + advantage_balance: float = 0.0, + scale_rewards: bool = True, + importance_sampling_level: Literal[ + "token", "sequence", "average", "geometric_average" + ] = "token", + mask_prob_ratio: bool = False, + ppo: bool = False, + precalculate_logprobs: bool = False, + epsilon: float | None = None, + epsilon_high: float | None = None, + max_negative_advantage_importance_sampling_weight: float | None = None, + kimi_k2_tau: float | None = None, + kl_penalty_coef: float = 0.0, + allow_training_without_logprobs: bool | None = None, + plot_tensors: bool | None = None, + truncated_importance_sampling: float | None = None, + scale_learning_rate_by_reward_std_dev: bool | None = None, + logprob_calculation_chunk_size: int | None = None, + num_trajectories_learning_rate_multiplier_power: float | None = None, + kl_ref_adapter_path: str | None = None, +) -> tuple[TrainConfig, dev.TrainConfig]: + config = TrainConfig( + learning_rate=learning_rate, + kl_penalty_coef=kl_penalty_coef, + ) + dev_config: dev.TrainConfig = { + "advantage_balance": advantage_balance, + "importance_sampling_level": importance_sampling_level, + "kl_penalty_coef": kl_penalty_coef, + "mask_prob_ratio": mask_prob_ratio, + "ppo": ppo, + "precalculate_logprobs": precalculate_logprobs, + "scale_rewards": scale_rewards, + } + + if allow_training_without_logprobs is not None: + dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs + if plot_tensors is not None: + dev_config["plot_tensors"] = plot_tensors + if truncated_importance_sampling is not None: + dev_config["truncated_importance_sampling"] = truncated_importance_sampling + if scale_learning_rate_by_reward_std_dev is not None: + dev_config["scale_learning_rate_by_reward_std_dev"] = ( + scale_learning_rate_by_reward_std_dev + ) + if logprob_calculation_chunk_size is not None: + dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size + if num_trajectories_learning_rate_multiplier_power is not None: + dev_config["num_trajectories_learning_rate_multiplier_power"] = ( + num_trajectories_learning_rate_multiplier_power + ) + if epsilon is not None: + dev_config["epsilon"] = epsilon + if epsilon_high is not None: + dev_config["epsilon_high"] = epsilon_high + if max_negative_advantage_importance_sampling_weight is not None: + dev_config["max_negative_advantage_importance_sampling_weight"] = ( + max_negative_advantage_importance_sampling_weight + ) + if kimi_k2_tau is not None: + dev_config["kimi_k2_tau"] = kimi_k2_tau + if kl_ref_adapter_path is not None: + dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path + + return config, dev_config + + +def aggregate_rl_training_metrics( + *, + training_metrics: list[dict[str, float]], + trajectory_groups: Iterable[TrajectoryGroup], + trainer_started: float, +) -> dict[str, float]: + groups_list = list(trajectory_groups) + avg_metrics = average_metric_samples(training_metrics) + summary = summarize_trajectory_groups(groups_list) + avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started) + avg_metrics.update( + { + key: value + for key, value in build_training_summary_metrics( + summary, + include_trainable_groups=True, + ).items() + if key not in avg_metrics + } + ) + return avg_metrics diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 43e35449b..c2f7153f8 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -43,11 +43,14 @@ from mp_actors import close_proxy, move_to_child_process from .. import dev +from .._backend_training import ( + aggregate_rl_training_metrics, + build_rl_train_configs, +) from ..backend import AnyTrainableModel, Backend from ..costs import build_cost_calculator, get_model_pricing from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, - average_metric_samples, build_training_summary_metrics, summarize_trajectory_groups, ) @@ -642,45 +645,36 @@ async def train( # type: ignore[override] if adam_params is not None: raise ValueError("LocalBackend requires adam_params=None.") - # Build config objects from explicit kwargs - config = TrainConfig( - learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef - ) - dev_config: dev.TrainConfig = { - "advantage_balance": advantage_balance, - "allow_training_without_logprobs": allow_training_without_logprobs, - "importance_sampling_level": importance_sampling_level, - "kl_penalty_coef": kl_penalty_coef, - "mask_prob_ratio": mask_prob_ratio, - "plot_tensors": plot_tensors, - "ppo": loss_fn == "ppo", - "precalculate_logprobs": precalculate_logprobs, - "scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev, - "scale_rewards": scale_rewards, - "logprob_calculation_chunk_size": logprob_calculation_chunk_size, - "num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power, - } - # Only include optional fields if they're set - if epsilon is not None: - dev_config["epsilon"] = epsilon - if epsilon_high is not None: - dev_config["epsilon_high"] = epsilon_high - if max_negative_advantage_importance_sampling_weight is not None: - dev_config["max_negative_advantage_importance_sampling_weight"] = ( - max_negative_advantage_importance_sampling_weight - ) - if kimi_k2_tau is not None: - dev_config["kimi_k2_tau"] = kimi_k2_tau - if truncated_importance_sampling is not None: - dev_config["truncated_importance_sampling"] = truncated_importance_sampling - if kl_ref_adapter_path is not None: - dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path - elif kl_penalty_reference_step is not None: - ref_checkpoint_dir = get_step_checkpoint_dir( + resolved_kl_ref_adapter_path = kl_ref_adapter_path + if ( + resolved_kl_ref_adapter_path is None + and kl_penalty_reference_step is not None + ): + resolved_kl_ref_adapter_path = get_step_checkpoint_dir( get_model_dir(model=model, art_path=self._path), kl_penalty_reference_step, ) - dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir + config, dev_config = build_rl_train_configs( + learning_rate=learning_rate, + advantage_balance=advantage_balance, + scale_rewards=scale_rewards, + importance_sampling_level=importance_sampling_level, + mask_prob_ratio=mask_prob_ratio, + ppo=loss_fn == "ppo", + precalculate_logprobs=precalculate_logprobs, + epsilon=epsilon, + epsilon_high=epsilon_high, + max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, + kimi_k2_tau=kimi_k2_tau, + kl_penalty_coef=kl_penalty_coef, + allow_training_without_logprobs=allow_training_without_logprobs, + plot_tensors=plot_tensors, + truncated_importance_sampling=truncated_importance_sampling, + scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev, + logprob_calculation_chunk_size=logprob_calculation_chunk_size, + num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, + kl_ref_adapter_path=resolved_kl_ref_adapter_path, + ) # Collect metrics from training training_metrics: list[dict[str, float]] = [] @@ -690,21 +684,10 @@ async def train( # type: ignore[override] ): training_metrics.append(metrics) - # Aggregate metrics - avg_metrics = average_metric_samples(training_metrics) - summary = summarize_trajectory_groups(groups_list) - avg_metrics.setdefault( - "time/step_trainer_s", time.monotonic() - trainer_started - ) - avg_metrics.update( - { - key: value - for key, value in build_training_summary_metrics( - summary, - include_trainable_groups=True, - ).items() - if key not in avg_metrics - } + avg_metrics = aggregate_rl_training_metrics( + training_metrics=training_metrics, + trajectory_groups=groups_list, + trainer_started=trainer_started, ) # Get step and checkpoint path @@ -822,7 +805,9 @@ async def _train_model( packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors" ) # Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train()) - grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences)) + grad_accumulation_sequences = max( + 1, int(config.grad_accumulation_sequences or 1) + ) estimated_gradient_steps = math.ceil( disk_packed_tensors["num_sequences"] / grad_accumulation_sequences ) diff --git a/src/art/loss.py b/src/art/loss.py index 27f49a6fb..7a195f5e6 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -8,7 +8,7 @@ from . import dev if TYPE_CHECKING: - from art.unsloth.service import TrainInputs + from art.preprocessing.inputs import TrainInputs class Loss(BaseModel): diff --git a/src/art/megatron/client.py b/src/art/megatron/client.py new file mode 100644 index 000000000..9e915c872 --- /dev/null +++ b/src/art/megatron/client.py @@ -0,0 +1,58 @@ +import asyncio +import datetime +import json +import os +from typing import Any, AsyncIterator + +from .jobs import DEFAULT_JOBS_DIR, MegatronJob +from .merge import merge_lora_adapter + +DEFAULT_TRAINING_LOG_DIR = "/tmp/megatron_training_logs" + + +def create_megatron_job_paths() -> tuple[str, str]: + timestamp = datetime.datetime.now().isoformat() + os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) + os.makedirs(DEFAULT_TRAINING_LOG_DIR, exist_ok=True) + return ( + os.path.join(DEFAULT_JOBS_DIR, f"{timestamp}.json"), + os.path.join(DEFAULT_TRAINING_LOG_DIR, f"{timestamp}.jsonl"), + ) + + +def write_megatron_job(job: MegatronJob, *, job_path: str) -> None: + os.makedirs(os.path.dirname(job_path), exist_ok=True) + with open(job_path, "w", encoding="utf-8") as handle: + handle.write(job.model_dump_json()) + + +async def stream_megatron_job( + job: MegatronJob, + *, + job_path: str, + poll_interval: float = 0.1, +) -> AsyncIterator[dict[str, Any]]: + num_lines = 0 + try: + while True: + await asyncio.sleep(poll_interval) + try: + with open(job.log_path, "a+", encoding="utf-8") as log_file: + log_file.seek(0) + lines = log_file.readlines()[num_lines:] + except FileNotFoundError: + continue + + for line in lines: + if not (line := line.strip()): + continue + if line == "all done": + merge_lora_adapter(job.lora_path) + return + num_lines += 1 + yield json.loads(line) + finally: + if os.path.exists(job_path): + os.remove(job_path) + if os.path.exists(job.log_path): + os.remove(job.log_path) diff --git a/src/art/megatron/jobs.py b/src/art/megatron/jobs.py new file mode 100644 index 000000000..88ab112f3 --- /dev/null +++ b/src/art/megatron/jobs.py @@ -0,0 +1,36 @@ +from typing import Any, Literal + +from pydantic import BaseModel + +from .. import types +from ..preprocessing.pack import DiskPackedTensors + +DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl" +DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs" +DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking" + + +class MegatronTrainingJob(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dict[str, Any] + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + log_path: str = DEFAULT_TRAINING_LOG_PATH + + +class MegatronSFTTrainingJob(BaseModel): + job_type: Literal["sft"] = "sft" + lora_path: str + optimizer_state_path: str + sft_data_dir: str + num_batches: int + learning_rates: list[float] + weight_decay: float = 0.0 + max_grad_norm: float = 1.0 + log_path: str = DEFAULT_TRAINING_LOG_PATH + + +MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob diff --git a/src/art/megatron/merge.py b/src/art/megatron/merge.py new file mode 100644 index 000000000..643a94617 --- /dev/null +++ b/src/art/megatron/merge.py @@ -0,0 +1,102 @@ +import importlib +import json +from pathlib import Path +from typing import Any + +import torch + +safetensors = importlib.import_module("safetensors") +safetensors_torch = importlib.import_module("safetensors.torch") +safe_open = safetensors.safe_open +save_file = safetensors_torch.save_file + + +def merge_lora_adapter(lora_path: str) -> None: + base_dir = Path(lora_path) + shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors")) + if not shard_filenames: + return + + shard_files_by_suffix = { + path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path + for path in shard_filenames + } + manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json")) + manifest_files_by_suffix = { + path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path + for path in manifest_filenames + } + + if set(shard_files_by_suffix) != set(manifest_files_by_suffix): + raise RuntimeError( + "Shard/manifest coverage mismatch: " + f"shards={sorted(shard_files_by_suffix)}, " + f"manifests={sorted(manifest_files_by_suffix)}" + ) + + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} + for suffix in sorted(shard_files_by_suffix): + shard_path = shard_files_by_suffix[suffix] + manifest_path = manifest_files_by_suffix[suffix] + with open(manifest_path, "r", encoding="utf-8") as manifest_file: + shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file) + with safe_open(shard_path, framework="pt") as file: + shard_tensors = {key: file.get_tensor(key) for key in file.keys()} + + if set(shard_tensors) != set(shard_manifest): + raise RuntimeError( + f"Tensor/manifest key mismatch for shard suffix={suffix}: " + f"tensor_keys={sorted(shard_tensors)}, " + f"manifest_keys={sorted(shard_manifest)}" + ) + for key, tensor in shard_tensors.items(): + entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) + + adapter_model: dict[str, torch.Tensor] = {} + for key, key_entries in entries_by_key.items(): + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + for manifest_entry, _tensor in key_entries: + if bool(manifest_entry["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(manifest_entry["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") + + if not sharded: + if len(key_entries) != 1: + raise RuntimeError( + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" + ) + tensor = key_entries[0][1] + else: + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for manifest_entry, shard_tensor in key_entries: + shard_rank = int(manifest_entry["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError( + f"Duplicate shard_rank={shard_rank} for key={key}" + ) + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" + ) + + ordered_shards = [ + shard_rank_to_tensor[shard_rank] + for shard_rank in range(shard_world_size) + ] + concat_dim = 1 if "lora_A" in key else 0 + tensor = torch.cat(ordered_shards, dim=concat_dim) + adapter_model[key] = tensor + + adapter_model_path = base_dir / "adapter_model.safetensors" + save_file(adapter_model, adapter_model_path) + for filename in shard_filenames: + filename.unlink() + for filename in manifest_filenames: + filename.unlink() diff --git a/src/art/megatron/routing_replay.py b/src/art/megatron/routing_replay.py index 0705a69a7..926868457 100644 --- a/src/art/megatron/routing_replay.py +++ b/src/art/megatron/routing_replay.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +import importlib import json from pathlib import Path import re @@ -13,9 +14,12 @@ ) from megatron.core.transformer.moe.moe_utils import permute, sort_chunks_by_idxs from pydantic import BaseModel, ConfigDict, model_validator -from safetensors.torch import load_file, save_file import torch +safetensors_torch = importlib.import_module("safetensors.torch") +load_file = safetensors_torch.load_file +save_file = safetensors_torch.save_file + ROUTER_NAME_TOKEN = ".mlp.router" ROUTER_KEY_FORMAT_VERSION = "moe_routing_replay_v1" GLOBAL_TOKEN_UIDS_KEY = "global_token_uids" diff --git a/src/art/megatron/runtime_env.py b/src/art/megatron/runtime_env.py new file mode 100644 index 000000000..c74a4b661 --- /dev/null +++ b/src/art/megatron/runtime_env.py @@ -0,0 +1,15 @@ +import os + + +def _set_cache_dir(env_var: str, default_path: str) -> None: + if not os.environ.get(env_var): + os.environ[env_var] = os.path.expanduser(default_path) + os.makedirs(os.environ[env_var], exist_ok=True) + + +def configure_megatron_runtime_env() -> None: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" + _set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor") + _set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache") diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 9ac3c14be..a1df86755 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -1,19 +1,15 @@ import asyncio from dataclasses import dataclass -import datetime from functools import cached_property -import json +import importlib import os from pathlib import Path import shlex import shutil import subprocess -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, cast from peft.tuners.lora.config import LoraConfig -from pydantic import BaseModel -from safetensors import safe_open -from safetensors.torch import save_file import torch from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest @@ -29,24 +25,96 @@ from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, openai_server_task, run_on_workers -from .routing_replay import MoeRoutingReplayBundle +from .client import create_megatron_job_paths, stream_megatron_job, write_megatron_job +from .jobs import ( + DEFAULT_JOBS_DIR, + DEFAULT_VLLM_WAKE_LOCK_PATH, + MegatronSFTTrainingJob, + MegatronTrainingJob, +) +from .sft_batches import materialize_sft_batches +safetensors = importlib.import_module("safetensors") +safe_open = safetensors.safe_open -class MegatronTrainingJob(BaseModel): - """Job format for communication with train.py""" - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig - moe_routing_replay_path: str | None = None - moe_routing_replay_strict: bool = True +def create_identity_lora(base_model: str, lora_path: str, rank: int = 1, lora_alpha: int = 32) -> None: + """Create an identity LoRA adapter for a Megatron model. + For MoE models, this targets fused expert parameters and converts them to + per-expert format. The conversion swaps lora_A/lora_B, producing A=zeros and + B=Kaiming — which is critical for stable training when alpha/rank is large. -MegatronTrainingJob.model_rebuild( - force=True, _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle} -) + Args: + base_model: HuggingFace model identifier. + lora_path: Directory to save the adapter files. + rank: LoRA rank (default 1 for Megatron models). + lora_alpha: LoRA alpha scaling factor. + """ + from unittest.mock import patch + + from accelerate import init_empty_weights + from peft import get_peft_model + from transformers import AutoConfig, AutoModelForCausalLM + + base_config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config( + base_config, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + model.name_or_path = base_model + + lora_config = LoraConfig( + base_model_name_or_path=base_model, + r=rank, + lora_alpha=lora_alpha, + target_modules=[], + target_parameters=[ + name + for name, _ in model.named_parameters() + if name.endswith( + ( + "q_proj.weight", + "k_proj.weight", + "v_proj.weight", + "o_proj.weight", + "mlp.experts.gate_up_proj", + "mlp.experts.down_proj", + ) + ) + ], + bias="none", + ) + + meta = torch.device("meta") + orig_to = torch.nn.Module.to + + def _skip_meta_to(module: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: + device = kwargs.get("device") or (args[0] if args else None) + if device == meta or str(device) == "meta": + return module + return orig_to(module, *args, **kwargs) + + with patch.object(torch.nn.Module, "to", _skip_meta_to): + peft_model = get_peft_model(model, lora_config) + + os.makedirs(lora_path, exist_ok=True) + peft_model.save_pretrained(lora_path) + convert_checkpoint_if_needed(lora_path) + + # Write final adapter_config with per-expert target_modules + LoraConfig( + base_model_name_or_path=base_model, + r=rank, + lora_alpha=lora_alpha, + target_modules=default_target_modules(base_model), + bias="none", + ).save_pretrained(lora_path) + + del peft_model, model + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() @dataclass @@ -97,60 +165,7 @@ def _adapter_has_weights(self, lora_path: str) -> bool: return False def _create_identity_lora(self, lora_path: str) -> None: - from unittest.mock import patch - - from accelerate import init_empty_weights - from peft import get_peft_model - from transformers import AutoConfig, AutoModelForCausalLM - - base_config = AutoConfig.from_pretrained( - self.base_model, - trust_remote_code=True, - ) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config( - base_config, - torch_dtype=torch.bfloat16, - trust_remote_code=True, - ) - model.name_or_path = self.base_model - lora_config = self._default_lora_adapter_config() - lora_config.target_modules = [] - lora_config.target_parameters = [ - name - for name, _ in model.named_parameters() - if name.endswith( - ( - "q_proj.weight", - "k_proj.weight", - "v_proj.weight", - "o_proj.weight", - "mlp.experts.gate_up_proj", - "mlp.experts.down_proj", - ) - ) - ] - - meta = torch.device("meta") - orig_to = torch.nn.Module.to - - def _skip_meta_to(module: torch.nn.Module, *args: Any, **kwargs: Any): - device = kwargs.get("device") or (args[0] if args else None) - if device == meta or str(device) == "meta": - return module - return orig_to(module, *args, **kwargs) - - with patch.object(torch.nn.Module, "to", _skip_meta_to): - peft_model = get_peft_model(model, lora_config) - - os.makedirs(lora_path, exist_ok=True) - peft_model.save_pretrained(lora_path) - convert_checkpoint_if_needed(lora_path) - self._default_lora_adapter_config().save_pretrained(lora_path) - del peft_model, model - if torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() + create_identity_lora(self.base_model, lora_path) def _ensure_identity_lora(self, lora_path: str) -> None: if self._adapter_has_weights(lora_path): @@ -221,6 +236,63 @@ async def _ensure_megatron_running(self) -> None: cwd=str(project_root), ) + def _clear_pending_jobs(self) -> None: + os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) + for job_name in os.listdir(DEFAULT_JOBS_DIR): + if job_name.endswith(".json"): + os.remove(os.path.join(DEFAULT_JOBS_DIR, job_name)) + + def _resolve_training_lora_path(self) -> str: + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + self._latest_step = 0 + self._ensure_identity_lora(lora_path) + self._ensure_lora_adapter_config(lora_path) + return lora_path + + async def _prepare_for_training(self) -> tuple[AsyncLLM, str]: + llm = await self.llm + await llm.pause_generation() + await llm.reset_prefix_cache() + await run_on_workers(llm, do_sleep, level=2) + self._is_sleeping = True + gc_and_empty_cuda_cache() + + await self._ensure_megatron_running() + lora_path = self._resolve_training_lora_path() + self._optimizer_state_path = self._get_optimizer_state_path() + self._clear_pending_jobs() + return llm, lora_path + + async def _publish_training_checkpoint( + self, + *, + llm: AsyncLLM, + lora_path: str, + ) -> None: + next_step = self._latest_step + 1 + new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) + os.makedirs(new_checkpoint_dir, exist_ok=True) + shutil.copy( + f"{lora_path}/adapter_model.safetensors", + f"{new_checkpoint_dir}/adapter_model.safetensors", + ) + self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) + + wake_lock_path = DEFAULT_VLLM_WAKE_LOCK_PATH + try: + with open(wake_lock_path, "w") as lock_file: + lock_file.write("waking vllm\n") + await run_on_workers(llm, do_wake_up) + self._is_sleeping = False + finally: + if os.path.exists(wake_lock_path): + os.remove(wake_lock_path) + + await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) + await llm.resume_generation() + async def start_openai_server( self, config: dev.OpenAIServerConfig | None ) -> tuple[str, int]: @@ -259,192 +331,56 @@ async def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - llm = await self.llm - await llm.pause_generation() - await llm.reset_prefix_cache() - await run_on_workers(llm, do_sleep, level=2) - self._is_sleeping = True - gc_and_empty_cuda_cache() - - # Start Megatron after vLLM has freed GPU memory. - await self._ensure_megatron_running() - - lora_path = get_last_checkpoint_dir(self.output_dir) - if lora_path is None: - lora_path = get_step_checkpoint_dir(self.output_dir, 0) - self._latest_step = 0 - self._ensure_identity_lora(lora_path) - self._ensure_lora_adapter_config(lora_path) - - self._optimizer_state_path = self._get_optimizer_state_path() - - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - for job_name in os.listdir(jobs_dir): - if job_name.endswith(".json"): - os.remove(os.path.join(jobs_dir, job_name)) + llm, lora_path = await self._prepare_for_training() if _config.get("moe_routing_replay_bundle") is not None: raise RuntimeError( "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " "MegatronService subprocess jobs must use moe_routing_replay_path." ) + job_path, log_path = create_megatron_job_paths() job = MegatronTrainingJob( lora_path=lora_path, optimizer_state_path=self._optimizer_state_path, disk_packed_tensors=disk_packed_tensors, config=config, - experimental_config=_config, + experimental_config=cast(dict[str, Any], _config), moe_routing_replay_path=_config.get("moe_routing_replay_path"), moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True), + log_path=log_path, ) - job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") - with open(job_path, "w") as f: - f.write(job.model_dump_json()) - - num_lines = 0 - while True: - await asyncio.sleep(0.1) - try: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_file.seek(0) - lines = log_file.readlines()[num_lines:] - for line in lines: - if line := line.strip(): - if line == "all done": - self._merge_lora_adapter(lora_path) - os.remove("/tmp/megatron_training_log.jsonl") - break - num_lines += 1 - yield json.loads(line) - else: - continue - break - except FileNotFoundError: - continue + write_megatron_job(job, job_path=job_path) - next_step = self._latest_step + 1 - new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) - os.makedirs(new_checkpoint_dir, exist_ok=True) - shutil.copy( - f"{lora_path}/adapter_model.safetensors", - f"{new_checkpoint_dir}/adapter_model.safetensors", - ) - self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) - - wake_lock_path = "/tmp/megatron_vllm_waking" - try: - with open(wake_lock_path, "w") as lock_file: - lock_file.write("waking vllm\n") - await run_on_workers(llm, do_wake_up) - self._is_sleeping = False - finally: - if os.path.exists(wake_lock_path): - os.remove(wake_lock_path) + async for result in stream_megatron_job(job, job_path=job_path): + yield {key: float(value) for key, value in result.items()} - await self._add_lora_aliases(llm, next_step, new_checkpoint_dir) - await llm.resume_generation() + await self._publish_training_checkpoint(llm=llm, lora_path=lora_path) - # SFT not supported for MegatronService async def train_sft( self, - batches: list[Any], + batches: list[SFTBatch], verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: - raise NotImplementedError("SFT training is not supported for MegatronService") - yield {} # Make this a generator - - def _merge_lora_adapter(self, lora_path: str) -> None: - """Merge sharded LoRA adapters from distributed training.""" - base_dir = Path(lora_path) - shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors")) - if not shard_filenames: - return - - shard_files_by_suffix = { - path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path - for path in shard_filenames - } - manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json")) - manifest_files_by_suffix = { - path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path - for path in manifest_filenames - } - - if set(shard_files_by_suffix) != set(manifest_files_by_suffix): - raise RuntimeError( - "Shard/manifest coverage mismatch: " - f"shards={sorted(shard_files_by_suffix)}, " - f"manifests={sorted(manifest_files_by_suffix)}" - ) + llm, lora_path = await self._prepare_for_training() + serialized_batches = materialize_sft_batches(batches) + job_path, log_path = create_megatron_job_paths() + job = MegatronSFTTrainingJob( + lora_path=lora_path, + optimizer_state_path=self._optimizer_state_path, + sft_data_dir=serialized_batches.sft_data_dir, + num_batches=serialized_batches.num_batches, + learning_rates=serialized_batches.learning_rates, + log_path=log_path, + ) + write_megatron_job(job, job_path=job_path) - entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} - for suffix in sorted(shard_files_by_suffix): - shard_path = shard_files_by_suffix[suffix] - manifest_path = manifest_files_by_suffix[suffix] - with open(manifest_path, "r", encoding="utf-8") as manifest_file: - shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file) - - with safe_open(shard_path, framework="pt") as file: - shard_tensors = {key: file.get_tensor(key) for key in file.keys()} - - if set(shard_tensors) != set(shard_manifest): - raise RuntimeError( - f"Tensor/manifest key mismatch for shard suffix={suffix}: " - f"tensor_keys={sorted(shard_tensors)}, " - f"manifest_keys={sorted(shard_manifest)}" - ) + async for result in stream_megatron_job(job, job_path=job_path): + yield { + "loss/train": float(result["loss"]), + "loss/learning_rate": float(result["learning_rate"]), + "loss/grad_norm": float(result["grad_norm"]), + } - for key, tensor in shard_tensors.items(): - entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) - - adapter_model: dict[str, torch.Tensor] = {} - for key, key_entries in entries_by_key.items(): - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - - for manifest_entry, _tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for key={key}") - - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated key={key} expected 1 shard, got {len(key_entries)}" - ) - tensor = key_entries[0][1] - else: - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError( - f"Duplicate shard_rank={shard_rank} for key={key}" - ) - shard_rank_to_tensor[shard_rank] = shard_tensor - - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor.keys()) != expected_shard_ranks: - raise RuntimeError( - f"Shard rank coverage mismatch for key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor.keys())}" - ) - - ordered_shards = [ - shard_rank_to_tensor[i] for i in range(shard_world_size) - ] - concat_dim = 1 if "lora_A" in key else 0 - tensor = torch.cat(ordered_shards, dim=concat_dim) - adapter_model[key] = tensor - - adapter_model_path = base_dir / "adapter_model.safetensors" - save_file(adapter_model, adapter_model_path) - for filename in shard_filenames: - filename.unlink() - for filename in manifest_filenames: - filename.unlink() + await self._publish_training_checkpoint(llm=llm, lora_path=lora_path) @cached_property def llm(self) -> asyncio.Task[AsyncLLM]: diff --git a/src/art/megatron/sft_batches.py b/src/art/megatron/sft_batches.py new file mode 100644 index 000000000..ea456bf63 --- /dev/null +++ b/src/art/megatron/sft_batches.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass +import json +import os +from typing import TYPE_CHECKING, Any, Iterable +import uuid + +from safetensors.torch import load_file, save_file +import torch + +if TYPE_CHECKING: + from ..preprocessing.tokenize import SFTBatch + + +DEFAULT_SFT_DATA_DIR = "/tmp/megatron_sft_data" + + +@dataclass(frozen=True) +class SerializedSFTBatches: + sft_data_dir: str + num_batches: int + learning_rates: list[float] + + +def serialize_sft_batch_to_disk(batch: "SFTBatch", batch_dir: str) -> None: + os.makedirs(batch_dir, exist_ok=True) + metadata = { + "learning_rate": batch.learning_rate, + "num_trajectories": batch.num_trajectories, + "num_trainable_tokens": batch.num_trainable_tokens, + "num_trajectory_tensors": len(batch.trajectory_tensors), + } + with open(os.path.join(batch_dir, "metadata.json"), "w", encoding="utf-8") as f: + json.dump(metadata, f) + for index, trajectory_tensors in enumerate(batch.trajectory_tensors): + save_file( + { + key: value.squeeze(0) if value.dim() > 0 else value + for key, value in trajectory_tensors.items() + }, + os.path.join(batch_dir, f"trajectory_{index}.safetensors"), + ) + + +def materialize_sft_batches( + batches: Iterable["SFTBatch"], + *, + sft_data_dir: str | None = None, +) -> SerializedSFTBatches: + if sft_data_dir is None: + sft_data_dir = os.path.join(DEFAULT_SFT_DATA_DIR, uuid.uuid4().hex) + + learning_rates: list[float] = [] + num_batches = 0 + for batch_index, batch in enumerate(batches): + batch_dir = os.path.join(sft_data_dir, f"batch_{batch_index:06d}") + serialize_sft_batch_to_disk(batch, batch_dir) + learning_rates.append(batch.learning_rate) + num_batches += 1 + + return SerializedSFTBatches( + sft_data_dir=sft_data_dir, + num_batches=num_batches, + learning_rates=learning_rates, + ) + + +def load_sft_batch_from_disk( + batch_dir: str, +) -> tuple[dict[str, Any], list[dict[str, torch.Tensor]]]: + with open(os.path.join(batch_dir, "metadata.json"), encoding="utf-8") as f: + metadata = json.load(f) + + trajectory_tensors = [] + for index in range(metadata["num_trajectory_tensors"]): + trajectory_tensors.append( + load_file(os.path.join(batch_dir, f"trajectory_{index}.safetensors")) + ) + return metadata, trajectory_tensors diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index cc86c126f..9cfa85105 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -1,23 +1,23 @@ # isort: off -import os - +from art.megatron.runtime_env import configure_megatron_runtime_env -def _set_cache_dir(env_var: str, default_path: str) -> None: - if not os.environ.get(env_var): - os.environ[env_var] = os.path.expanduser(default_path) - os.makedirs(os.environ[env_var], exist_ok=True) +configure_megatron_runtime_env() +# isort: on +"""Megatron training runtime and public worker API. -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" -_set_cache_dir("TORCHINDUCTOR_CACHE_DIR", "~/.cache/torchinductor") -_set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache") -# isort: on +Public cross-repo API consumed by serverless-training: +- build_training_runtime +- run_megatron_worker_loop +- merge_lora_adapter +""" import gc +import importlib import json import math +import os +from pathlib import Path import shutil import time from typing import Any, Callable, cast @@ -28,15 +28,23 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer.module import MegatronModule from pydantic import BaseModel, ConfigDict -from safetensors.torch import load_file, save_file import torch from torch._inductor.runtime.cache_dir_utils import cache_dir as inductor_cache_dir from art import dev, types from art.loss import loss_fn, shift_tensor +from art.megatron.client import create_megatron_job_paths, write_megatron_job from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state +from art.megatron.jobs import ( + DEFAULT_JOBS_DIR, + DEFAULT_VLLM_WAKE_LOCK_PATH, + MegatronJob, + MegatronSFTTrainingJob, + MegatronTrainingJob, +) from art.megatron.lora import apply_lora_adapters +from art.megatron.merge import merge_lora_adapter from art.megatron.offload import ( OffloadState, clear_optimizer_state, @@ -48,29 +56,30 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: MoeRoutingReplayBundle, MoeRoutingReplayController, ) +from art.megatron.sft_batches import load_sft_batch_from_disk from art.preprocessing.pack import ( - DiskPackedTensors, PackedTensors, packed_tensors_from_dir, ) -DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" - - -class TrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig - moe_routing_replay_path: str | None = None - moe_routing_replay_strict: bool = True +safetensors = importlib.import_module("safetensors") +safetensors_torch = importlib.import_module("safetensors.torch") +safe_open = safetensors.safe_open +load_file = safetensors_torch.load_file +save_file = safetensors_torch.save_file +DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" -TrainingJob.model_rebuild( - force=True, - _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle}, -) +__all__ = [ + "DEFAULT_MODEL_IDENTIFIER", + "TrainingRuntime", + "build_training_runtime", + "run_megatron_worker_loop", + "run_megatron_rl_job", + "run_megatron_sft_job", + "finalize_megatron_job", + "merge_lora_adapter", +] class TrainingRuntime(BaseModel): @@ -299,6 +308,385 @@ def build_training_runtime( return runtime +def run_megatron_worker_loop( + runtime: TrainingRuntime, + *, + supports_sft: bool, + wait_until_ready: Callable[[], None] | None = None, + before_job: Callable[[], None] | None = None, + after_job: Callable[[], None] | None = None, +) -> None: + while True: + torch.distributed.barrier() # type: ignore[possibly-missing-attribute] + os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True) + job_names = sorted( + job_name + for job_name in os.listdir(DEFAULT_JOBS_DIR) + if job_name.endswith(".json") + ) + if not job_names: + time.sleep(1) + continue + + if wait_until_ready is not None: + wait_until_ready() + if before_job is not None: + before_job() + + job_path = os.path.join(DEFAULT_JOBS_DIR, job_names[0]) + job = _load_megatron_job(job_path, supports_sft=supports_sft) + print0(runtime.rank, "Loaded job from", job_path) + print0(runtime.rank, "Job:", job) + + try: + _run_megatron_job(runtime, job) + finally: + if after_job is not None: + after_job() + + finalize_megatron_job( + runtime, + job_path=job_path, + log_path=job.log_path, + cleanup_path=_job_cleanup_path(job), + ) + + +def run_megatron_rl_job( + runtime: TrainingRuntime, + job: MegatronTrainingJob, +) -> None: + packed_tensors = None + adapter_model = None + template = None + zero_template = None + + try: + configure_moe_routing_replay( + runtime, + replay_bundle_path=job.moe_routing_replay_path, + strict=job.moe_routing_replay_strict, + ) + adapter_model = _load_lora_and_optimizer( + runtime, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + + print0( + runtime.rank, + "Loading packed tensors from", + job.disk_packed_tensors["dir"], + ) + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0)) + zero_template = _zero_contribution_inputs(template) + num_sequences = job.disk_packed_tensors["num_sequences"] + global_grad_accumulation_sequences = resolve_global_grad_accumulation_sequences( + job.config.grad_accumulation_sequences + ) + num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) + for step_index in range(num_steps): + micro_indices = build_micro_sample_indices( + step_index=step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + micro_inputs = select_micro_inputs( + packed_tensors, + micro_indices, + zero_template, + ) + step_result = run_training_step( + model_chunks=runtime.model, + optimizer=runtime.optimizer, + learning_rate=job.config.learning_rate, + inputs=micro_inputs, + config=job.config, + experimental_config=cast(dev.TrainConfig, job.experimental_config), + ref_logprobs=None, + step_index=step_index, + sample_index=micro_indices, + moe_routing_replay_controller=runtime.moe_routing_replay_controller, + ) + print0( + runtime.rank, + "Correlation between old and new probabilities:", + step_result.probs_corr, + ) + + if runtime.rank == 0: + with open(job.log_path, "a+", encoding="utf-8") as log_file: + log_msg = json.dumps( + { + "loss": step_result.reduced_loss.item(), + "grad_norm": step_result.grad_norm, + "probs_corr": step_result.probs_corr, + } + ) + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + _save_lora_and_optimizer( + runtime, + adapter_model=adapter_model, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + finally: + if packed_tensors is not None: + del packed_tensors + if adapter_model is not None: + del adapter_model + if template is not None: + del template + if zero_template is not None: + del zero_template + if "micro_inputs" in locals(): + del micro_inputs + gc.collect() + torch.cuda.empty_cache() + + +def run_megatron_sft_job( + runtime: TrainingRuntime, + job: MegatronSFTTrainingJob, +) -> None: + adapter_model = None + + try: + configure_moe_routing_replay(runtime) + adapter_model = _load_lora_and_optimizer( + runtime, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + + runtime.optimizer.config.clip_grad = job.max_grad_norm + for param_group in runtime.optimizer.param_groups: + param_group["weight_decay"] = job.weight_decay + + device = next(runtime.model[0].parameters()).device + dp_rank = ps.get_data_parallel_rank() + dp_world_size = ps.get_data_parallel_world_size() + + for batch_idx in range(job.num_batches): + batch_start_time = time.perf_counter() + batch_dir = os.path.join(job.sft_data_dir, f"batch_{batch_idx:06d}") + batch_metadata, trajectory_tensors = load_sft_batch_from_disk(batch_dir) + global_trainable_tokens = max( + int(batch_metadata["num_trainable_tokens"]), + 1, + ) + local_trajectory_tensors = trajectory_tensors[dp_rank::dp_world_size] + + for chunk in runtime.model: + chunk.zero_grad_buffer() # type: ignore[call-non-callable] + + batch_loss = torch.tensor(0.0, device=device) + local_trainable_tokens = 0.0 + for param_group in runtime.optimizer.param_groups: + param_group["lr"] = job.learning_rates[batch_idx] + + for traj_tensors in local_trajectory_tensors: + attention_mask_1d = traj_tensors["attention_mask"] + actual_len = int(attention_mask_1d.sum().item()) + input_ids = ( + traj_tensors["input_ids"][:actual_len].unsqueeze(0).to(device) + ) + labels = traj_tensors["labels"][:actual_len].unsqueeze(0).to(device) + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + shifted_labels = shift_tensor(labels, -100) + mask = shifted_labels != -100 + local_trainable_tokens += float(mask.sum().item()) + + per_token_loss: torch.Tensor = runtime.model[0]( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=_placeholder_attention_mask(device), + labels=shifted_labels, + extra_block_kwargs={ + "attention_bias": _causal_attention_state(seq_len, device), + }, + ) + masked_loss = per_token_loss[mask].sum() + masked_loss.backward() + batch_loss += masked_loss.detach() + + num_tokens = torch.tensor( + [local_trainable_tokens], + device=device, + dtype=torch.float32, + ) + finalize_model_grads_extended(runtime.model, num_tokens=num_tokens) + update_successful, grad_norm, num_zeros_in_grad = runtime.optimizer.step() + runtime.optimizer.zero_grad() + del update_successful, num_zeros_in_grad + + torch.distributed.all_reduce( + batch_loss, + op=torch.distributed.ReduceOp.SUM, + group=ps.get_data_parallel_group(with_context_parallel=True), + ) + avg_loss = batch_loss / float(global_trainable_tokens) + batch_time = time.perf_counter() - batch_start_time + tokens_per_second = ( + global_trainable_tokens / batch_time if batch_time > 0 else 0.0 + ) + + if runtime.rank == 0: + with open(job.log_path, "a+", encoding="utf-8") as log_file: + log_msg = json.dumps( + { + "loss": avg_loss.item(), + "learning_rate": job.learning_rates[batch_idx], + "grad_norm": float(grad_norm), + "num_trajectories": float( + batch_metadata["num_trajectories"] + ), + "num_trainable_tokens": float(global_trainable_tokens), + "tokens_per_second": tokens_per_second, + } + ) + print("Logging SFT", log_msg) + log_file.write(log_msg + "\n") + + _save_lora_and_optimizer( + runtime, + adapter_model=adapter_model, + lora_path=job.lora_path, + optimizer_state_path=job.optimizer_state_path, + ) + finally: + if adapter_model is not None: + del adapter_model + gc.collect() + torch.cuda.empty_cache() + + +def _load_megatron_job(job_path: str, *, supports_sft: bool) -> MegatronJob: + with open(job_path, "rb") as handle: + job_data = json.loads(handle.read()) + if job_data.get("job_type") == "sft": + if not supports_sft: + raise NotImplementedError("SFT jobs are not supported in this worker loop") + return MegatronSFTTrainingJob.model_validate(job_data) + return MegatronTrainingJob.model_validate(job_data) + + +def _run_megatron_job(runtime: TrainingRuntime, job: MegatronJob) -> None: + if isinstance(job, MegatronSFTTrainingJob): + run_megatron_sft_job(runtime, job) + return + run_megatron_rl_job(runtime, job) + + +def _job_cleanup_path(job: MegatronJob) -> str: + if isinstance(job, MegatronSFTTrainingJob): + return job.sft_data_dir + return job.disk_packed_tensors["dir"] + + +def _load_lora_and_optimizer( + runtime: TrainingRuntime, + *, + lora_path: str, + optimizer_state_path: str, +) -> dict[str, torch.Tensor]: + adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors") + if not os.path.exists(adapter_model_path): + raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") + print0(runtime.rank, "Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) + + optimizer_shard_path = os.path.join( + optimizer_state_path, + f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", + ) + if os.path.exists(optimizer_shard_path): + print0(runtime.rank, "Loading optimizer state from", optimizer_shard_path) + runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + print0( + runtime.rank, + "No optimizer state found at", + optimizer_shard_path, + "- resetting optimizer for new run", + ) + clear_optimizer_state(runtime.optimizer) + runtime.optimizer.reload_model_params() + return adapter_model + + +def _save_lora_and_optimizer( + runtime: TrainingRuntime, + *, + adapter_model: dict[str, torch.Tensor], + lora_path: str, + optimizer_state_path: str, +) -> None: + sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( + runtime.model, + adapter_model, + ) + shard_path = os.path.join( + lora_path, + f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", + ) + manifest_path = os.path.join( + lora_path, + f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", + ) + print("Saving adapter shard to", shard_path) + os.makedirs(lora_path, exist_ok=True) + save_file(sharded_state_dict, shard_path) + print("Saving adapter shard manifest to", manifest_path) + with open(manifest_path, "w", encoding="utf-8") as manifest_file: + json.dump(sharded_state_manifest, manifest_file, sort_keys=True) + + optimizer_shard_path = os.path.join( + optimizer_state_path, + f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", + ) + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(optimizer_state_path, exist_ok=True) + torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) + + +def finalize_megatron_job( + runtime: TrainingRuntime, + *, + job_path: str | None, + log_path: str, + cleanup_path: str, +) -> None: + torch.distributed.barrier() # type: ignore[possibly-missing-attribute] + if runtime.rank != 0: + return + + if job_path is not None and os.path.exists(job_path): + os.remove(job_path) + if os.path.exists(cleanup_path): + shutil.rmtree(cleanup_path) + with open(log_path, "a+", encoding="utf-8") as log_file: + log_file.write("all done\n") + + +def _placeholder_attention_mask(device: torch.device) -> torch.Tensor: + return torch.zeros((1, 1, 1, 1), dtype=torch.bool, device=device) + + +def _causal_attention_state(seq_len: int, device: torch.device) -> Any: + group_ids = torch.zeros((1, seq_len), dtype=torch.int64, device=device) + parent_ids = torch.zeros_like(group_ids) + return create_shared_prefix_attention_state( + group_ids=group_ids, + parent_ids=parent_ids, + ) + + def iter_modules(model_chunks: list[MegatronModule]) -> Any: for chunk in model_chunks: for module in chunk.modules(): @@ -377,35 +765,54 @@ def _zero_contribution_inputs(template: PackedTensors) -> PackedTensors: return dummy +def resolve_global_grad_accumulation_sequences( + global_grad_accumulation_sequences: int | None, +) -> int: + dp_world_size = ps.get_data_parallel_world_size() + if global_grad_accumulation_sequences is None: + return dp_world_size + return global_grad_accumulation_sequences + + def resolve_local_grad_accumulation_sequences( - global_grad_accumulation_sequences: int, + global_grad_accumulation_sequences: int | None, ) -> int: + resolved_global_grad_accumulation_sequences = ( + resolve_global_grad_accumulation_sequences( + global_grad_accumulation_sequences=global_grad_accumulation_sequences + ) + ) dp_world_size = ps.get_data_parallel_world_size() if ( - global_grad_accumulation_sequences <= 0 - or global_grad_accumulation_sequences % dp_world_size != 0 + resolved_global_grad_accumulation_sequences <= 0 + or resolved_global_grad_accumulation_sequences % dp_world_size != 0 ): raise RuntimeError( "Invalid global grad accumulation / DP world size combination: " - f"global_grad_accumulation_sequences={global_grad_accumulation_sequences}, " + f"global_grad_accumulation_sequences={resolved_global_grad_accumulation_sequences}, " f"dp_world_size={dp_world_size}" ) - return global_grad_accumulation_sequences // dp_world_size + return resolved_global_grad_accumulation_sequences // dp_world_size def build_micro_sample_indices( step_index: int, num_sequences: int, - global_grad_accumulation_sequences: int, + global_grad_accumulation_sequences: int | None, ) -> list[int | None]: dp_rank = ps.get_data_parallel_rank() + resolved_global_grad_accumulation_sequences = ( + resolve_global_grad_accumulation_sequences( + global_grad_accumulation_sequences=global_grad_accumulation_sequences + ) + ) dp_world_size = ps.get_data_parallel_world_size() local_grad_accumulation_sequences = resolve_local_grad_accumulation_sequences( - global_grad_accumulation_sequences=global_grad_accumulation_sequences, + global_grad_accumulation_sequences=resolved_global_grad_accumulation_sequences, ) - base_global_sample_index = step_index * global_grad_accumulation_sequences + base_global_sample_index = step_index * resolved_global_grad_accumulation_sequences global_step_indices: list[int | None] = [] - for offset in range(global_grad_accumulation_sequences): + for offset in range(resolved_global_grad_accumulation_sequences): global_sample_index = base_global_sample_index + offset global_step_indices.append( global_sample_index if global_sample_index < num_sequences else None @@ -504,10 +911,15 @@ def run_training_step( micro_sample_indices = [sample_index] if moe_routing_replay_controller is not None: + resolved_global_grad_accumulation_sequences = ( + resolve_global_grad_accumulation_sequences( + config.grad_accumulation_sequences + ) + ) moe_routing_replay_controller.set_step( step_index=step_index, sample_index=micro_sample_indices, - global_grad_accumulation_sequences=config.grad_accumulation_sequences, + global_grad_accumulation_sequences=resolved_global_grad_accumulation_sequences, ) device = next(model_chunks[0].parameters()).device @@ -557,6 +969,7 @@ def run_training_step( if new_logprobs is None or raw_loss_sum is None: raise RuntimeError("run_training_step did not produce outputs") + # num_tokens is reduced in place across ranks by finalize_model_grads(). finalize_model_grads_extended(model_chunks, num_tokens=num_tokens) update_successful, grad_norm, num_zeros_in_grad = _optimizer_step( optimizer, @@ -586,156 +999,21 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: offload_state = OffloadState() offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) - while True: - torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - job_names = sorted( - job_name for job_name in os.listdir(jobs_dir) if job_name.endswith(".json") - ) - if not job_names: - time.sleep(1) - continue - - wake_lock_path = "/tmp/megatron_vllm_waking" - while os.path.exists(wake_lock_path): + def wait_until_ready() -> None: + while os.path.exists(DEFAULT_VLLM_WAKE_LOCK_PATH): time.sleep(0.2) - reload_to_gpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) - - job_name = job_names[0] - job_path = os.path.join(jobs_dir, job_name) - with open(job_path, "rb") as handle: - job = TrainingJob.model_validate_json(handle.read()) - config = job.config - experimental_config = job.experimental_config - - configure_moe_routing_replay( - runtime, - replay_bundle_path=job.moe_routing_replay_path, - strict=job.moe_routing_replay_strict, - ) - - print0(runtime.rank, "Loaded job from", job_path) - print0(runtime.rank, "Job:", job) - - adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" - if not os.path.exists(adapter_model_path): - raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") - print0(runtime.rank, "Loading adapter model from", adapter_model_path) - adapter_model = load_file(adapter_model_path) - load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) - - optimizer_shard_path = os.path.join( - job.optimizer_state_path, - f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", - ) - if os.path.exists(optimizer_shard_path): - print("Loading optimizer state from", optimizer_shard_path) - runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) - else: - print( - "No optimizer state found at", - optimizer_shard_path, - "- resetting optimizer for new run", - ) - clear_optimizer_state(runtime.optimizer) - runtime.optimizer.reload_model_params() - - print0( - runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"] - ) - packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) - template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0)) - zero_template = _zero_contribution_inputs(template) - num_sequences = job.disk_packed_tensors["num_sequences"] - global_grad_accumulation_sequences = config.grad_accumulation_sequences - num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) - for step_index in range(num_steps): - micro_indices = build_micro_sample_indices( - step_index=step_index, - num_sequences=num_sequences, - global_grad_accumulation_sequences=global_grad_accumulation_sequences, - ) - micro_inputs = select_micro_inputs( - packed_tensors, micro_indices, zero_template - ) - try: - step_result = run_training_step( - model_chunks=runtime.model, - optimizer=runtime.optimizer, - learning_rate=config.learning_rate, - inputs=micro_inputs, - config=config, - experimental_config=experimental_config, - ref_logprobs=None, - step_index=step_index, - sample_index=micro_indices, - moe_routing_replay_controller=runtime.moe_routing_replay_controller, - ) - except Exception: - raise - print0( - runtime.rank, - "Correlation between old and new probabilities:", - step_result.probs_corr, - ) - - if runtime.rank == 0: - with open( - "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" - ) as log_file: - log_msg = json.dumps( - { - "loss": step_result.reduced_loss.item(), - "grad_norm": step_result.grad_norm, - "probs_corr": step_result.probs_corr, - } - ) - print("Logging", log_msg) - log_file.write(log_msg + "\n") - - sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( - runtime.model, - adapter_model, - ) - shard_path = os.path.join( - job.lora_path, - f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", - ) - manifest_path = os.path.join( - job.lora_path, - f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", - ) - print("Saving adapter shard to", shard_path) - save_file(sharded_state_dict, shard_path) - print("Saving adapter shard manifest to", manifest_path) - with open(manifest_path, "w", encoding="utf-8") as manifest_file: - json.dump(sharded_state_manifest, manifest_file, sort_keys=True) - - print("Saving optimizer shard to", optimizer_shard_path) - os.makedirs(job.optimizer_state_path, exist_ok=True) - torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) - - offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) - - del packed_tensors - del template - del zero_template - del adapter_model - if "micro_inputs" in locals(): - del micro_inputs - gc.collect() - torch.cuda.empty_cache() - - torch.distributed.barrier() # ty: ignore[possibly-missing-attribute] - if runtime.rank == 0: - os.remove(job_path) - with open( - "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" - ) as log_file: - log_file.write("all done\n") - shutil.rmtree(job.disk_packed_tensors["dir"]) + run_megatron_worker_loop( + runtime, + supports_sft=True, + wait_until_ready=wait_until_ready, + before_job=lambda: reload_to_gpu( + runtime.model, runtime.optimizer, runtime.rank, offload_state + ), + after_job=lambda: offload_to_cpu( + runtime.model, runtime.optimizer, runtime.rank, offload_state + ), + ) def main() -> None: diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index ce530fe58..fb469eb9c 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -9,10 +9,13 @@ from art.serverless.client import Client, ExperimentalTrainingConfig from .. import dev +from .._backend_training import ( + aggregate_rl_training_metrics, + build_rl_train_configs, +) from ..backend import AnyTrainableModel, Backend from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, - average_metric_samples, build_training_summary_metrics, summarize_trajectory_groups, ) @@ -254,27 +257,19 @@ async def train( # type: ignore[override] """ groups_list = list(trajectory_groups) - # Build config objects from explicit kwargs - config = TrainConfig(learning_rate=learning_rate) - dev_config: dev.TrainConfig = { - "advantage_balance": advantage_balance, - "importance_sampling_level": importance_sampling_level, - "mask_prob_ratio": mask_prob_ratio, - "ppo": ppo, - "precalculate_logprobs": precalculate_logprobs, - "scale_rewards": scale_rewards, - } - # Only include optional fields if they're set - if epsilon is not None: - dev_config["epsilon"] = epsilon - if epsilon_high is not None: - dev_config["epsilon_high"] = epsilon_high - if max_negative_advantage_importance_sampling_weight is not None: - dev_config["max_negative_advantage_importance_sampling_weight"] = ( - max_negative_advantage_importance_sampling_weight - ) - if kimi_k2_tau is not None: - dev_config["kimi_k2_tau"] = kimi_k2_tau + config, dev_config = build_rl_train_configs( + learning_rate=learning_rate, + advantage_balance=advantage_balance, + scale_rewards=scale_rewards, + importance_sampling_level=importance_sampling_level, + mask_prob_ratio=mask_prob_ratio, + ppo=ppo, + precalculate_logprobs=precalculate_logprobs, + epsilon=epsilon, + epsilon_high=epsilon_high, + max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, + kimi_k2_tau=kimi_k2_tau, + ) # Collect metrics from training training_metrics: list[dict[str, float]] = [] @@ -284,21 +279,10 @@ async def train( # type: ignore[override] ): training_metrics.append(metrics) - # Aggregate metrics - avg_metrics = average_metric_samples(training_metrics) - summary = summarize_trajectory_groups(groups_list) - avg_metrics.setdefault( - "time/step_trainer_s", time.monotonic() - trainer_started - ) - avg_metrics.update( - { - key: value - for key, value in build_training_summary_metrics( - summary, - include_trainable_groups=True, - ).items() - if key not in avg_metrics - } + avg_metrics = aggregate_rl_training_metrics( + training_metrics=training_metrics, + trajectory_groups=groups_list, + trainer_started=trainer_started, ) # Get step and artifact name diff --git a/src/art/tinker/client.py b/src/art/tinker/client.py index 540faae02..8ece9c7dc 100644 --- a/src/art/tinker/client.py +++ b/src/art/tinker/client.py @@ -23,7 +23,14 @@ def _message_or_choice_to_dict(message_or_choice: MessageOrChoice) -> dict[str, Any]: if isinstance(message_or_choice, dict): return cast(dict[str, Any], message_or_choice) - return cast(dict[str, Any], message_or_choice.to_dict()) + if isinstance(message_or_choice, BaseModel): + return cast(dict[str, Any], message_or_choice.to_dict()) + to_dict = getattr(message_or_choice, "to_dict", None) + if to_dict is None: + raise TypeError( + "message_or_choice must be a dict or OpenAI model with to_dict()" + ) + return cast(dict[str, Any], to_dict()) class MessagesAndChoicesWithLogprobs(BaseModel): diff --git a/src/art/types.py b/src/art/types.py index f905d8817..a39276371 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -17,7 +17,7 @@ class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 - grad_accumulation_sequences: int = pydantic.Field(default=1, ge=1) + grad_accumulation_sequences: int | None = pydantic.Field(default=None, ge=1) class TrainSFTConfig(pydantic.BaseModel): diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index cb55ce18a..37c6de193 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -9,15 +9,10 @@ import socket import subprocess import sys -from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast +from typing import Any, AsyncIterator, Literal, cast -from datasets import Dataset -import peft import torch -from torch.optim import Optimizer -from transformers import GenerationMixin, PreTrainedModel -from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from trl import GRPOConfig, GRPOTrainer +from trl import GRPOTrainer from vllm import AsyncEngineArgs from vllm.lora.request import LoRARequest from vllm.v1.engine.async_llm import AsyncLLM @@ -25,136 +20,23 @@ from .. import dev, types from ..dev.validate import is_dedicated_mode from ..local.checkpoints import get_last_checkpoint_dir -from ..preprocessing.inputs import TrainInputs, create_train_inputs -from ..preprocessing.pack import ( - DiskPackedTensors, - PackedTensors, - packed_tensors_from_dir, -) +from ..preprocessing.inputs import TrainInputs +from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers -from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train +from .train import ( + UnslothTrainContext, + create_unsloth_train_context, + gc_and_empty_cuda_cache, + run_unsloth_rl_training, + run_unsloth_sft_training, +) logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from peft.peft_model import PeftModelForCausalLM - from trl import GRPOTrainer - - -# ============================================================================ -# Shared Utilities -# ============================================================================ - - -class SupportsLoadLora(Protocol): - """Protocol for models that support the optimized load_lora method.""" - - def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ... - - -class _StopTrainInputs: - """Dedicated sentinel for stopping the background trainer loop.""" - - -_STOP_TRAIN_INPUT = _StopTrainInputs() -_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0 -_TrainLoopInput = TrainInputs | _StopTrainInputs - - -def precalculate_new_logprobs( - trainer: "GRPOTrainer", - peft_model: "PeftModelForCausalLM", - packed_tensors: PackedTensors, - config: types.TrainConfig, - _config: dev.TrainConfig, -) -> torch.Tensor: - """Precalculate logprobs for all offsets and return as a tensor.""" - return torch.cat( - [ - trainer.compute_loss( - peft_model, - TrainInputs( # ty:ignore[missing-typed-dict-key] - **{ - k: v[_offset : _offset + 1] - for k, v in packed_tensors.items() - if isinstance(v, torch.Tensor) - }, - pixel_values=packed_tensors["pixel_values"][_offset : _offset + 1], - image_grid_thw=packed_tensors["image_grid_thw"][ - _offset : _offset + 1 - ], - config=config, - _config=_config, - return_new_logprobs=True, - ), - ) - for _offset in range(0, packed_tensors["tokens"].shape[0]) - ] - ).to("cpu") - - -async def process_train_batch( - packed_tensors: PackedTensors, - config: types.TrainConfig, - _config: dev.TrainConfig, - inputs_queue: asyncio.Queue[_TrainLoopInput], - results_queue: asyncio.Queue[dict[str, float]], - train_task: asyncio.Task[None], - trainer: "GRPOTrainer", - peft_model: "PeftModelForCausalLM", - warmup: bool, - verbose: bool = False, -): - """ - Process training batches and yield results. - - Yields tuples of (result, warmup_done) where warmup_done indicates if warmup just finished. - """ - precalculate_logprobs = _config.get("precalculate_logprobs", False) - - for offset in range(0, packed_tensors["tokens"].shape[0]): - for _ in range(2 if warmup else 1): - if precalculate_logprobs and not warmup: - # Preserve original logprobs before overwriting - packed_tensors["original_logprobs"] = packed_tensors["logprobs"] # type: ignore - packed_tensors["logprobs"] = precalculate_new_logprobs( - trainer, peft_model, packed_tensors, config, _config - ) - precalculate_logprobs = False - - inputs_queue.put_nowait( - create_train_inputs(packed_tensors, offset, config, _config, warmup) - ) - - # Wait for a result from the queue or for the training task to, - # presumably, raise an exception - done, _ = await asyncio.wait( - [ - asyncio.create_task(results_queue.get()), - train_task, - ], - return_when=asyncio.FIRST_COMPLETED, - ) - if verbose: - print( - "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" - ) - for task in done: - result = task.result() - # If `result` is `None`, the training task finished somehow. - assert result is not None, "The training task should never finish." - results_queue.task_done() - if warmup: - gc_and_empty_cuda_cache() - await asyncio.sleep(0.1) - warmup = False - else: - yield result - def save_checkpoint( trainer: "GRPOTrainer", @@ -201,17 +83,10 @@ def save_checkpoint( return checkpoint_dir -def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer: - optimizer = cast(Optimizer | None, getattr(trainer, "optimizer", None)) - if optimizer is None: - raise RuntimeError("Trainer optimizer must be initialized before training") - return optimizer - - def _find_free_tcp_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("127.0.0.1", 0)) - return cast(int, sock.getsockname()[1]) + return int(sock.getsockname()[1]) def _normalize_merged_checkpoint_name(name: str) -> str: @@ -223,104 +98,6 @@ def _normalize_merged_checkpoint_name(name: str) -> str: return normalized -# ============================================================================ -# Model Classes -# ============================================================================ - - -class CausalLM(PreTrainedModel, GenerationMixin): - """Dummy class for type checking.""" - - pass - - -@dataclass -class UnslothState: - model: CausalLM - tokenizer: PreTrainedTokenizerBase - peft_model: peft.peft_model.PeftModelForCausalLM - trainer: GRPOTrainer - inputs_queue: asyncio.Queue[_TrainLoopInput] - results_queue: asyncio.Queue[dict[str, float]] - _is_offloaded: bool = False - _pinned_buffers: dict[str, torch.Tensor] | None = None - - def offload_to_cpu(self) -> None: - """Offload training model and optimizer to CPU using pinned memory for faster transfers.""" - if self._is_offloaded: - return - - # Initialize pinned buffer storage - if self._pinned_buffers is None: - self._pinned_buffers = {} - - # Offload model parameters to pinned memory for faster reload - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cuda": - # Create pinned buffer if not exists or wrong size - if ( - name not in self._pinned_buffers - or self._pinned_buffers[name].shape != param.shape - ): - self._pinned_buffers[name] = torch.empty( - param.shape, dtype=param.dtype, device="cpu", pin_memory=True - ) - # Async copy to pinned memory - self._pinned_buffers[name].copy_(param.data, non_blocking=True) - param.data = self._pinned_buffers[name] - - # Offload optimizer state to pinned memory - optimizer = getattr(self.trainer, "optimizer", None) - if optimizer is not None and hasattr(optimizer, "state"): - for param_id, state in optimizer.state.items(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.device.type == "cuda": - key = f"opt_{id(param_id)}_{k}" - if ( - key not in self._pinned_buffers - or self._pinned_buffers[key].shape != v.shape - ): - self._pinned_buffers[key] = torch.empty( - v.shape, dtype=v.dtype, device="cpu", pin_memory=True - ) - self._pinned_buffers[key].copy_(v, non_blocking=True) - state[k] = self._pinned_buffers[key] - - # Sync to ensure all copies are complete before freeing GPU memory - torch.cuda.synchronize() - - self._is_offloaded = True - gc_and_empty_cuda_cache() - - def reload_to_gpu(self, device: str = "cuda:0") -> None: - """Reload training model and optimizer back to GPU using async transfers.""" - if not self._is_offloaded: - return - - # Reload model parameters from pinned memory (fast async transfer) - for name, param in self.peft_model.named_parameters(): - if param.device.type == "cpu": - # Allocate on GPU and async copy from pinned memory - gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) - gpu_tensor.copy_(param.data, non_blocking=True) - param.data = gpu_tensor - - # Reload optimizer state - optimizer = getattr(self.trainer, "optimizer", None) - if optimizer is not None and hasattr(optimizer, "state"): - for state in optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and v.device.type == "cpu": - gpu_tensor = torch.empty(v.shape, dtype=v.dtype, device=device) - gpu_tensor.copy_(v, non_blocking=True) - state[k] = gpu_tensor - - # Sync to ensure all copies are complete before training - torch.cuda.synchronize() - - self._is_offloaded = False - - # ============================================================================ # Service # ============================================================================ @@ -333,7 +110,6 @@ class UnslothService: config: dev.InternalModelConfig output_dir: str _is_sleeping: bool = False - _last_training_mode: Literal["sft", "rl"] | None = None _latest_step: int = 0 _lora_id_counter: int = 1 # Start from 1 since 0 is reserved # Dedicated mode subprocess state @@ -342,7 +118,6 @@ class UnslothService: _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 _weight_transfer_group: Any = field(default=None, init=False, repr=False) - _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False) @property def is_dedicated(self) -> bool: @@ -364,21 +139,9 @@ def _next_lora_id(self) -> int: return self._lora_id_counter async def aclose(self) -> None: - train_task = self._train_task - self._train_task = None - if train_task is None or train_task.done(): - self.close() - return - - # `_state` is a cached_property. Read from __dict__ directly so - # closing does not instantiate trainer state only to stop a task. state = self.__dict__.get("_state") - assert isinstance(state, UnslothState) - state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) - try: - await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S) - except asyncio.TimeoutError: - train_task.cancel() + if isinstance(state, UnslothTrainContext): + await state.stop_background_training() self.close() # ========================================================================= @@ -798,27 +561,6 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._latest_step = step await llm.resume_generation() - def _reset_optimizer_if_mode_changed( - self, - mode: Literal["sft", "rl"], - ) -> None: - """Reset optimizer state if training mode changed. - - Uses a single shared optimizer (trainer.optimizer) for both SFT and RL. - Resets optimizer state (momentum, variance) only when switching between - training modes to avoid stale state from a different loss landscape. - """ - mode_changed = ( - self._last_training_mode is not None and self._last_training_mode != mode - ) - optimizer = _get_trainer_optimizer(self._state.trainer) - - if mode_changed: - # Clear all optimizer state (exp_avg, exp_avg_sq, step for each param) - optimizer.state.clear() - - self._last_training_mode = mode - async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -846,38 +588,11 @@ async def _train_dedicated( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: """Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU.""" - self._reset_optimizer_if_mode_changed("rl") - optimizer = _get_trainer_optimizer(self._state.trainer) - - rl_weight_decay = 0.1 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = rl_weight_decay - - packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - - await self._state.results_queue.join() - - if self._train_task is None: - self._train_task = asyncio.create_task( - train( - trainer=self._state.trainer, - results_queue=self._state.results_queue, - ) - ) - warmup = True - else: - warmup = False - - async for result in process_train_batch( - packed_tensors=packed_tensors, + async for result in run_unsloth_rl_training( + self._state, + disk_packed_tensors=disk_packed_tensors, config=config, _config=_config, - inputs_queue=self._state.inputs_queue, - results_queue=self._state.results_queue, - train_task=self._train_task, - trainer=self._state.trainer, - peft_model=self._state.peft_model, - warmup=warmup, verbose=verbose, ): yield result @@ -938,44 +653,11 @@ async def _train_shared( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() - # Reset optimizer state if switching from SFT to RL - self._reset_optimizer_if_mode_changed("rl") - optimizer = _get_trainer_optimizer(self._state.trainer) - - # Set RL-specific hyperparameters - rl_weight_decay = 0.1 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = rl_weight_decay - - # Load packed tensors - packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) - - # Wait for existing batches to finish - await self._state.results_queue.join() - - # If we haven't already, start the training task - if self._train_task is None: - self._train_task = asyncio.create_task( - train( - trainer=self._state.trainer, - results_queue=self._state.results_queue, - ) - ) - warmup = True - else: - warmup = False - - # Train on the batch using shared logic - async for result in process_train_batch( - packed_tensors=packed_tensors, + async for result in run_unsloth_rl_training( + self._state, + disk_packed_tensors=disk_packed_tensors, config=config, _config=_config, - inputs_queue=self._state.inputs_queue, - results_queue=self._state.results_queue, - train_task=self._train_task, - trainer=self._state.trainer, - peft_model=self._state.peft_model, - warmup=warmup, verbose=verbose, ): yield result @@ -1070,92 +752,19 @@ async def train_sft( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() - - # Get model and optimizer - peft_model = self._state.peft_model - self._reset_optimizer_if_mode_changed("sft") - optimizer = _get_trainer_optimizer(self._state.trainer) - - # Set SFT-specific hyperparameters - sft_weight_decay = 0.01 - for param_group in optimizer.param_groups: - param_group["weight_decay"] = sft_weight_decay - - # Reset environment variable that may be set by RL training - os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" - - peft_model.train() - device = next(peft_model.parameters()).device - max_grad_norm = 1.0 - if verbose: print("SFT training started") - # === Process batches === - batch_idx = 0 - for batch in batches: - batch_start_time = time.perf_counter() - batch_loss = 0.0 - - # Update learning rate for this batch - for param_group in optimizer.param_groups: - param_group["lr"] = batch.learning_rate - - # Total trainable tokens for loss normalization - num_items_in_batch = torch.tensor( - batch.num_trainable_tokens, dtype=torch.long, device=device - ) - - # Process each trajectory in the batch (gradient accumulation) - for trajectory_tensor in batch.trajectory_tensors: - # Move tensors to device - input_ids = trajectory_tensor["input_ids"].to(device) - attention_mask = trajectory_tensor["attention_mask"].to(device) - labels = trajectory_tensor["labels"].to(device) - - # Forward pass with num_items_in_batch for proper loss normalization - outputs = peft_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - num_items_in_batch=num_items_in_batch, - ) - - loss = outputs.loss - - # Backward pass - accumulate gradients - loss.backward() - - # Track metrics - batch_loss += loss.item() - - # Gradient clipping - grad_norm = torch.nn.utils.clip_grad_norm_( - peft_model.parameters(), max_grad_norm - ).item() - - # Optimizer step at the end of each batch - optimizer.step() - optimizer.zero_grad() - - # Compute timing metrics - batch_time = time.perf_counter() - batch_start_time - tokens_per_second = ( - batch.num_trainable_tokens / batch_time if batch_time > 0 else 0.0 - ) - - if verbose: - print( - f"Batch {batch_idx}: loss={batch_loss:.4f}, lr={batch.learning_rate:.2e}, " - f"grad_norm={grad_norm:.4f}, tok/s={tokens_per_second:.1f}" - ) - - batch_idx += 1 - + async for result in run_unsloth_sft_training( + self._state, + batches, + verbose=verbose, + max_grad_norm=1.0, + ): yield { - "loss/train": batch_loss, - "loss/learning_rate": batch.learning_rate, - "loss/grad_norm": grad_norm, + "loss/train": result["loss"], + "loss/learning_rate": result["learning_rate"], + "loss/grad_norm": result["grad_norm"], } # === Cleanup === @@ -1199,82 +808,17 @@ async def train_sft( print("SFT training finished") @cached_property - def _state(self) -> UnslothState: - import unsloth - - # Initialize Unsloth model - init_args = self.config.get("init_args", {}) + def _state(self) -> UnslothTrainContext: + init_args = dict(self.config.get("init_args", {})) checkpoint_dir = get_last_checkpoint_dir(self.output_dir) if checkpoint_dir: init_args["model_name"] = checkpoint_dir else: init_args["model_name"] = self.base_model - - model, tokenizer = cast( - tuple[CausalLM, PreTrainedTokenizerBase], - unsloth.FastLanguageModel.from_pretrained(**init_args), - ) - - # Initialize PEFT model - skip if already a PeftModel (e.g. loaded from checkpoint) - if ( - hasattr(model, "peft_config") - and getattr(model, "peft_config", None) is not None - ): - # Model already has LoRA adapters (loaded from checkpoint) - peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) - else: - peft_model = cast( - peft.peft_model.PeftModelForCausalLM, - unsloth.FastLanguageModel.get_peft_model( - model, **self.config.get("peft_args", {}) - ), - ) - - # Unsloth's model patching can leave the PEFT model without - # `warnings_issued`, which GRPOTrainer expects during init. - if not hasattr(peft_model, "warnings_issued"): - peft_model.warnings_issued = {} # type: ignore[attr-defined] - - # Initialize trainer with dummy dataset - data = {"prompt": ""} - trainer = GRPOTrainer( - model=peft_model, # type: ignore - reward_funcs=[], - args=GRPOConfig(**self.config.get("trainer_args", {})), - train_dataset=Dataset.from_list([data for _ in range(10_000_000)]), - processing_class=tokenizer, - ) - - # Initialize optimizer eagerly using trainer's configured settings. - if trainer.optimizer is None: - trainer.create_optimizer() - - # Initialize queues - inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue() - results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() - - # Patch trainer _prepare_inputs() to pull from queue - def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: - async def get_inputs() -> _TrainLoopInput: - return await inputs_queue.get() - - # Force otherwise synchronous _prepare_inputs() to yield - # with nested asyncio.run() call - inputs = asyncio.run(get_inputs()) - if isinstance(inputs, _StopTrainInputs): - raise StopTrainingLoop() - - return cast(dict[str, torch.Tensor], inputs) - - trainer._prepare_inputs = _async_prepare_inputs - - return UnslothState( - model=model, - tokenizer=tokenizer, - peft_model=peft_model, - trainer=trainer, - inputs_queue=inputs_queue, - results_queue=results_queue, + return create_unsloth_train_context( + init_args=init_args, + peft_args=cast(dict[str, Any], self.config.get("peft_args", {})), + trainer_args=cast(dict[str, Any], self.config.get("trainer_args", {})), ) @cached_property diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 20147e85b..8d22d80c3 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -1,24 +1,55 @@ +"""Unsloth training runtime and public API. + +Public cross-repo API consumed by serverless-training: +- create_unsloth_train_context +- run_unsloth_rl_training +- run_unsloth_sft_training +""" + import asyncio from collections import defaultdict from contextlib import contextmanager, nullcontext +from dataclasses import dataclass import gc import os -from typing import TYPE_CHECKING, Any, Callable, cast +import time +from typing import Any, AsyncIterator, Callable, Iterable, Literal, cast +from datasets import Dataset import nest_asyncio +import peft from peft.peft_model import PeftModel import torch -from trl import GRPOTrainer +from torch.optim import Optimizer +from transformers import GenerationMixin, PreTrainedModel +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from trl import GRPOConfig, GRPOTrainer -from .. import dev +from .. import dev, types from ..loss import loss_fn, shift_tensor +from ..preprocessing.inputs import TrainInputs, create_train_inputs +from ..preprocessing.pack import ( + DiskPackedTensors, + PackedTensors, + packed_tensors_from_dir, +) +from ..preprocessing.tokenize import SFTBatch from ..types import TrainConfig -if TYPE_CHECKING: - from .service import TrainInputs - nest_asyncio.apply() +__all__ = [ + "CausalLM", + "StopTrainingLoop", + "UnslothTrainContext", + "create_unsloth_train_context", + "gc_and_empty_cuda_cache", + "run_unsloth_rl_training", + "run_unsloth_sft_training", +] + +_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0 + _UPSTREAM_TRAIN_METRIC_KEYS = { "reward": "reward", "reward_std_dev": "reward_std_dev", @@ -43,6 +74,223 @@ class StopTrainingLoop(Exception): """Signal that the background trainer loop should exit cleanly.""" +class _StopTrainInputs: + """Sentinel used to stop the background trainer loop cleanly.""" + + +_STOP_TRAIN_INPUT = _StopTrainInputs() +_TrainLoopInput = TrainInputs | _StopTrainInputs + + +class CausalLM(PreTrainedModel, GenerationMixin): + """Dummy class for type checking.""" + + pass + + +@dataclass +class UnslothTrainContext: + model: CausalLM + tokenizer: PreTrainedTokenizerBase + peft_model: peft.peft_model.PeftModelForCausalLM + trainer: GRPOTrainer + inputs_queue: asyncio.Queue[_TrainLoopInput] + results_queue: asyncio.Queue[dict[str, float]] + train_task: asyncio.Task[None] | None = None + warmup_pending: bool = True + last_training_mode: Literal["sft", "rl"] | None = None + _is_offloaded: bool = False + _pinned_buffers: dict[str, torch.Tensor] | None = None + + def offload_to_cpu(self) -> None: + if self._is_offloaded: + return + + if self._pinned_buffers is None: + self._pinned_buffers = {} + + for name, param in self.peft_model.named_parameters(): + if param.device.type != "cuda": + continue + if ( + name not in self._pinned_buffers + or self._pinned_buffers[name].shape != param.shape + ): + self._pinned_buffers[name] = torch.empty( + param.shape, + dtype=param.dtype, + device="cpu", + pin_memory=True, + ) + self._pinned_buffers[name].copy_(param.data, non_blocking=True) + param.data = self._pinned_buffers[name] + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for param_id, state in optimizer.state.items(): + for key, value in state.items(): + if ( + not isinstance(value, torch.Tensor) + or value.device.type != "cuda" + ): + continue + buffer_key = f"opt_{id(param_id)}_{key}" + if ( + buffer_key not in self._pinned_buffers + or self._pinned_buffers[buffer_key].shape != value.shape + ): + self._pinned_buffers[buffer_key] = torch.empty( + value.shape, + dtype=value.dtype, + device="cpu", + pin_memory=True, + ) + self._pinned_buffers[buffer_key].copy_(value, non_blocking=True) + state[key] = self._pinned_buffers[buffer_key] + + torch.cuda.synchronize() + self._is_offloaded = True + gc_and_empty_cuda_cache() + + def reload_to_gpu(self, device: str = "cuda:0") -> None: + if not self._is_offloaded: + return + + for _, param in self.peft_model.named_parameters(): + if param.device.type != "cpu": + continue + gpu_tensor = torch.empty(param.shape, dtype=param.dtype, device=device) + gpu_tensor.copy_(param.data, non_blocking=True) + param.data = gpu_tensor + + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None and hasattr(optimizer, "state"): + for state in optimizer.state.values(): + for key, value in state.items(): + if ( + not isinstance(value, torch.Tensor) + or value.device.type != "cpu" + ): + continue + gpu_tensor = torch.empty( + value.shape, dtype=value.dtype, device=device + ) + gpu_tensor.copy_(value, non_blocking=True) + state[key] = gpu_tensor + + torch.cuda.synchronize() + self._is_offloaded = False + + async def load_lora_adapter(self, lora_path: str) -> None: + try: + await self.results_queue.join() + except Exception: + pass + try: + torch.cuda.synchronize() + except Exception: + pass + + try: + import importlib + + load_safetensors = importlib.import_module("safetensors.torch").load_file + except Exception: + load_safetensors = None # type: ignore[assignment] + + state_dict = None + st_path = os.path.join(lora_path, "adapter_model.safetensors") + bin_path = os.path.join(lora_path, "adapter_model.bin") + alt_st_path = os.path.join(lora_path, "model.safetensors") + alt_bin_path = os.path.join(lora_path, "pytorch_model.bin") + try: + if os.path.exists(st_path) and load_safetensors is not None: + state_dict = load_safetensors(st_path, device="cpu") + elif os.path.exists(bin_path): + state_dict = torch.load(bin_path, map_location="cpu") # type: ignore[call-arg] + elif os.path.exists(alt_st_path) and load_safetensors is not None: + state_dict = load_safetensors(alt_st_path, device="cpu") + elif os.path.exists(alt_bin_path): + state_dict = torch.load(alt_bin_path, map_location="cpu") # type: ignore[call-arg] + else: + raise FileNotFoundError(f"No adapter weights found in {lora_path}") + except Exception as exc: + raise RuntimeError(f"Failed to load LoRA adapter weights: {exc}") from exc + + with torch.no_grad(): + self.peft_model.zero_grad(set_to_none=True) + optimizer = getattr(self.trainer, "optimizer", None) + if optimizer is not None: + optimizer = getattr(optimizer, "optimizer", optimizer) + if hasattr(optimizer, "zero_grad"): + optimizer.zero_grad(set_to_none=True) # type: ignore[arg-type] + if hasattr(optimizer, "state") and isinstance(optimizer.state, dict): + optimizer.state.clear() + + try: + try: + from peft.utils.save_and_load import ( + set_peft_model_state_dict as _set_peft_model_state_dict, + ) + except Exception: + from peft import ( + set_peft_model_state_dict as _set_peft_model_state_dict, # type: ignore + ) + + active_adapter = getattr(self.peft_model, "active_adapter", "default") + _set_peft_model_state_dict( + self.peft_model, + state_dict, + adapter_name=active_adapter, + ) + self.peft_model.set_adapter(active_adapter) + except Exception as exc: + raise RuntimeError(f"Failed to set LoRA weights in-place: {exc}") from exc + + try: + torch.cuda.synchronize() + except Exception: + pass + + async def load_optimizer_state(self, checkpoint_dir: str) -> None: + try: + await self.results_queue.join() + except Exception: + pass + try: + torch.cuda.synchronize() + except Exception: + pass + + optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt") + if os.path.exists(optimizer_path): + optimizer_state = torch.load(optimizer_path, map_location="cpu") + self.trainer.optimizer.load_state_dict(optimizer_state) + + def save_lora_adapter(self, lora_path: str) -> None: + self.trainer.save_model(lora_path) + + def save_optimizer_state(self, checkpoint_dir: str) -> None: + optimizer_path = os.path.join(checkpoint_dir, "optimizer.pt") + torch.save(self.trainer.optimizer.state_dict(), optimizer_path) + + async def stop_background_training( + self, + *, + timeout_s: float = _TRAIN_TASK_SHUTDOWN_TIMEOUT_S, + ) -> None: + train_task = self.train_task + self.train_task = None + if train_task is None or train_task.done(): + return + + self.inputs_queue.put_nowait(_STOP_TRAIN_INPUT) + try: + await asyncio.wait_for(train_task, timeout=timeout_s) + except asyncio.TimeoutError: + train_task.cancel() + + def _canonicalize_upstream_metric_key(metric: str) -> str: if "/" in metric: return metric @@ -415,3 +663,255 @@ def _calculate_logprobs( def gc_and_empty_cuda_cache(n: int = 3) -> None: [gc.collect() >= 0 and torch.cuda.empty_cache() for _ in range(n)] + + +def create_unsloth_train_context( + *, + init_args: dict[str, Any], + peft_args: dict[str, Any], + trainer_args: dict[str, Any], + use_fast_model: bool = False, +) -> UnslothTrainContext: + import unsloth + + loader_cls = unsloth.FastModel if use_fast_model else unsloth.FastLanguageModel + model, tokenizer = cast( + tuple[CausalLM, PreTrainedTokenizerBase], + loader_cls.from_pretrained(**init_args), + ) + + if ( + hasattr(model, "peft_config") + and getattr(model, "peft_config", None) is not None + ): + peft_model = cast(peft.peft_model.PeftModelForCausalLM, model) + else: + peft_model = cast( + peft.peft_model.PeftModelForCausalLM, + loader_cls.get_peft_model(model, **peft_args), + ) + + if not hasattr(peft_model, "warnings_issued"): + peft_model.warnings_issued = {} # type: ignore[attr-defined] + + trainer = GRPOTrainer( + model=peft_model, # type: ignore[arg-type] + reward_funcs=[], + args=GRPOConfig(**trainer_args), + train_dataset=Dataset.from_list([{"prompt": ""} for _ in range(10_000_000)]), + processing_class=tokenizer, + ) + if trainer.optimizer is None: + trainer.create_optimizer() + + inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue() + results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue() + + def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]: + async def get_inputs() -> _TrainLoopInput: + return await inputs_queue.get() + + inputs = asyncio.run(get_inputs()) + if isinstance(inputs, _StopTrainInputs): + raise StopTrainingLoop() + return cast(dict[str, torch.Tensor], inputs) + + trainer._prepare_inputs = _async_prepare_inputs + + return UnslothTrainContext( + model=model, + tokenizer=tokenizer, + peft_model=peft_model, + trainer=trainer, + inputs_queue=inputs_queue, + results_queue=results_queue, + ) + + +def _get_trainer_optimizer(ctx: UnslothTrainContext) -> Optimizer: + optimizer = cast(Optimizer | None, getattr(ctx.trainer, "optimizer", None)) + if optimizer is None: + raise RuntimeError("Trainer optimizer must be initialized before training") + return optimizer + + +def _reset_optimizer_if_mode_changed( + ctx: UnslothTrainContext, + mode: Literal["sft", "rl"], +) -> None: + mode_changed = ctx.last_training_mode is not None and ctx.last_training_mode != mode + if mode_changed: + _get_trainer_optimizer(ctx).state.clear() + ctx.last_training_mode = mode + + +def _precalculate_new_logprobs( + ctx: UnslothTrainContext, + packed_tensors: PackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, +) -> torch.Tensor: + return torch.cat( + [ + ctx.trainer.compute_loss( + ctx.peft_model, + TrainInputs( # ty:ignore[missing-typed-dict-key] + **{ + key: value[offset : offset + 1] + for key, value in packed_tensors.items() + if isinstance(value, torch.Tensor) + }, + pixel_values=packed_tensors["pixel_values"][offset : offset + 1], + image_grid_thw=packed_tensors["image_grid_thw"][ + offset : offset + 1 + ], + config=config, + _config=_config, + return_new_logprobs=True, + ), + ) + for offset in range(0, packed_tensors["tokens"].shape[0]) + ] + ).to("cpu") + + +async def run_unsloth_rl_training( + ctx: UnslothTrainContext, + disk_packed_tensors: DiskPackedTensors, + config: types.TrainConfig, + _config: dev.TrainConfig, + verbose: bool = False, +) -> AsyncIterator[dict[str, float]]: + _reset_optimizer_if_mode_changed(ctx, "rl") + optimizer = _get_trainer_optimizer(ctx) + for param_group in optimizer.param_groups: + param_group["weight_decay"] = 0.1 + + packed_tensors = packed_tensors_from_dir(**disk_packed_tensors) + await ctx.results_queue.join() + + if ctx.train_task is None: + ctx.train_task = asyncio.create_task( + train( + trainer=ctx.trainer, + results_queue=ctx.results_queue, + ) + ) + + warmup = ctx.warmup_pending + precalculate_logprobs = _config.get("precalculate_logprobs", False) + + for offset in range(0, packed_tensors["tokens"].shape[0]): + for _ in range(2 if warmup else 1): + if precalculate_logprobs and not warmup: + packed_tensors["original_logprobs"] = packed_tensors["logprobs"] # type: ignore[index] + packed_tensors["logprobs"] = _precalculate_new_logprobs( + ctx, + packed_tensors, + config, + _config, + ) + precalculate_logprobs = False + + ctx.inputs_queue.put_nowait( + create_train_inputs(packed_tensors, offset, config, _config, warmup) + ) + + done, _ = await asyncio.wait( + [ + asyncio.create_task(ctx.results_queue.get()), + ctx.train_task, + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if verbose: + print( + "Done waiting for a result from the queue or for the training task to, presumably, raise an exception" + ) + for task in done: + result = task.result() + assert result is not None, "The training task should never finish." + ctx.results_queue.task_done() + if warmup: + gc_and_empty_cuda_cache() + await asyncio.sleep(0.1) + warmup = False + ctx.warmup_pending = False + else: + yield result + + +async def run_unsloth_sft_training( + ctx: UnslothTrainContext, + batches: Iterable[SFTBatch], + verbose: bool = False, + *, + weight_decay: float = 0.0, + max_grad_norm: float = 1.0, +) -> AsyncIterator[dict[str, float]]: + _reset_optimizer_if_mode_changed(ctx, "sft") + optimizer = _get_trainer_optimizer(ctx) + + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" + + for param_group in optimizer.param_groups: + param_group["weight_decay"] = weight_decay + + ctx.peft_model.train() + device = next(ctx.peft_model.parameters()).device + + for batch_idx, batch in enumerate(batches): + batch_start_time = time.perf_counter() + batch_loss = 0.0 + + for param_group in optimizer.param_groups: + param_group["lr"] = batch.learning_rate + + num_trainable_tokens = torch.tensor( + batch.num_trainable_tokens, + dtype=torch.long, + device=device, + ) + + for trajectory_tensor in batch.trajectory_tensors: + input_ids = trajectory_tensor["input_ids"].to(device) + attention_mask = trajectory_tensor["attention_mask"].to(device) + labels = trajectory_tensor["labels"].to(device) + + outputs = ctx.peft_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + num_items_in_batch=num_trainable_tokens, + ) + loss = outputs.loss + loss.backward() + batch_loss += loss.item() + + grad_norm = torch.nn.utils.clip_grad_norm_( + ctx.peft_model.parameters(), + max_grad_norm, + ).item() + + optimizer.step() + optimizer.zero_grad() + + batch_time = time.perf_counter() - batch_start_time + tokens_per_second = ( + batch.num_trainable_tokens / batch_time if batch_time > 0 else 0.0 + ) + + if verbose: + print( + f"Batch {batch_idx}: loss={batch_loss:.4f}, lr={batch.learning_rate:.2e}, " + f"grad_norm={grad_norm:.4f}, tok/s={tokens_per_second:.1f}" + ) + + yield { + "loss": batch_loss, + "learning_rate": batch.learning_rate, + "grad_norm": grad_norm, + "num_trajectories": float(batch.num_trajectories), + "num_trainable_tokens": float(batch.num_trainable_tokens), + "tokens_per_second": tokens_per_second, + } diff --git a/src/art/utils/convert_moe_lora.py b/src/art/utils/convert_moe_lora.py index 0ea80f63a..ff3e893ce 100644 --- a/src/art/utils/convert_moe_lora.py +++ b/src/art/utils/convert_moe_lora.py @@ -12,13 +12,15 @@ ... """ +import importlib import json import os import re -import safetensors.torch import torch +safetensors_torch = importlib.import_module("safetensors.torch") + def _has_fused_moe_lora(tensors: dict[str, torch.Tensor]) -> bool: """Check if the adapter contains fused MoE LoRA tensors.""" @@ -152,7 +154,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None: if not os.path.exists(adapter_path) or not os.path.exists(config_path): return - tensors = safetensors.torch.load_file(adapter_path) + tensors = safetensors_torch.load_file(adapter_path) if not _has_fused_moe_lora(tensors): return @@ -168,7 +170,7 @@ def convert_checkpoint_if_needed(checkpoint_dir: str) -> None: ) # Overwrite the adapter with the converted tensors - safetensors.torch.save_file(new_tensors, adapter_path) + safetensors_torch.save_file(new_tensors, adapter_path) # Update adapter_config.json target_modules adapter_config["target_modules"] = [ diff --git a/tests/integration/megatron_oracle_harness.py b/tests/integration/megatron_oracle_harness.py index db0f74ad6..ad3b552f1 100644 --- a/tests/integration/megatron_oracle_harness.py +++ b/tests/integration/megatron_oracle_harness.py @@ -742,8 +742,9 @@ def _load_output_tensor(topology_dir: Path, step: StepTrace): def _load_safetensor_map(path: Path) -> dict[str, Any]: """Loads one safetensor map from disk.""" - from safetensors.torch import load_file + import importlib + load_file = importlib.import_module("safetensors.torch").load_file return load_file(str(path)) diff --git a/tests/integration/megatron_oracle_worker.py b/tests/integration/megatron_oracle_worker.py index 0f135c5df..048777683 100644 --- a/tests/integration/megatron_oracle_worker.py +++ b/tests/integration/megatron_oracle_worker.py @@ -711,13 +711,18 @@ def _scaled_loss_fn(*args: Any, **kwargs: Any): def _worker_run(request: WorkerRunRequest) -> None: """Executes one full distributed training trace generation worker run.""" - from safetensors.torch import load_file, save_file + import importlib + import torch from art import dev, types from art.megatron import train as megatron_train from art.preprocessing.pack import packed_tensors_from_dir + safetensors_torch = importlib.import_module("safetensors.torch") + load_file = safetensors_torch.load_file + save_file = safetensors_torch.save_file + local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl") # ty: ignore[possibly-missing-attribute] diff --git a/tests/unit/test_megatron_sft_batches.py b/tests/unit/test_megatron_sft_batches.py new file mode 100644 index 000000000..06177187c --- /dev/null +++ b/tests/unit/test_megatron_sft_batches.py @@ -0,0 +1,51 @@ +from pathlib import Path + +import torch + +from art.megatron.sft_batches import load_sft_batch_from_disk, materialize_sft_batches +from art.preprocessing.tokenize import SFTBatch + + +def test_materialize_and_load_sft_batches_round_trip(tmp_path: Path) -> None: + batches = [ + SFTBatch( + trajectory_tensors=[ + { + "input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1, 1]], dtype=torch.int64), + "labels": torch.tensor([[-100, 2, 3]], dtype=torch.int64), + }, + { + "input_ids": torch.tensor([[4, 5]], dtype=torch.int64), + "attention_mask": torch.tensor([[1, 1]], dtype=torch.int64), + "labels": torch.tensor([[-100, 5]], dtype=torch.int64), + }, + ], + learning_rate=1e-4, + num_trajectories=2, + num_trainable_tokens=3, + ) + ] + + serialized = materialize_sft_batches( + batches, + sft_data_dir=str(tmp_path / "megatron-sft"), + ) + + assert serialized.num_batches == 1 + assert serialized.learning_rates == [1e-4] + + metadata, trajectories = load_sft_batch_from_disk( + str(Path(serialized.sft_data_dir) / "batch_000000") + ) + + assert metadata == { + "learning_rate": 1e-4, + "num_trajectories": 2, + "num_trainable_tokens": 3, + "num_trajectory_tensors": 2, + } + assert len(trajectories) == 2 + assert torch.equal(trajectories[0]["input_ids"], torch.tensor([1, 2, 3])) + assert torch.equal(trajectories[0]["labels"], torch.tensor([-100, 2, 3])) + assert torch.equal(trajectories[1]["attention_mask"], torch.tensor([1, 1]))