Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/art/megatron/backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from mp_actors import move_to_child_process

from ..local.backend import LocalBackend
Expand All @@ -17,6 +19,7 @@ def __init__(

async def _get_service(self, model: TrainableModel) -> ModelService:
from ..dev.get_model_config import get_model_config
from ..dev.validate import is_dedicated_mode, validate_dedicated_config
from .service import MegatronService

if model.name not in self._services:
Expand All @@ -25,13 +28,19 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
)
validate_dedicated_config(config)
dedicated = is_dedicated_mode(config)
if dedicated:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(gpu_id) for gpu_id in config["trainer_gpu_ids"]
)
self._services[model.name] = MegatronService(
model_name=model.name,
base_model=model.base_model,
config=config,
output_dir=get_model_dir(model=model, art_path=self._path),
)
if not self._in_process:
if not dedicated and not self._in_process:
self._services[model.name] = move_to_child_process(
self._services[model.name],
process_name="megatron-service",
Expand Down
280 changes: 266 additions & 14 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
from dataclasses import dataclass
from dataclasses import dataclass, field
import datetime
from functools import cached_property
import json
import logging
import os
from pathlib import Path
import shlex
import shutil
import subprocess
from typing import Any, AsyncIterator
import sys
from typing import Any, AsyncIterator, Literal

from peft.tuners.lora.config import LoraConfig
from pydantic import BaseModel
Expand All @@ -21,6 +23,7 @@

from .. import dev, types
from ..dev.get_model_config import default_target_modules
from ..dev.validate import is_dedicated_mode
from ..local.checkpoints import get_last_checkpoint_dir
from ..preprocessing.pack import DiskPackedTensors
from ..preprocessing.tokenize import SFTBatch
Expand Down Expand Up @@ -49,6 +52,9 @@ class MegatronTrainingJob(BaseModel):
)


logger = logging.getLogger(__name__)


@dataclass
class MegatronService:
model_name: str
Expand All @@ -60,6 +66,24 @@ class MegatronService:
_lora_id_counter: int = 1
_megatron_process: asyncio.subprocess.Process | None = None
_optimizer_state_path: str | None = None
_vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg]
_vllm_log_file: Any = field(default=None, repr=False)
_vllm_host: str = "127.0.0.1"
_vllm_port: int = 0

@property
def is_dedicated(self) -> bool:
return is_dedicated_mode(self.config)

@property
def rollout_weights_mode(self) -> Literal["lora", "merged"]:
mode = self.config.get("rollout_weights_mode", "lora")
assert mode in {"lora", "merged"}
return mode

@property
def _vllm_base_url(self) -> str:
return f"http://{self._vllm_host}:{self._vllm_port}"

