From b16a5bc97a6be54b5335a72e991b8f8c164f3f85 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Tue, 31 Mar 2026 13:34:15 -0700 Subject: [PATCH] feat: Add dedicated Megatron lora mode --- src/art/megatron/backend.py | 11 +- src/art/megatron/service.py | 280 ++++++++++++++++++++++++-- tests/unit/test_megatron_dedicated.py | 194 ++++++++++++++++++ 3 files changed, 470 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_megatron_dedicated.py diff --git a/src/art/megatron/backend.py b/src/art/megatron/backend.py index 1ebdc7a17..2227eecb7 100644 --- a/src/art/megatron/backend.py +++ b/src/art/megatron/backend.py @@ -1,3 +1,5 @@ +import os + from mp_actors import move_to_child_process from ..local.backend import LocalBackend @@ -17,6 +19,7 @@ def __init__( async def _get_service(self, model: TrainableModel) -> ModelService: from ..dev.get_model_config import get_model_config + from ..dev.validate import is_dedicated_mode, validate_dedicated_config from .service import MegatronService if model.name not in self._services: @@ -25,13 +28,19 @@ async def _get_service(self, model: TrainableModel) -> ModelService: output_dir=get_model_dir(model=model, art_path=self._path), config=model._internal_config, ) + validate_dedicated_config(config) + dedicated = is_dedicated_mode(config) + if dedicated: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + str(gpu_id) for gpu_id in config["trainer_gpu_ids"] + ) self._services[model.name] = MegatronService( model_name=model.name, base_model=model.base_model, config=config, output_dir=get_model_dir(model=model, art_path=self._path), ) - if not self._in_process: + if not dedicated and not self._in_process: self._services[model.name] = move_to_child_process( self._services[model.name], process_name="megatron-service", diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 9ac3c14be..5402a5d36 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -1,14 +1,16 @@ import asyncio -from dataclasses import dataclass +from dataclasses import dataclass, field import datetime from functools import cached_property import json +import logging import os from pathlib import Path import shlex import shutil import subprocess -from typing import Any, AsyncIterator +import sys +from typing import Any, AsyncIterator, Literal from peft.tuners.lora.config import LoraConfig from pydantic import BaseModel @@ -21,6 +23,7 @@ from .. import dev, types from ..dev.get_model_config import default_target_modules +from ..dev.validate import is_dedicated_mode from ..local.checkpoints import get_last_checkpoint_dir from ..preprocessing.pack import DiskPackedTensors from ..preprocessing.tokenize import SFTBatch @@ -49,6 +52,9 @@ class MegatronTrainingJob(BaseModel): ) +logger = logging.getLogger(__name__) + + @dataclass class MegatronService: model_name: str @@ -60,6 +66,24 @@ class MegatronService: _lora_id_counter: int = 1 _megatron_process: asyncio.subprocess.Process | None = None _optimizer_state_path: str | None = None + _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] + _vllm_log_file: Any = field(default=None, repr=False) + _vllm_host: str = "127.0.0.1" + _vllm_port: int = 0 + + @property + def is_dedicated(self) -> bool: + return is_dedicated_mode(self.config) + + @property + def rollout_weights_mode(self) -> Literal["lora", "merged"]: + mode = self.config.get("rollout_weights_mode", "lora") + assert mode in {"lora", "merged"} + return mode + + @property + def _vllm_base_url(self) -> str: + return f"http://{self._vllm_host}:{self._vllm_port}" def _next_lora_id(self) -> int: self._lora_id_counter += 1 @@ -171,6 +195,144 @@ def _ensure_lora_adapter_config( return self._default_lora_adapter_config().save_pretrained(lora_path) + def _resolve_active_lora_path(self) -> str: + lora_path = get_last_checkpoint_dir(self.output_dir) + if lora_path is None: + lora_path = get_step_checkpoint_dir(self.output_dir, 0) + self._latest_step = 0 + else: + self._latest_step = get_step_from_dir(self.output_dir) + self._ensure_identity_lora(lora_path) + self._ensure_lora_adapter_config(lora_path) + return lora_path + + async def _start_vllm_subprocess( + self, + lora_path: str, + port: int, + config: dev.OpenAIServerConfig | None, + ) -> tuple[str, int]: + import atexit + import httpx + + inference_gpu_ids = self.config["inference_gpu_ids"] + cuda_devices = ",".join(str(gpu_id) for gpu_id in inference_gpu_ids) + + server_args: dict[str, object] = { + "return_tokens_as_token_ids": True, + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + if config and "server_args" in config: + server_args.update(dict(config["server_args"])) + for key in ("port", "host", "lora_modules", "api_key"): + server_args.pop(key, None) + + engine_args = dict(self.config.get("engine_args", {})) + if config and "engine_args" in config: + engine_args.update(dict(config["engine_args"])) + engine_args.setdefault("generation_config", "vllm") + engine_args["enable_lora"] = True + engine_args.setdefault("max_loras", 2) + for key in ("model", "served_model_name", "enable_sleep_mode"): + engine_args.pop(key, None) + + cmd = [ + sys.executable, + "-m", + "art.vllm.dedicated_server", + f"--model={self.base_model}", + f"--port={port}", + f"--host={self._vllm_host}", + f"--cuda-visible-devices={cuda_devices}", + f"--lora-path={lora_path}", + f"--served-model-name={self.model_name}@{self._latest_step}", + f"--rollout-weights-mode={self.rollout_weights_mode}", + f"--engine-args-json={json.dumps(engine_args)}", + f"--server-args-json={json.dumps(server_args)}", + ] + + log_dir = os.path.join(self.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + self._vllm_log_file = open( + os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1 + ) + self._vllm_process = subprocess.Popen( + cmd, + stdout=self._vllm_log_file, + stderr=subprocess.STDOUT, + bufsize=1, + ) + self._vllm_port = port + + timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600)) + elapsed = 0.0 + async with httpx.AsyncClient() as client: + while elapsed < timeout: + if self._vllm_process.poll() is not None: + raise RuntimeError( + "vLLM subprocess exited with code " + f"{self._vllm_process.returncode}. " + f"Check logs at {log_dir}/vllm-dedicated.log" + ) + try: + response = await client.get( + f"{self._vllm_base_url}/v1/models", + timeout=5.0, + ) + if response.status_code == 200: + break + except (httpx.ConnectError, httpx.ReadTimeout): + pass + await asyncio.sleep(1.0) + elapsed += 1.0 + else: + self._stop_vllm_subprocess() + raise TimeoutError( + f"vLLM subprocess did not become ready within {timeout}s. " + f"Check logs at {log_dir}/vllm-dedicated.log" + ) + + atexit.register(self.close) + logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices) + return self._vllm_host, self._vllm_port + + async def _reload_adapter(self, checkpoint_path: str, step: int) -> None: + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self._vllm_base_url}/v1/load_lora_adapter", + json={ + "lora_name": f"{self.model_name}@{step}", + "lora_path": checkpoint_path, + "load_inplace": True, + }, + timeout=60.0, + ) + response.raise_for_status() + self._latest_step = step + + def _stop_vllm_subprocess(self) -> None: + if self._vllm_process is not None: + self._vllm_process.terminate() + try: + self._vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self._vllm_process.kill() + self._vllm_process.wait() + self._vllm_process = None + if self._vllm_log_file is not None: + self._vllm_log_file.close() + self._vllm_log_file = None + + def _stop_megatron_process(self) -> None: + if self._megatron_process is None: + return + if self._megatron_process.returncode is None: + self._megatron_process.terminate() + self._megatron_process = None + async def _add_lora_aliases( self, llm: AsyncLLM, step: int, checkpoint_dir: str ) -> None: @@ -186,6 +348,10 @@ async def _add_lora_aliases( self._latest_step = step async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: + if self.is_dedicated: + assert self.rollout_weights_mode == "lora" + await self._reload_adapter(checkpoint_dir, step) + return llm = await self.llm await llm.pause_generation() await self._add_lora_aliases(llm, step, checkpoint_dir) @@ -209,29 +375,36 @@ async def _ensure_megatron_running(self) -> None: subprocess.run(["pkill", "-9", "megatron-service"], check=False) train_script = Path(__file__).parent / "train.py" project_root = Path(__file__).resolve().parents[3] - num_gpus = torch.cuda.device_count() - os.environ["MODEL_IDENTIFIER"] = self.base_model + launch_env = os.environ.copy() + if self.is_dedicated: + trainer_gpu_ids = self.config["trainer_gpu_ids"] + num_gpus = len(trainer_gpu_ids) + launch_env["CUDA_VISIBLE_DEVICES"] = ",".join( + str(gpu_id) for gpu_id in trainer_gpu_ids + ) + else: + num_gpus = torch.cuda.device_count() + launch_env["MODEL_IDENTIFIER"] = self.base_model command = ( - f"{setup_cmd}uv run --project {shlex.quote(str(project_root))} " - f"torchrun --nproc_per_node {num_gpus} {shlex.quote(str(train_script))}" + f"{setup_cmd}{shlex.quote(sys.executable)} -m torch.distributed.run " + f"--nproc_per_node {num_gpus} {shlex.quote(str(train_script))}" ) self._megatron_process = await asyncio.create_subprocess_shell( command, cwd=str(project_root), + env=launch_env, ) async def start_openai_server( self, config: dev.OpenAIServerConfig | None ) -> tuple[str, int]: - lora_path = get_last_checkpoint_dir(self.output_dir) - if lora_path is None: - lora_path = get_step_checkpoint_dir(self.output_dir, 0) - self._latest_step = 0 - else: - self._latest_step = get_step_from_dir(self.output_dir) - self._ensure_identity_lora(lora_path) - self._ensure_lora_adapter_config(lora_path) + lora_path = self._resolve_active_lora_path() + + if self.is_dedicated: + assert self.rollout_weights_mode == "lora" + port = (config or {}).get("server_args", {}).get("port", 8000) + return await self._start_vllm_subprocess(lora_path, port, config) lora_path_for_server = ( lora_path if self._adapter_has_weights(lora_path) else None @@ -250,8 +423,17 @@ async def start_openai_server( ) async def vllm_engine_is_sleeping(self) -> bool: + if self.is_dedicated: + return False return self._is_sleeping + async def aclose(self) -> None: + self.close() + + def close(self) -> None: + self._stop_vllm_subprocess() + self._stop_megatron_process() + async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -259,6 +441,76 @@ async def train( _config: dev.TrainConfig, verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: + if self.is_dedicated: + assert self.rollout_weights_mode == "lora" + await self._ensure_megatron_running() + + lora_path = self._resolve_active_lora_path() + self._optimizer_state_path = self._get_optimizer_state_path() + + jobs_dir = "/tmp/megatron_training_jobs" + os.makedirs(jobs_dir, exist_ok=True) + for job_name in os.listdir(jobs_dir): + if job_name.endswith(".json"): + os.remove(os.path.join(jobs_dir, job_name)) + if _config.get("moe_routing_replay_bundle") is not None: + raise RuntimeError( + "moe_routing_replay_bundle is only supported for in-process/runtime APIs; " + "MegatronService subprocess jobs must use moe_routing_replay_path." + ) + job = MegatronTrainingJob( + lora_path=lora_path, + optimizer_state_path=self._optimizer_state_path, + disk_packed_tensors=disk_packed_tensors, + config=config, + experimental_config=_config, + moe_routing_replay_path=_config.get("moe_routing_replay_path"), + moe_routing_replay_strict=_config.get( + "moe_routing_replay_strict", True + ), + ) + job_path = os.path.join( + jobs_dir, f"{datetime.datetime.now().isoformat()}.json" + ) + with open(job_path, "w", encoding="utf-8") as handle: + handle.write(job.model_dump_json()) + + num_lines = 0 + while True: + await asyncio.sleep(0.1) + try: + with open( + "/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8" + ) as log_file: + log_file.seek(0) + lines = log_file.readlines()[num_lines:] + for line in lines: + line = line.strip() + if not line: + continue + if line == "all done": + self._merge_lora_adapter(lora_path) + os.remove("/tmp/megatron_training_log.jsonl") + break + num_lines += 1 + yield json.loads(line) + else: + continue + break + except FileNotFoundError: + continue + + next_step = self._latest_step + 1 + new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) + os.makedirs(new_checkpoint_dir, exist_ok=True) + shutil.copy( + f"{lora_path}/adapter_model.safetensors", + f"{new_checkpoint_dir}/adapter_model.safetensors", + ) + self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path) + await self._reload_adapter(new_checkpoint_dir, next_step) + return + llm = await self.llm await llm.pause_generation() await llm.reset_prefix_cache() diff --git a/tests/unit/test_megatron_dedicated.py b/tests/unit/test_megatron_dedicated.py new file mode 100644 index 000000000..2f9c08c99 --- /dev/null +++ b/tests/unit/test_megatron_dedicated.py @@ -0,0 +1,194 @@ +import asyncio +import os +from pathlib import Path +import shlex +import sys +import types as pytypes +from typing import Any + +import pytest + +pytest.importorskip("vllm") + +from art import TrainableModel, types +from art.dev.model import InternalModelConfig +from art.megatron.backend import MegatronBackend +from art.megatron.service import MegatronService + + +@pytest.mark.asyncio +async def test_megatron_backend_dedicated_uses_trainer_gpus_without_child_process( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config = InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="lora", + ) + model = TrainableModel( + name="megatron-dedicated", + project="unit-tests", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + base_path=str(tmp_path), + _internal_config=config, + ) + backend = MegatronBackend(path=str(tmp_path)) + validated: dict[str, Any] = {} + + class FakeService: + def __init__( + self, + *, + model_name: str, + base_model: str, + config: InternalModelConfig, + output_dir: str, + ) -> None: + self.model_name = model_name + self.base_model = base_model + self.config = config + self.output_dir = output_dir + + monkeypatch.setattr( + "art.dev.get_model_config.get_model_config", + lambda *args, **kwargs: config, + ) + monkeypatch.setattr( + "art.dev.validate.validate_dedicated_config", + lambda cfg: validated.setdefault("config", cfg), + ) + monkeypatch.setattr( + "art.megatron.backend.move_to_child_process", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("Dedicated Megatron service should not move to a child process") + ), + ) + monkeypatch.setattr("art.megatron.service.MegatronService", FakeService) + + service = await backend._get_service(model) + + assert isinstance(service, FakeService) + assert validated["config"] is config + assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" + + +@pytest.mark.asyncio +async def test_megatron_service_ensure_megatron_running_uses_trainer_gpus( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="megatron-dedicated", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[2], + rollout_weights_mode="lora", + ), + output_dir=str(tmp_path), + ) + megatron_module = pytypes.ModuleType("megatron") + megatron_bridge_module = pytypes.ModuleType("megatron.bridge") + monkeypatch.setitem(sys.modules, "megatron", megatron_module) + monkeypatch.setitem(sys.modules, "megatron.bridge", megatron_bridge_module) + + seen: dict[str, Any] = {} + + monkeypatch.setattr("art.megatron.service.subprocess.run", lambda *args, **kwargs: None) + + async def fake_create_subprocess_shell( + command: str, + cwd: str, + env: dict[str, str], + ) -> Any: + seen["command"] = command + seen["cwd"] = cwd + seen["env"] = env + return pytypes.SimpleNamespace(returncode=None) + + monkeypatch.setattr( + "art.megatron.service.asyncio.create_subprocess_shell", + fake_create_subprocess_shell, + ) + + await service._ensure_megatron_running() + + assert shlex.quote(sys.executable) in seen["command"] + assert "torch.distributed.run" in seen["command"] + assert "--nproc_per_node 2" in seen["command"] + assert seen["env"]["CUDA_VISIBLE_DEVICES"] == "0,1" + assert seen["env"]["MODEL_IDENTIFIER"] == "Qwen/Qwen3-30B-A3B-Instruct-2507" + + +@pytest.mark.asyncio +async def test_megatron_service_start_openai_server_dedicated_starts_subprocess( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + checkpoint_dir = tmp_path / "checkpoints" / "0000" + checkpoint_dir.mkdir(parents=True) + service = MegatronService( + model_name="megatron-dedicated", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="lora", + ), + output_dir=str(tmp_path), + ) + seen: dict[str, Any] = {} + + monkeypatch.setattr( + "art.megatron.service.get_last_checkpoint_dir", + lambda _output_dir: str(checkpoint_dir), + ) + monkeypatch.setattr(service, "_ensure_identity_lora", lambda _path: None) + monkeypatch.setattr(service, "_ensure_lora_adapter_config", lambda _path, source_path=None: None) + + async def fake_start_vllm_subprocess( + lora_path: str, + port: int, + config: dict[str, Any] | None, + ) -> tuple[str, int]: + seen["lora_path"] = lora_path + seen["port"] = port + seen["config"] = config + return ("127.0.0.1", port) + + monkeypatch.setattr(service, "_start_vllm_subprocess", fake_start_vllm_subprocess) + + location = await service.start_openai_server({"server_args": {"port": 8123}}) + + assert location == ("127.0.0.1", 8123) + assert seen["lora_path"] == str(checkpoint_dir) + assert seen["port"] == 8123 + + +@pytest.mark.asyncio +async def test_megatron_service_register_lora_for_step_dedicated_reloads_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + service = MegatronService( + model_name="megatron-dedicated", + base_model="Qwen/Qwen3-30B-A3B-Instruct-2507", + config=InternalModelConfig( + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + rollout_weights_mode="lora", + ), + output_dir=str(tmp_path), + ) + seen: list[tuple[str, int]] = [] + + monkeypatch.setattr( + service, + "_reload_adapter", + lambda checkpoint_dir, step: seen.append((checkpoint_dir, step)) or asyncio.sleep(0), + ) + + await service.register_lora_for_step(3, "/tmp/checkpoints/3") + + assert seen == [("/tmp/checkpoints/3", 3)]