Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/art/megatron/job_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Annotated, Literal, TypeAlias

from pydantic import BaseModel, Field, TypeAdapter

from art import dev, types
from art.megatron.routing_replay import MoeRoutingReplayBundle
from art.preprocessing.pack import DiskPackedTensors


class MergedWeightTransferInitInfo(BaseModel):
master_address: str
master_port: int
rank_offset: int
world_size: int


class MergedWeightTransferSpec(BaseModel):
init_info: MergedWeightTransferInitInfo
vllm_base_url: str
served_model_name: str


class MegatronSyncJob(BaseModel):
kind: Literal["sync"]
lora_path: str
merged_weight_transfer: MergedWeightTransferSpec


class _MegatronTrainJobBase(BaseModel):
lora_path: str
optimizer_state_path: str
disk_packed_tensors: DiskPackedTensors
config: types.TrainConfig
experimental_config: dev.TrainConfig
moe_routing_replay_path: str | None = None
moe_routing_replay_strict: bool = True


class MegatronLoraTrainJob(_MegatronTrainJobBase):
kind: Literal["train_lora"]


class MegatronMergedTrainJob(_MegatronTrainJobBase):
kind: Literal["train_merged"]
merged_weight_transfer: MergedWeightTransferSpec


MegatronLoraTrainJob.model_rebuild(
force=True,
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
)
MegatronMergedTrainJob.model_rebuild(
force=True,
_types_namespace={"MoeRoutingReplayBundle": MoeRoutingReplayBundle},
)

MegatronJob: TypeAlias = Annotated[
MegatronSyncJob | MegatronLoraTrainJob | MegatronMergedTrainJob,
Field(discriminator="kind"),
]


def dump_megatron_job(job: MegatronJob) -> str:
return TypeAdapter(MegatronJob).dump_json(job).decode()


def load_megatron_job(raw: str | bytes) -> MegatronJob:
return TypeAdapter(MegatronJob).validate_json(raw)
1 change: 1 addition & 0 deletions src/art/megatron/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_provider(
)
)
provider = bridge.to_megatron_provider()
setattr(provider, "art_bridge", bridge)
base_layer_spec = provider.transformer_layer_spec

def _flex_attention_layer_spec(
Expand Down
Loading
Loading