def _next_lora_id(self) -> int:
self._lora_id_counter += 1
Expand Down Expand Up @@ -171,6 +195,144 @@ def _ensure_lora_adapter_config(
return
self._default_lora_adapter_config().save_pretrained(lora_path)

def _resolve_active_lora_path(self) -> str:
lora_path = get_last_checkpoint_dir(self.output_dir)
if lora_path is None:
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
self._latest_step = 0
else:
self._latest_step = get_step_from_dir(self.output_dir)
self._ensure_identity_lora(lora_path)
self._ensure_lora_adapter_config(lora_path)
return lora_path

async def _start_vllm_subprocess(
self,
lora_path: str,
port: int,
config: dev.OpenAIServerConfig | None,
) -> tuple[str, int]:
import atexit
import httpx

inference_gpu_ids = self.config["inference_gpu_ids"]
cuda_devices = ",".join(str(gpu_id) for gpu_id in inference_gpu_ids)

server_args: dict[str, object] = {
"return_tokens_as_token_ids": True,
"enable_auto_tool_choice": True,
"tool_call_parser": "hermes",
}
if config and "server_args" in config:
server_args.update(dict(config["server_args"]))
for key in ("port", "host", "lora_modules", "api_key"):
server_args.pop(key, None)

engine_args = dict(self.config.get("engine_args", {}))
if config and "engine_args" in config:
engine_args.update(dict(config["engine_args"]))
engine_args.setdefault("generation_config", "vllm")
engine_args["enable_lora"] = True
engine_args.setdefault("max_loras", 2)
for key in ("model", "served_model_name", "enable_sleep_mode"):
engine_args.pop(key, None)

cmd = [
sys.executable,
"-m",
"art.vllm.dedicated_server",
f"--model={self.base_model}",
f"--port={port}",
f"--host={self._vllm_host}",
f"--cuda-visible-devices={cuda_devices}",
f"--lora-path={lora_path}",
f"--served-model-name={self.model_name}@{self._latest_step}",
f"--rollout-weights-mode={self.rollout_weights_mode}",
f"--engine-args-json={json.dumps(engine_args)}",
f"--server-args-json={json.dumps(server_args)}",
]

log_dir = os.path.join(self.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
self._vllm_log_file = open(
os.path.join(log_dir, "vllm-dedicated.log"), "w", buffering=1
)
self._vllm_process = subprocess.Popen(
cmd,
stdout=self._vllm_log_file,
stderr=subprocess.STDOUT,
bufsize=1,
)
self._vllm_port = port

timeout = float(os.environ.get("ART_DEDICATED_VLLM_TIMEOUT", 600))
elapsed = 0.0
async with httpx.AsyncClient() as client:
while elapsed < timeout:
if self._vllm_process.poll() is not None:
raise RuntimeError(
"vLLM subprocess exited with code "
f"{self._vllm_process.returncode}. "
f"Check logs at {log_dir}/vllm-dedicated.log"
)
try:
response = await client.get(
f"{self._vllm_base_url}/v1/models",
timeout=5.0,
)
if response.status_code == 200:
break
except (httpx.ConnectError, httpx.ReadTimeout):
pass
await asyncio.sleep(1.0)
elapsed += 1.0
else:
self._stop_vllm_subprocess()
raise TimeoutError(
f"vLLM subprocess did not become ready within {timeout}s. "
f"Check logs at {log_dir}/vllm-dedicated.log"
)

atexit.register(self.close)
logger.info("vLLM subprocess ready on port %d (GPUs: %s)", port, cuda_devices)
return self._vllm_host, self._vllm_port

async def _reload_adapter(self, checkpoint_path: str, step: int) -> None:
import httpx

async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._vllm_base_url}/v1/load_lora_adapter",
json={
"lora_name": f"{self.model_name}@{step}",
"lora_path": checkpoint_path,
"load_inplace": True,
},
timeout=60.0,
)
response.raise_for_status()
self._latest_step = step

def _stop_vllm_subprocess(self) -> None:
if self._vllm_process is not None:
self._vllm_process.terminate()
try:
self._vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self._vllm_process.kill()
self._vllm_process.wait()
self._vllm_process = None
if self._vllm_log_file is not None:
self._vllm_log_file.close()
self._vllm_log_file = None

def _stop_megatron_process(self) -> None:
if self._megatron_process is None:
return
if self._megatron_process.returncode is None:
self._megatron_process.terminate()
self._megatron_process = None

