diff --git a/src/art/megatron/job_protocol.py b/src/art/megatron/job_protocol.py new file mode 100644 index 000000000..43458cb7b --- /dev/null +++ b/src/art/megatron/job_protocol.py @@ -0,0 +1,68 @@ +from typing import Annotated, Literal, TypeAlias + +from pydantic import BaseModel, Field, TypeAdapter + +from art import dev, types +from art.megatron.routing_replay import MoeRoutingReplayBundle +from art.preprocessing.pack import DiskPackedTensors + + +class MergedWeightTransferInitInfo(BaseModel): + master_address: str + master_port: int + rank_offset: int + world_size: int + + +class MergedWeightTransferSpec(BaseModel): + init_info: MergedWeightTransferInitInfo + vllm_base_url: str + served_model_name: str + + +class MegatronSyncJob(BaseModel): + kind: Literal["sync"] + lora_path: str + merged_weight_transfer: MergedWeightTransferSpec + + +class _MegatronTrainJobBase(BaseModel): + lora_path: str + optimizer_state_path: str + disk_packed_tensors: DiskPackedTensors + config: types.TrainConfig + experimental_config: dev.TrainConfig + moe_routing_replay_path: str | None = None + moe_routing_replay_strict: bool = True + + +class MegatronLoraTrainJob(_MegatronTrainJobBase): + kind: Literal["train_lora"] + + +class MegatronMergedTrainJob(_MegatronTrainJobBase): + kind: Literal["train_merged"] + merged_weight_transfer: MergedWeightTransferSpec + + +MegatronLoraTrainJob.model_rebuild( + force=True, + _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle}, +) +MegatronMergedTrainJob.model_rebuild( + force=True, + _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle}, +) + +MegatronJob: TypeAlias = Annotated[ + MegatronSyncJob | MegatronLoraTrainJob | MegatronMergedTrainJob, + Field(discriminator="kind"), +] + + +def dump_megatron_job(job: MegatronJob) -> str: + return TypeAdapter(MegatronJob).dump_json(job).decode() + + +def load_megatron_job(raw: str | bytes) -> MegatronJob: + return TypeAdapter(MegatronJob).validate_json(raw) diff --git a/src/art/megatron/provider.py b/src/art/megatron/provider.py index 7629d4272..5527ab8e0 100644 --- a/src/art/megatron/provider.py +++ b/src/art/megatron/provider.py @@ -80,6 +80,7 @@ def get_provider( ) ) provider = bridge.to_megatron_provider() + setattr(provider, "art_bridge", bridge) base_layer_spec = provider.transformer_layer_spec def _flex_attention_layer_spec( diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 5402a5d36..e53e91b5a 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -8,12 +8,12 @@ from pathlib import Path import shlex import shutil +import signal import subprocess import sys from typing import Any, AsyncIterator, Literal from peft.tuners.lora.config import LoraConfig -from pydantic import BaseModel from safetensors import safe_open from safetensors.torch import save_file import torch @@ -30,28 +30,19 @@ from ..unsloth.service import do_sleep, do_wake_up, gc_and_empty_cuda_cache from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir +from ..utils.network import find_free_tcp_port from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, openai_server_task, run_on_workers -from .routing_replay import MoeRoutingReplayBundle - - -class MegatronTrainingJob(BaseModel): - """Job format for communication with train.py""" - - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig - moe_routing_replay_path: str | None = None - moe_routing_replay_strict: bool = True - - -MegatronTrainingJob.model_rebuild( - force=True, _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle} +from .job_protocol import ( + MegatronJob, + MegatronLoraTrainJob, + MegatronMergedTrainJob, + MegatronSyncJob, + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, + dump_megatron_job, ) - logger = logging.getLogger(__name__) @@ -70,6 +61,10 @@ class MegatronService: _vllm_log_file: Any = field(default=None, repr=False) _vllm_host: str = "127.0.0.1" _vllm_port: int = 0 + _merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = field( + default=None, + repr=False, + ) @property def is_dedicated(self) -> bool: @@ -110,14 +105,11 @@ def _adapter_has_weights(self, lora_path: str) -> bool: adapter_path = os.path.join(lora_path, "adapter_model.safetensors") if not os.path.exists(adapter_path): return False - try: - with safe_open(adapter_path, framework="pt") as adapter_file: - for key in adapter_file.keys(): - tensor = adapter_file.get_tensor(key) - if torch.any(tensor != 0): - return True - except Exception: - return False + with safe_open(adapter_path, framework="pt") as adapter_file: + for key in adapter_file.keys(): + tensor = adapter_file.get_tensor(key) + if torch.any(tensor != 0): + return True return False def _create_identity_lora(self, lora_path: str) -> None: @@ -181,20 +173,22 @@ def _ensure_identity_lora(self, lora_path: str) -> None: return self._create_identity_lora(lora_path) - def _ensure_lora_adapter_config( - self, lora_path: str, *, source_path: str | None = None - ) -> None: + def _ensure_lora_adapter_config(self, lora_path: str) -> None: config_path = os.path.join(lora_path, "adapter_config.json") if os.path.exists(config_path): return os.makedirs(lora_path, exist_ok=True) - if source_path is not None: - source_config = os.path.join(source_path, "adapter_config.json") - if os.path.exists(source_config): - shutil.copy(source_config, config_path) - return self._default_lora_adapter_config().save_pretrained(lora_path) + def _build_merged_weight_transfer_spec(self, step: int) -> MergedWeightTransferSpec: + init_info = self._merged_weight_transfer_init_info + assert init_info is not None + return MergedWeightTransferSpec( + init_info=init_info, + vllm_base_url=self._vllm_base_url, + served_model_name=f"{self.model_name}@{step}", + ) + def _resolve_active_lora_path(self) -> str: lora_path = get_last_checkpoint_dir(self.output_dir) if lora_path is None: @@ -202,10 +196,43 @@ def _resolve_active_lora_path(self) -> str: self._latest_step = 0 else: self._latest_step = get_step_from_dir(self.output_dir) - self._ensure_identity_lora(lora_path) + if self.rollout_weights_mode == "lora": + self._ensure_identity_lora(lora_path) self._ensure_lora_adapter_config(lora_path) return lora_path + async def _set_served_model_name(self, step: int) -> None: + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/art/set_served_model_name", + json={"name": f"{self.model_name}@{step}"}, + timeout=30.0, + ) + response.raise_for_status() + self._latest_step = step + + async def _init_merged_weight_transfer(self) -> None: + import httpx + + if self._merged_weight_transfer_init_info is not None: + return + assert len(self.config["trainer_gpu_ids"]) == 1 + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self._vllm_base_url}/get_world_size", + timeout=30.0, + ) + response.raise_for_status() + inference_world_size = int(response.json()["world_size"]) + self._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=find_free_tcp_port(), + rank_offset=1, + world_size=inference_world_size + 1, + ) + async def _start_vllm_subprocess( self, lora_path: str, @@ -213,6 +240,7 @@ async def _start_vllm_subprocess( config: dev.OpenAIServerConfig | None, ) -> tuple[str, int]: import atexit + import httpx inference_gpu_ids = self.config["inference_gpu_ids"] @@ -232,8 +260,13 @@ async def _start_vllm_subprocess( if config and "engine_args" in config: engine_args.update(dict(config["engine_args"])) engine_args.setdefault("generation_config", "vllm") - engine_args["enable_lora"] = True - engine_args.setdefault("max_loras", 2) + if self.rollout_weights_mode == "merged": + engine_args["weight_transfer_config"] = {"backend": "nccl"} + engine_args.pop("enable_lora", None) + engine_args.pop("max_loras", None) + else: + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) for key in ("model", "served_model_name", "enable_sleep_mode"): engine_args.pop(key, None) @@ -313,6 +346,77 @@ async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: response.raise_for_status() self._latest_step = step + async def _sync_dedicated_merged_weights( + self, + *, + lora_path: str, + step: int, + ) -> None: + await self._ensure_megatron_running() + await self._init_merged_weight_transfer() + job_path = self._write_job( + MegatronSyncJob( + kind="sync", + lora_path=lora_path, + merged_weight_transfer=self._build_merged_weight_transfer_spec(step), + ) + ) + async for _ in self._stream_training_log( + job_path=job_path, lora_path=lora_path + ): + pass + self._latest_step = step + + def _write_job(self, job: MegatronJob) -> str: + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + for job_name in os.listdir(jobs_dir): + if job_name.endswith(".json"): + os.remove(os.path.join(jobs_dir, job_name)) + log_path = "/tmp/megatron_training_log.jsonl" + if os.path.exists(log_path): + os.remove(log_path) + if ( + job.kind != "sync" + and job.experimental_config.get("moe_routing_replay_bundle") is not None + ): + raise RuntimeError( + "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " + "MegatronService subprocess jobs must use moe_routing_replay_path." + ) + job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") + with open(job_path, "w", encoding="utf-8") as handle: + handle.write(dump_megatron_job(job)) + return job_path + + async def _stream_training_log( + self, + *, + job_path: str, + lora_path: str, + ) -> AsyncIterator[dict[str, float]]: + log_path = "/tmp/megatron_training_log.jsonl" + num_lines = 0 + while True: + await asyncio.sleep(0.1) + if not os.path.exists(log_path): + assert os.path.exists(job_path) + continue + with open(log_path, "a+", encoding="utf-8") as log_file: + log_file.seek(0) + lines = log_file.readlines()[num_lines:] + for line in lines: + line = line.strip() + if not line: + continue + if line == "all done": + self._merge_lora_adapter(lora_path) + os.remove(log_path) + return + num_lines += 1 + yield json.loads(line) + assert os.path.exists(job_path) + def _stop_vllm_subprocess(self) -> None: if self._vllm_process is not None: self._vllm_process.terminate() @@ -325,12 +429,13 @@ def _stop_vllm_subprocess(self) -> None: if self._vllm_log_file is not None: self._vllm_log_file.close() self._vllm_log_file = None + self._merged_weight_transfer_init_info = None def _stop_megatron_process(self) -> None: if self._megatron_process is None: return if self._megatron_process.returncode is None: - self._megatron_process.terminate() + os.killpg(os.getpgid(self._megatron_process.pid), signal.SIGTERM) self._megatron_process = None async def _add_lora_aliases( @@ -349,8 +454,10 @@ async def _add_lora_aliases( async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: if self.is_dedicated: - assert self.rollout_weights_mode == "lora" - await self._reload_adapter(checkpoint_dir, step) + if self.rollout_weights_mode == "merged": + await self._set_served_model_name(step) + else: + await self._reload_adapter(checkpoint_dir, step) return llm = await self.llm await llm.pause_generation() @@ -394,6 +501,7 @@ async def _ensure_megatron_running(self) -> None: command, cwd=str(project_root), env=launch_env, + start_new_session=True, ) async def start_openai_server( @@ -402,9 +510,14 @@ async def start_openai_server( lora_path = self._resolve_active_lora_path() if self.is_dedicated: - assert self.rollout_weights_mode == "lora" port = (config or {}).get("server_args", {}).get("port", 8000) - return await self._start_vllm_subprocess(lora_path, port, config) + location = await self._start_vllm_subprocess(lora_path, port, config) + if self.rollout_weights_mode == "merged": + await self._sync_dedicated_merged_weights( + lora_path=lora_path, + step=self._latest_step, + ) + return location lora_path_for_server = ( lora_path if self._adapter_has_weights(lora_path) else None @@ -442,73 +555,61 @@ async def train( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: if self.is_dedicated: - assert self.rollout_weights_mode == "lora" await self._ensure_megatron_running() lora_path = self._resolve_active_lora_path() self._optimizer_state_path = self._get_optimizer_state_path() - - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - for job_name in os.listdir(jobs_dir): - if job_name.endswith(".json"): - os.remove(os.path.join(jobs_dir, job_name)) - if _config.get("moe_routing_replay_bundle") is not None: - raise RuntimeError( - "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " - "MegatronService subprocess jobs must use moe_routing_replay_path." + next_step = self._latest_step + 1 + if self.rollout_weights_mode == "merged": + await self._init_merged_weight_transfer() + job: MegatronJob = MegatronMergedTrainJob( + kind="train_merged", + lora_path=lora_path, + optimizer_state_path=self._optimizer_state_path, + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=_config, + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", + True, + ), + merged_weight_transfer=self._build_merged_weight_transfer_spec( + next_step + ), + ) + else: + job = MegatronLoraTrainJob( + kind="train_lora", + lora_path=lora_path, + optimizer_state_path=self._optimizer_state_path, + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=_config, + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", + True, + ), ) - job = MegatronTrainingJob( + job_path = self._write_job(job) + async for result in self._stream_training_log( + job_path=job_path, lora_path=lora_path, - optimizer_state_path=self._optimizer_state_path, - disk_packed_tensors=disk_packed_tensors, - config=config, - experimental_config=_config, - moe_routing_replay_path=_config.get("moe_routing_replay_path"), - moe_routing_replay_strict=_config.get( - "moe_routing_replay_strict", True - ), - ) - job_path = os.path.join( - jobs_dir, f"{datetime.datetime.now().isoformat()}.json" - ) - with open(job_path, "w", encoding="utf-8") as handle: - handle.write(job.model_dump_json()) + ): + yield result - num_lines = 0 - while True: - await asyncio.sleep(0.1) - try: - with open( - "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" - ) as log_file: - log_file.seek(0) - lines = log_file.readlines()[num_lines:] - for line in lines: - line = line.strip() - if not line: - continue - if line == "all done": - self._merge_lora_adapter(lora_path) - os.remove("/tmp/megatron_training_log.jsonl") - break - num_lines += 1 - yield json.loads(line) - else: - continue - break - except FileNotFoundError: - continue - - next_step = self._latest_step + 1 new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) os.makedirs(new_checkpoint_dir, exist_ok=True) shutil.copy( f"{lora_path}/adapter_model.safetensors", f"{new_checkpoint_dir}/adapter_model.safetensors", ) - self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) - await self._reload_adapter(new_checkpoint_dir, next_step) + self._ensure_lora_adapter_config(new_checkpoint_dir) + if self.rollout_weights_mode == "merged": + self._latest_step = next_step + else: + await self._reload_adapter(new_checkpoint_dir, next_step) return llm = await self.llm @@ -530,17 +631,8 @@ async def train( self._optimizer_state_path = self._get_optimizer_state_path() - jobs_dir = "/tmp/megatron_training_jobs" - os.makedirs(jobs_dir, exist_ok=True) - for job_name in os.listdir(jobs_dir): - if job_name.endswith(".json"): - os.remove(os.path.join(jobs_dir, job_name)) - if _config.get("moe_routing_replay_bundle") is not None: - raise RuntimeError( - "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " - "MegatronService subprocess jobs must use moe_routing_replay_path." - ) - job = MegatronTrainingJob( + job = MegatronLoraTrainJob( + kind="train_lora", lora_path=lora_path, optimizer_state_path=self._optimizer_state_path, disk_packed_tensors=disk_packed_tensors, @@ -549,30 +641,12 @@ async def train( moe_routing_replay_path=_config.get("moe_routing_replay_path"), moe_routing_replay_strict=_config.get("moe_routing_replay_strict", True), ) - job_path = os.path.join(jobs_dir, f"{datetime.datetime.now().isoformat()}.json") - with open(job_path, "w") as f: - f.write(job.model_dump_json()) - - num_lines = 0 - while True: - await asyncio.sleep(0.1) - try: - with open("/tmp/megatron_training_log.jsonl", "a+") as log_file: - log_file.seek(0) - lines = log_file.readlines()[num_lines:] - for line in lines: - if line := line.strip(): - if line == "all done": - self._merge_lora_adapter(lora_path) - os.remove("/tmp/megatron_training_log.jsonl") - break - num_lines += 1 - yield json.loads(line) - else: - continue - break - except FileNotFoundError: - continue + job_path = self._write_job(job) + async for result in self._stream_training_log( + job_path=job_path, + lora_path=lora_path, + ): + yield result next_step = self._latest_step + 1 new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) @@ -581,7 +655,7 @@ async def train( f"{lora_path}/adapter_model.safetensors", f"{new_checkpoint_dir}/adapter_model.safetensors", ) - self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) + self._ensure_lora_adapter_config(new_checkpoint_dir) wake_lock_path = "/tmp/megatron_vllm_waking" try: diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index cc86c126f..d3fbf1f1e 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -15,9 +15,11 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: _set_cache_dir("TRITON_CACHE_DIR", "~/.triton/cache") # isort: on +from concurrent.futures import ThreadPoolExecutor import gc import json import math +import re import shutil import time from typing import Any, Callable, cast @@ -27,6 +29,7 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_layer import TransformerLayer from pydantic import BaseModel, ConfigDict from safetensors.torch import load_file, save_file import torch @@ -36,7 +39,20 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: from art.loss import loss_fn, shift_tensor from art.megatron.finalize_grads import finalize_model_grads_extended from art.megatron.flex_attention import create_shared_prefix_attention_state -from art.megatron.lora import apply_lora_adapters +from art.megatron.job_protocol import ( + MegatronMergedTrainJob, + MergedWeightTransferInitInfo, + MergedWeightTransferSpec, + load_megatron_job, +) +from art.megatron.lora import ( + LoRA, + MLPExpertsLinearFC1LoRA, + MLPExpertsLinearFC2LoRA, + SelfAttentionLinearProjLoRA, + SelfAttentionLinearQKVLoRA, + apply_lora_adapters, +) from art.megatron.offload import ( OffloadState, clear_optimizer_state, @@ -57,22 +73,6 @@ def _set_cache_dir(env_var: str, default_path: str) -> None: DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507" -class TrainingJob(BaseModel): - lora_path: str - optimizer_state_path: str - disk_packed_tensors: DiskPackedTensors - config: types.TrainConfig - experimental_config: dev.TrainConfig - moe_routing_replay_path: str | None = None - moe_routing_replay_strict: bool = True - - -TrainingJob.model_rebuild( - force=True, - _types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle}, -) - - class TrainingRuntime(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -82,6 +82,8 @@ class TrainingRuntime(BaseModel): rank: int world_size: int moe_routing_replay_controller: MoeRoutingReplayController | None = None + merged_weight_transfer_group: Any | None = None + merged_weight_transfer_init_info: MergedWeightTransferInitInfo | None = None class TrainStepResult(BaseModel): @@ -305,6 +307,18 @@ def iter_modules(model_chunks: list[MegatronModule]) -> Any: yield module +def iter_named_modules(model_chunks: list[MegatronModule]) -> Any: + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + yield module_name, module + + +def _is_language_transformer_layer_name(module_name: str) -> bool: + while module_name.startswith("module."): + module_name = module_name.removeprefix("module.") + return module_name.startswith(("decoder.layers.", "language_model.decoder.layers.")) + + def load_adapter_into_model( model_chunks: list[MegatronModule], adapter_model: dict[str, torch.Tensor], @@ -320,6 +334,22 @@ def load_adapter_into_model( optimizer.reload_model_params() +def maybe_load_adapter_into_model( + model_chunks: list[MegatronModule], + adapter_model_path: str, + optimizer: Any | None = None, + *, + rank: int, +) -> dict[str, torch.Tensor]: + if not os.path.exists(adapter_model_path): + print0(rank, "No adapter model found at", adapter_model_path) + return {} + print0(rank, "Loading adapter model from", adapter_model_path) + adapter_model = load_file(adapter_model_path) + load_adapter_into_model(model_chunks, adapter_model, optimizer) + return adapter_model + + def collect_sharded_lora_state( model_chunks: list[MegatronModule], adapter_model: dict[str, torch.Tensor], @@ -582,6 +612,397 @@ def run_training_step( ) +def _is_art_adapter_param_name(name: str) -> bool: + return any( + segment in name + for segment in ( + ".lora.", + ".q_proj_lora.", + ".k_proj_lora.", + ".v_proj_lora.", + ".gate_lora.", + ".up_lora.", + ) + ) + + +def _unwrap_art_wrapper_name(name: str) -> str: + while name.startswith("module."): + name = name[len("module.") :] + for wrapped, unwrapped in ( + (".linear_proj.linear_proj.", ".linear_proj."), + (".linear_qkv.linear_qkv.", ".linear_qkv."), + (".linear_fc1.linear_fc1.", ".linear_fc1."), + (".linear_fc2.linear_fc2.", ".linear_fc2."), + ): + name = name.replace(wrapped, unwrapped) + return name + + +def _mapping_hf_weights_exist(mapping: Any, hf_keys: set[str]) -> bool: + if getattr(mapping, "allow_hf_name_mismatch", False): + return True + hf_param = mapping.hf_param + if isinstance(hf_param, str): + return hf_param in hf_keys + assert isinstance(hf_param, dict) + return all(param in hf_keys for param in hf_param.values()) + + +def _lora_delta(lora: LoRA, expert_idx: int | None = None) -> torch.Tensor: + if lora.A_T.ndim == 3: + assert expert_idx is not None + a_t = lora.A_T[expert_idx] + b_t = lora.B_T[expert_idx] + else: + a_t = lora.A_T + b_t = lora.B_T + return (b_t.T @ a_t.T) * lora.scale + + +def _expert_index_from_hf_name(hf_name: str) -> int: + match = re.search(r"\.experts\.(\d+)\.", hf_name) + assert match is not None + return int(match.group(1)) + + +def _hf_name_has_indexed_expert(hf_name: str) -> bool: + return re.search(r"\.experts\.(\d+)\.", hf_name) is not None + + +def _stack_moe_fc1_deltas(handler: MLPExpertsLinearFC1LoRA) -> torch.Tensor: + return torch.stack( + [ + torch.cat( + [ + _lora_delta(handler.gate_lora, expert_idx), + _lora_delta(handler.up_lora, expert_idx), + ], + dim=0, + ) + for expert_idx in range(handler.gate_lora.num_local_experts) + ], + dim=0, + ) + + +def _stack_moe_fc2_deltas(handler: MLPExpertsLinearFC2LoRA) -> torch.Tensor: + return torch.stack( + [ + _lora_delta(handler.lora, expert_idx) + for expert_idx in range(handler.lora.num_local_experts) + ], + dim=0, + ) + + +def _merge_delta_into_weight( + hf_name: str, + base_weight: torch.Tensor, + delta: torch.Tensor, +) -> torch.Tensor: + delta = delta.to(device=base_weight.device, dtype=base_weight.dtype) + if tuple(base_weight.shape) == tuple(delta.shape): + return base_weight + delta + transposed = delta.transpose(-1, -2) + assert tuple(base_weight.shape) == tuple(transposed.shape), ( + f"{hf_name}: cannot merge delta {tuple(delta.shape)} into {tuple(base_weight.shape)}" + ) + return base_weight + transposed + + +def _build_art_merge_handlers( + model_chunks: list[MegatronModule], +) -> tuple[dict[str, Any], dict[str, Any]]: + exact_handlers: dict[str, Any] = {} + prefix_handlers: dict[str, Any] = {} + for module_name, module in iter_named_modules(model_chunks): + if not isinstance(module, TransformerLayer): + continue + if not _is_language_transformer_layer_name(module_name): + continue + prefix = f"language_model.decoder.layers.{module.layer_number - 1}" + linear_proj = getattr(module.self_attention, "linear_proj", None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + exact_handlers[f"{prefix}.self_attention.linear_proj.weight"] = linear_proj + linear_qkv = getattr(module.self_attention, "linear_qkv", None) + if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): + exact_handlers[f"{prefix}.self_attention.linear_qkv.weight"] = linear_qkv + experts = getattr(module.mlp, "experts", None) + if experts is None: + continue + if isinstance(experts.linear_fc1, MLPExpertsLinearFC1LoRA): + prefix_handlers[f"{prefix}.mlp.experts.linear_fc1.weight"] = ( + experts.linear_fc1 + ) + if isinstance(experts.linear_fc2, MLPExpertsLinearFC2LoRA): + prefix_handlers[f"{prefix}.mlp.experts.linear_fc2.weight"] = ( + experts.linear_fc2 + ) + return exact_handlers, prefix_handlers + + +def _merge_art_lora_into_hf_weights( + global_param_name: str, + converted_weights_dict: dict[str, torch.Tensor], + *, + exact_handlers: dict[str, Any], + prefix_handlers: dict[str, Any], +) -> dict[str, torch.Tensor]: + handler = exact_handlers.get(global_param_name) + if handler is None: + for prefix, prefix_handler in prefix_handlers.items(): + if global_param_name.startswith(prefix): + handler = prefix_handler + break + if handler is None: + return converted_weights_dict + if isinstance(handler, SelfAttentionLinearProjLoRA): + hf_name, base_weight = next(iter(converted_weights_dict.items())) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + _lora_delta(handler.lora), + ) + return converted_weights_dict + if isinstance(handler, SelfAttentionLinearQKVLoRA): + deltas = { + "q_proj": _lora_delta(handler.q_proj_lora), + "k_proj": _lora_delta(handler.k_proj_lora), + "v_proj": _lora_delta(handler.v_proj_lora), + } + for hf_name, base_weight in list(converted_weights_dict.items()): + for projection, delta in deltas.items(): + if projection in hf_name: + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + break + return converted_weights_dict + if isinstance(handler, MLPExpertsLinearFC1LoRA): + for hf_name, base_weight in list(converted_weights_dict.items()): + delta = ( + torch.cat( + [ + _lora_delta( + handler.gate_lora, _expert_index_from_hf_name(hf_name) + ), + _lora_delta( + handler.up_lora, _expert_index_from_hf_name(hf_name) + ), + ], + dim=0, + ) + if _hf_name_has_indexed_expert(hf_name) + else _stack_moe_fc1_deltas(handler) + ) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + return converted_weights_dict + assert isinstance(handler, MLPExpertsLinearFC2LoRA) + for hf_name, base_weight in list(converted_weights_dict.items()): + delta = ( + _lora_delta(handler.lora, _expert_index_from_hf_name(hf_name)) + if _hf_name_has_indexed_expert(hf_name) + else _stack_moe_fc2_deltas(handler) + ) + converted_weights_dict[hf_name] = _merge_delta_into_weight( + hf_name, + base_weight, + delta, + ) + return converted_weights_dict + + +def _build_art_conversion_tasks(runtime: TrainingRuntime) -> list[Any]: + from itertools import chain + + from megatron.bridge.models.conversion.model_bridge import ( + WeightConversionTask, + _megatron_local_name_to_global, + ) + from megatron.bridge.models.conversion.utils import ( + get_module_and_param_from_name, + persistent_buffers, + ) + + bridge = getattr(runtime.provider, "art_bridge", None) + assert bridge is not None + mapping_registry = bridge._model_bridge.mapping_registry() + hf_source = bridge.hf_pretrained.state.source + hf_keys = set(hf_source.get_all_keys()) + model_config = runtime.model[0].config + tasks: list[Any] = [] + for vp_stage, model in enumerate(runtime.model): + for local_name, _ in chain(model.named_parameters(), persistent_buffers(model)): + if "_extra_state" in local_name or _is_art_adapter_param_name(local_name): + continue + global_name = _megatron_local_name_to_global( + runtime.model, + model_config, + _unwrap_art_wrapper_name(local_name), + vp_stage, + ) + mapping = mapping_registry.megatron_to_hf_lookup(global_name) + if mapping is None or not _mapping_hf_weights_exist(mapping, hf_keys): + continue + local_module, local_weights = get_module_and_param_from_name( + runtime.model, + local_name, + vp_stage, + ) + if local_module is not None and not hasattr(local_module, "config"): + setattr(local_module, "config", model_config) + tasks.append( + WeightConversionTask( + pp_rank=0, + vp_stage=vp_stage, + param_name=local_name, + global_param_name=global_name, + megatron_module=local_module, + param_weight=local_weights, + mapping=mapping, + ) + ) + return tasks + + +def _iter_merged_vllm_weights(runtime: TrainingRuntime) -> Any: + # vLLM expects HF checkpoint names, but Megatron only has live trainer weights. + # Convert through Bridge here, then merge ART's LoRA deltas into those tensors. + bridge = getattr(runtime.provider, "art_bridge", None) + assert bridge is not None + model_bridge = bridge._model_bridge + hf_state_dict = bridge.hf_pretrained.state + exact_handlers, prefix_handlers = _build_art_merge_handlers(runtime.model) + for task in _build_art_conversion_tasks(runtime): + converted_weights_dict = task.mapping.megatron_to_hf( + task.param_weight, + task.megatron_module, + ) + converted_weights_dict = model_bridge.maybe_modify_converted_hf_weight( + task, + converted_weights_dict, + hf_state_dict, + ) + converted_weights_dict = _merge_art_lora_into_hf_weights( + task.global_param_name, + converted_weights_dict, + exact_handlers=exact_handlers, + prefix_handlers=prefix_handlers, + ) + for hf_name, tensor in converted_weights_dict.items(): + yield hf_name, tensor + + +def _ensure_merged_weight_transfer_group( + runtime: TrainingRuntime, + spec: MergedWeightTransferSpec, +) -> None: + assert runtime.rank == 0 + assert runtime.world_size == 1 + if runtime.merged_weight_transfer_init_info == spec.init_info: + assert runtime.merged_weight_transfer_group is not None + return + import httpx + from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine + + def _remote_init() -> None: + response = httpx.post( + f"{spec.vllm_base_url}/init_weight_transfer_engine", + json={"init_info": spec.init_info.model_dump()}, + timeout=300.0, + ) + response.raise_for_status() + + with ThreadPoolExecutor(max_workers=1) as executor: + remote_future = executor.submit(_remote_init) + time.sleep(1.0) + runtime.merged_weight_transfer_group = NCCLWeightTransferEngine.trainer_init( + { + "master_address": spec.init_info.master_address, + "master_port": spec.init_info.master_port, + "world_size": spec.init_info.world_size, + } + ) + remote_future.result() + runtime.merged_weight_transfer_init_info = spec.init_info + + +def _sync_merged_weights_to_vllm( + runtime: TrainingRuntime, + spec: MergedWeightTransferSpec, + *, + pause_generation: bool, +) -> None: + assert runtime.rank == 0 + assert runtime.world_size == 1 + + import httpx + from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferEngine + + _ensure_merged_weight_transfer_group(runtime, spec) + + def _send_weights() -> None: + NCCLWeightTransferEngine.trainer_send_weights( + _iter_merged_vllm_weights(runtime), + {"group": runtime.merged_weight_transfer_group}, + ) + + with httpx.Client() as client: + if pause_generation: + response = client.post( + f"{spec.vllm_base_url}/pause", + params={"mode": "wait"}, + timeout=300.0, + ) + response.raise_for_status() + try: + torch.cuda.synchronize() + names: list[str] = [] + dtype_names: list[str] = [] + shapes: list[list[int]] = [] + for name, tensor in _iter_merged_vllm_weights(runtime): + names.append(name) + dtype_names.append(str(tensor.dtype).removeprefix("torch.")) + shapes.append(list(tensor.shape)) + with ThreadPoolExecutor(max_workers=1) as executor: + send_future = executor.submit(_send_weights) + response = client.post( + f"{spec.vllm_base_url}/update_weights", + json={ + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "is_checkpoint_format": True, + } + }, + timeout=600.0, + ) + response.raise_for_status() + send_future.result() + response = client.post( + f"{spec.vllm_base_url}/art/set_served_model_name", + json={"name": spec.served_model_name}, + timeout=30.0, + ) + response.raise_for_status() + torch.cuda.synchronize() + finally: + if pause_generation: + response = client.post( + f"{spec.vllm_base_url}/resume", + timeout=30.0, + ) + response.raise_for_status() + + def _run_service_loop(runtime: TrainingRuntime) -> None: offload_state = OffloadState() offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) @@ -606,61 +1027,70 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: job_name = job_names[0] job_path = os.path.join(jobs_dir, job_name) with open(job_path, "rb") as handle: - job = TrainingJob.model_validate_json(handle.read()) - config = job.config - experimental_config = job.experimental_config - - configure_moe_routing_replay( - runtime, - replay_bundle_path=job.moe_routing_replay_path, - strict=job.moe_routing_replay_strict, - ) + job = load_megatron_job(handle.read()) + if job.kind != "sync": + config = job.config + experimental_config = job.experimental_config + configure_moe_routing_replay( + runtime, + replay_bundle_path=job.moe_routing_replay_path, + strict=job.moe_routing_replay_strict, + ) print0(runtime.rank, "Loaded job from", job_path) print0(runtime.rank, "Job:", job) adapter_model_path = f"{job.lora_path}/adapter_model.safetensors" - if not os.path.exists(adapter_model_path): - raise FileNotFoundError(f"No adapter model found at {adapter_model_path}") - print0(runtime.rank, "Loading adapter model from", adapter_model_path) - adapter_model = load_file(adapter_model_path) - load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) - - optimizer_shard_path = os.path.join( - job.optimizer_state_path, - f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", + adapter_model = maybe_load_adapter_into_model( + runtime.model, + adapter_model_path, + runtime.optimizer, + rank=runtime.rank, ) - if os.path.exists(optimizer_shard_path): - print("Loading optimizer state from", optimizer_shard_path) - runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + + if job.kind == "sync": + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=False, + ) else: - print( - "No optimizer state found at", - optimizer_shard_path, - "- resetting optimizer for new run", + optimizer_shard_path = os.path.join( + job.optimizer_state_path, + f"{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.pt", ) - clear_optimizer_state(runtime.optimizer) - runtime.optimizer.reload_model_params() + if os.path.exists(optimizer_shard_path): + print("Loading optimizer state from", optimizer_shard_path) + runtime.optimizer.load_state_dict(torch.load(optimizer_shard_path)) + else: + print( + "No optimizer state found at", + optimizer_shard_path, + "- resetting optimizer for new run", + ) + clear_optimizer_state(runtime.optimizer) + runtime.optimizer.reload_model_params() - print0( - runtime.rank, "Loading packed tensors from", job.disk_packed_tensors["dir"] - ) - packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) - template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0)) - zero_template = _zero_contribution_inputs(template) - num_sequences = job.disk_packed_tensors["num_sequences"] - global_grad_accumulation_sequences = config.grad_accumulation_sequences - num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) - for step_index in range(num_steps): - micro_indices = build_micro_sample_indices( - step_index=step_index, - num_sequences=num_sequences, - global_grad_accumulation_sequences=global_grad_accumulation_sequences, - ) - micro_inputs = select_micro_inputs( - packed_tensors, micro_indices, zero_template + print0( + runtime.rank, + "Loading packed tensors from", + job.disk_packed_tensors["dir"], ) - try: + packed_tensors = packed_tensors_from_dir(**job.disk_packed_tensors) + template = _clone_packed_tensors(select_indexed_inputs(packed_tensors, 0)) + zero_template = _zero_contribution_inputs(template) + num_sequences = job.disk_packed_tensors["num_sequences"] + global_grad_accumulation_sequences = config.grad_accumulation_sequences + num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) + for step_index in range(num_steps): + micro_indices = build_micro_sample_indices( + step_index=step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + micro_inputs = select_micro_inputs( + packed_tensors, micro_indices, zero_template + ) step_result = run_training_step( model_chunks=runtime.model, optimizer=runtime.optimizer, @@ -673,58 +1103,64 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: sample_index=micro_indices, moe_routing_replay_controller=runtime.moe_routing_replay_controller, ) - except Exception: - raise - print0( - runtime.rank, - "Correlation between old and new probabilities:", - step_result.probs_corr, - ) - - if runtime.rank == 0: - with open( - "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" - ) as log_file: - log_msg = json.dumps( - { - "loss": step_result.reduced_loss.item(), - "grad_norm": step_result.grad_norm, - "probs_corr": step_result.probs_corr, - } - ) - print("Logging", log_msg) - log_file.write(log_msg + "\n") - - sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( - runtime.model, - adapter_model, - ) - shard_path = os.path.join( - job.lora_path, - f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", - ) - manifest_path = os.path.join( - job.lora_path, - f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", - ) - print("Saving adapter shard to", shard_path) - save_file(sharded_state_dict, shard_path) - print("Saving adapter shard manifest to", manifest_path) - with open(manifest_path, "w", encoding="utf-8") as manifest_file: - json.dump(sharded_state_manifest, manifest_file, sort_keys=True) + print0( + runtime.rank, + "Correlation between old and new probabilities:", + step_result.probs_corr, + ) - print("Saving optimizer shard to", optimizer_shard_path) - os.makedirs(job.optimizer_state_path, exist_ok=True) - torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) + if runtime.rank == 0: + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_msg = json.dumps( + { + "loss": step_result.reduced_loss.item(), + "grad_norm": step_result.grad_norm, + "probs_corr": step_result.probs_corr, + } + ) + print("Logging", log_msg) + log_file.write(log_msg + "\n") + + sharded_state_dict, sharded_state_manifest = collect_sharded_lora_state( + runtime.model, + adapter_model, + ) + shard_path = os.path.join( + job.lora_path, + f"adapter_model-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.safetensors", + ) + manifest_path = os.path.join( + job.lora_path, + f"adapter_manifest-{runtime.rank + 1:02d}-of-{runtime.world_size:02d}.json", + ) + print("Saving adapter shard to", shard_path) + save_file(sharded_state_dict, shard_path) + print("Saving adapter shard manifest to", manifest_path) + with open(manifest_path, "w", encoding="utf-8") as manifest_file: + json.dump(sharded_state_manifest, manifest_file, sort_keys=True) + + print("Saving optimizer shard to", optimizer_shard_path) + os.makedirs(job.optimizer_state_path, exist_ok=True) + torch.save(runtime.optimizer.state_dict(), optimizer_shard_path) + + if isinstance(job, MegatronMergedTrainJob): + _sync_merged_weights_to_vllm( + runtime, + job.merged_weight_transfer, + pause_generation=True, + ) offload_to_cpu(runtime.model, runtime.optimizer, runtime.rank, offload_state) - del packed_tensors - del template - del zero_template + if job.kind != "sync": + del packed_tensors + del template + del zero_template + if "micro_inputs" in locals(): + del micro_inputs del adapter_model - if "micro_inputs" in locals(): - del micro_inputs gc.collect() torch.cuda.empty_cache() @@ -735,7 +1171,8 @@ def _run_service_loop(runtime: TrainingRuntime) -> None: "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" ) as log_file: log_file.write("all done\n") - shutil.rmtree(job.disk_packed_tensors["dir"]) + if job.kind != "sync": + shutil.rmtree(job.disk_packed_tensors["dir"]) def main() -> None: diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index cb55ce18a..d139deea0 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -6,7 +6,6 @@ import json import logging import os -import socket import subprocess import sys from typing import TYPE_CHECKING, Any, AsyncIterator, Literal, Protocol, cast @@ -34,6 +33,7 @@ from ..preprocessing.tokenize import SFTBatch from ..utils.convert_moe_lora import convert_checkpoint_if_needed from ..utils.get_model_step import get_step_from_dir +from ..utils.network import find_free_tcp_port from ..utils.output_dirs import get_step_checkpoint_dir from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train @@ -208,12 +208,6 @@ def _get_trainer_optimizer(trainer: GRPOTrainer) -> Optimizer: return optimizer -def _find_free_tcp_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return cast(int, sock.getsockname()[1]) - - def _normalize_merged_checkpoint_name(name: str) -> str: # PEFT wraps adapted modules under `.base_layer`, but vLLM expects the # original checkpoint parameter names during update_weights(). @@ -223,6 +217,9 @@ def _normalize_merged_checkpoint_name(name: str) -> str: return normalized +_find_free_tcp_port = find_free_tcp_port + + # ============================================================================ # Model Classes # ============================================================================ @@ -523,7 +520,7 @@ async def _init_merged_weight_transfer(self) -> None: ) from exc inference_world_size = int(world_size_response.json()["world_size"]) - master_port = _find_free_tcp_port() + master_port = find_free_tcp_port() init_info = { "master_address": "127.0.0.1", "master_port": master_port, diff --git a/src/art/utils/network.py b/src/art/utils/network.py new file mode 100644 index 000000000..7b39d187d --- /dev/null +++ b/src/art/utils/network.py @@ -0,0 +1,8 @@ +import socket +from typing import cast + + +def find_free_tcp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return cast(int, sock.getsockname()[1]) diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py index 2f9c08c99..ffb378a71 100644 --- a/tests/unit/test_megatron_dedicated.py +++ b/tests/unit/test_megatron_dedicated.py @@ -13,6 +13,10 @@ from art import TrainableModel, types from art.dev.model import InternalModelConfig from art.megatron.backend import MegatronBackend +from art.megatron.job_protocol import ( + MegatronMergedTrainJob, + MergedWeightTransferInitInfo, +) from art.megatron.service import MegatronService @@ -61,7 +65,9 @@ def __init__( monkeypatch.setattr( "art.megatron.backend.move_to_child_process", lambda *args, **kwargs: (_ for _ in ()).throw( - AssertionError("Dedicated Megatron service should not move to a child process") + AssertionError( + "Dedicated Megatron service should not move to a child process" + ) ), ) monkeypatch.setattr("art.megatron.service.MegatronService", FakeService) @@ -95,16 +101,20 @@ async def test_megatron_service_ensure_megatron_running_uses_trainer_gpus( seen: dict[str, Any] = {} - monkeypatch.setattr("art.megatron.service.subprocess.run", lambda *args, **kwargs: None) + monkeypatch.setattr( + "art.megatron.service.subprocess.run", lambda *args, **kwargs: None + ) async def fake_create_subprocess_shell( command: str, cwd: str, env: dict[str, str], + start_new_session: bool, ) -> Any: seen["command"] = command seen["cwd"] = cwd seen["env"] = env + seen["start_new_session"] = start_new_session return pytypes.SimpleNamespace(returncode=None) monkeypatch.setattr( @@ -119,6 +129,7 @@ async def fake_create_subprocess_shell( assert "--nproc_per_node 2" in seen["command"] assert seen["env"]["CUDA_VISIBLE_DEVICES"] == "0,1" assert seen["env"]["MODEL_IDENTIFIER"] == "Qwen/Qwen3-30B-A3B-Instruct-2507" + assert seen["start_new_session"] is True @pytest.mark.asyncio @@ -145,7 +156,7 @@ async def test_megatron_service_start_openai_server_dedicated_starts_subprocess( lambda _output_dir: str(checkpoint_dir), ) monkeypatch.setattr(service, "_ensure_identity_lora", lambda _path: None) - monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path, source_path=None: None) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) async def fake_start_vllm_subprocess( lora_path: str, @@ -186,9 +197,208 @@ async def test_megatron_service_register_lora_for_step_dedicated_reloads_adapter monkeypatch.setattr( service, "_reload_adapter", - lambda checkpoint_dir, step: seen.append((checkpoint_dir, step)) or asyncio.sleep(0), + lambda checkpoint_dir, step: seen.append((checkpoint_dir, step)) + or asyncio.sleep(0), ) await service.register_lora_for_step(3, "/tmp/checkpoints/3") assert seen == [("/tmp/checkpoints/3", 3)] + + +@pytest.mark.asyncio +async def test_megatron_service_start_openai_server_merged_syncs_step_zero( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + "art.megatron.service.get_last_checkpoint_dir", + lambda _output_dir: str(checkpoint_dir), + ) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) + monkeypatch.setattr( + service, + "_start_vllm_subprocess", + lambda lora_path, port, config: asyncio.sleep(0, result=("127.0.0.1", port)), + ) + monkeypatch.setattr( + service, + "_sync_dedicated_merged_weights", + lambda *, lora_path, step: calls.append((lora_path, step)) or asyncio.sleep(0), + ) + + location = await service.start_openai_server({"server_args": {"port": 8123}}) + + assert location == ("127.0.0.1", 8123) + assert calls == [(str(checkpoint_dir), 0)] + + +@pytest.mark.asyncio +async def test_megatron_service_register_lora_for_step_merged_sets_served_name( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + calls: list[int] = [] + + monkeypatch.setattr( + service, + "_set_served_model_name", + lambda step: calls.append(step) or asyncio.sleep(0), + ) + + await service.register_lora_for_step(3, "/tmp/checkpoints/3") + + assert calls == [3] + + +@pytest.mark.asyncio +async def test_megatron_service_write_job_uses_merged_job_kind( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + job_dir = Path("/tmp/megatron_training_jobs") + log_path = Path("/tmp/megatron_training_log.jsonl") + if job_dir.exists(): + for path in job_dir.glob("*.json"): + path.unlink() + if log_path.exists(): + log_path.unlink() + + service._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=1234, + rank_offset=1, + world_size=2, + ) + + job_path = service._write_job( + MegatronMergedTrainJob( + kind="train_merged", + lora_path="/tmp/checkpoint", + optimizer_state_path="/tmp/optimizer", + disk_packed_tensors={ + "dir": "/tmp/tensors", + "num_sequences": 1, + "sequence_length": 16, + }, + config=types.TrainConfig(learning_rate=5e-5), + experimental_config={}, + merged_weight_transfer=service._build_merged_weight_transfer_spec(1), + ) + ) + + with open(job_path, "r", encoding="utf-8") as handle: + job = MegatronMergedTrainJob.model_validate_json(handle.read()) + + assert job.kind == "train_merged" + assert job.merged_weight_transfer.served_model_name == "megatron-merged@1" + + +@pytest.mark.asyncio +async def test_megatron_service_train_merged_writes_merged_job_and_does_not_reload_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + adapter_path = checkpoint_dir / "adapter_model.safetensors" + adapter_path.write_bytes(b"adapter") + service = MegatronService( + model_name="megatron-merged", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="merged", + ), + output_dir=str(tmp_path), + ) + events: list[Any] = [] + + monkeypatch.setattr( + service, + "_ensure_megatron_running", + lambda: events.append("ensure") or asyncio.sleep(0), + ) + monkeypatch.setattr( + service, "_resolve_active_lora_path", lambda: str(checkpoint_dir) + ) + monkeypatch.setattr( + service, + "_init_merged_weight_transfer", + lambda: events.append("init") or asyncio.sleep(0), + ) + monkeypatch.setattr( + service, + "_write_job", + lambda job: events.append(job) or "/tmp/job.json", + ) + + async def fake_stream_training_log(*, job_path: str, lora_path: str): + events.append(("stream", job_path, lora_path)) + yield {"loss": 1.0} + + monkeypatch.setattr(service, "_stream_training_log", fake_stream_training_log) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path: None) + monkeypatch.setattr( + service, + "_reload_adapter", + lambda checkpoint_dir, step: (_ for _ in ()).throw( + AssertionError("merged mode should not hot-reload a LoRA adapter") + ), + ) + service._merged_weight_transfer_init_info = MergedWeightTransferInitInfo( + master_address="127.0.0.1", + master_port=1234, + rank_offset=1, + world_size=2, + ) + + results = [] + async for result in service.train( + {"dir": "/tmp/tensors", "num_sequences": 1, "sequence_length": 16}, + types.TrainConfig(learning_rate=5e-5), + {}, + ): + results.append(result) + + assert results == [{"loss": 1.0}] + assert events[0:2] == ["ensure", "init"] + job = events[2] + assert isinstance(job, MegatronMergedTrainJob) + assert job.merged_weight_transfer.served_model_name == "megatron-merged@1" + assert events[3] == ("stream", "/tmp/job.json", str(checkpoint_dir))