async def _add_lora_aliases(
self, llm: AsyncLLM, step: int, checkpoint_dir: str
) -> None:
Expand All @@ -186,6 +348,10 @@ async def _add_lora_aliases(
self._latest_step = step

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._reload_adapter(checkpoint_dir, step)
return
llm = await self.llm
await llm.pause_generation()
await self._add_lora_aliases(llm, step, checkpoint_dir)
Expand All @@ -209,29 +375,36 @@ async def _ensure_megatron_running(self) -> None:
subprocess.run(["pkill", "-9", "megatron-service"], check=False)
train_script = Path(__file__).parent / "train.py"
project_root = Path(__file__).resolve().parents[3]
num_gpus = torch.cuda.device_count()
os.environ["MODEL_IDENTIFIER"] = self.base_model
launch_env = os.environ.copy()
if self.is_dedicated:
trainer_gpu_ids = self.config["trainer_gpu_ids"]
num_gpus = len(trainer_gpu_ids)
launch_env["CUDA_VISIBLE_DEVICES"] = ",".join(
str(gpu_id) for gpu_id in trainer_gpu_ids
)
else:
num_gpus = torch.cuda.device_count()
launch_env["MODEL_IDENTIFIER"] = self.base_model

command = (
f"{setup_cmd}uv run --project {shlex.quote(str(project_root))} "
f"torchrun --nproc_per_node {num_gpus} {shlex.quote(str(train_script))}"
f"{setup_cmd}{shlex.quote(sys.executable)} -m torch.distributed.run "
f"--nproc_per_node {num_gpus} {shlex.quote(str(train_script))}"
)
self._megatron_process = await asyncio.create_subprocess_shell(
command,
cwd=str(project_root),
env=launch_env,
)

async def start_openai_server(
self, config: dev.OpenAIServerConfig | None
) -> tuple[str, int]:
lora_path = get_last_checkpoint_dir(self.output_dir)
if lora_path is None:
lora_path = get_step_checkpoint_dir(self.output_dir, 0)
self._latest_step = 0
else:
self._latest_step = get_step_from_dir(self.output_dir)
self._ensure_identity_lora(lora_path)
self._ensure_lora_adapter_config(lora_path)
lora_path = self._resolve_active_lora_path()

if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
port = (config or {}).get("server_args", {}).get("port", 8000)
return await self._start_vllm_subprocess(lora_path, port, config)

lora_path_for_server = (
lora_path if self._adapter_has_weights(lora_path) else None
Expand All @@ -250,15 +423,94 @@ async def start_openai_server(
)

async def vllm_engine_is_sleeping(self) -> bool:
if self.is_dedicated:
return False
return self._is_sleeping

async def aclose(self) -> None:
self.close()

def close(self) -> None:
self._stop_vllm_subprocess()
self._stop_megatron_process()

async def train(
self,
disk_packed_tensors: DiskPackedTensors,
config: types.TrainConfig,
_config: dev.TrainConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
if self.is_dedicated:
assert self.rollout_weights_mode == "lora"
await self._ensure_megatron_running()

lora_path = self._resolve_active_lora_path()
self._optimizer_state_path = self._get_optimizer_state_path()

jobs_dir = "/tmp/megatron_training_jobs"
os.makedirs(jobs_dir, exist_ok=True)
for job_name in os.listdir(jobs_dir):
if job_name.endswith(".json"):
os.remove(os.path.join(jobs_dir, job_name))
if _config.get("moe_routing_replay_bundle") is not None:
raise RuntimeError(
"moe_routing_replay_bundle is only supported for in-process/runtime APIs; "
"MegatronService subprocess jobs must use moe_routing_replay_path."
)
job = MegatronTrainingJob(
lora_path=lora_path,
optimizer_state_path=self._optimizer_state_path,
disk_packed_tensors=disk_packed_tensors,
config=config,
experimental_config=_config,
moe_routing_replay_path=_config.get("moe_routing_replay_path"),
moe_routing_replay_strict=_config.get(
"moe_routing_replay_strict", True
),
)
job_path = os.path.join(
jobs_dir, f"{datetime.datetime.now().isoformat()}.json"
)
with open(job_path, "w", encoding="utf-8") as handle:
handle.write(job.model_dump_json())

num_lines = 0
while True:
await asyncio.sleep(0.1)
try:
with open(
"/tmp/megatron_training_log.jsonl", "a+", encoding="utf-8"
) as log_file:
log_file.seek(0)
lines = log_file.readlines()[num_lines:]
for line in lines:
line = line.strip()
if not line:
continue
if line == "all done":
self._merge_lora_adapter(lora_path)
os.remove("/tmp/megatron_training_log.jsonl")
break
num_lines += 1
yield json.loads(line)
else:
continue
break
except FileNotFoundError:
continue

next_step = self._latest_step + 1
new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step)
os.makedirs(new_checkpoint_dir, exist_ok=True)
shutil.copy(
f"{lora_path}/adapter_model.safetensors",
f"{new_checkpoint_dir}/adapter_model.safetensors",
)
self._ensure_lora_adapter_config(new_checkpoint_dir, source_path=lora_path)
await self._reload_adapter(new_checkpoint_dir, next_step)
return

llm = await self.llm
await llm.pause_generation()
await llm.reset_prefix_cache()
Expand Down
Loading
Loading