Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c2039fc
Refresh shared training refactor on top of ART main
Kovbo Mar 28, 2026
19c906b
Rename Megatron merge helper
Kovbo Mar 28, 2026
9d75910
Deduplicate local and shared training logic
Kovbo Mar 28, 2026
6d0d2ae
Fix Megatron rope theta compatibility
Kovbo Mar 28, 2026
9c474c9
Remove Megatron rope theta workaround
Kovbo Mar 28, 2026
2fa8ffb
Align Unsloth SFT weight decay defaults
Kovbo Mar 28, 2026
8cb71cc
remove apex from no-build-isolation-package
Kovbo Mar 28, 2026
3a679cb
update install script
Kovbo Mar 28, 2026
9e90c7d
Fix Megatron job finalization ordering
Kovbo Mar 30, 2026
511d72c
Share Megatron worker loop
Kovbo Mar 30, 2026
2e64da0
Default Megatron grad accumulation by DP size
Kovbo Apr 1, 2026
0cee7cf
Collapse Megatron shared API into train module
Kovbo Apr 1, 2026
911c082
Remove Megatron shared shim
Kovbo Apr 1, 2026
0fa9a2b
Collapse Unsloth shared API into train module
Kovbo Apr 1, 2026
f6cd445
Lighten Megatron orchestration imports
Kovbo Apr 1, 2026
ff28081
Merge branch 'main' of github.com:OpenPipe/ART into feat/shared-train…
Kovbo Apr 2, 2026
3116a1b
Merge branch 'feat/shared-training-code' of github.com:OpenPipe/ART i…
Kovbo Apr 2, 2026
d08f2ad
fix: normalize SFT loss by token count before backward pass
Kovbo Apr 2, 2026
21dd5a3
Revert "fix: normalize SFT loss by token count before backward pass"
Kovbo Apr 3, 2026
d68ae3d
Support Megatron SFT in local backend
Kovbo Apr 3, 2026
f8fee63
refactor: extract create_identity_lora as standalone function
Kovbo Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions scripts/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,14 @@ else
echo "Skipping git reset/clean (GIT_RESET_CLEAN is not true). Preserving synced working tree."
fi

# Install astral-uv
if ! command -v uv >/dev/null 2>&1; then
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
# Install astral-uv (standalone version)
# Always prepend standalone install path so it takes precedence over system/conda uv
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
if ! curl -LsSf https://astral.sh/uv/install.sh | sh; then
echo "Failed to install uv." >&2
exit 1
fi

# Update uv
uv self update

# Sync the dependencies
if [ "${INSTALL_EXTRAS:-false}" = "true" ]; then
uv sync --all-extras
Expand Down
105 changes: 105 additions & 0 deletions src/art/_backend_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Iterable
import time
from typing import Literal

from . import dev
from .metrics_taxonomy import (
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
from .trajectories import TrajectoryGroup
from .types import TrainConfig


def build_rl_train_configs(
*,
learning_rate: float,
advantage_balance: float = 0.0,
scale_rewards: bool = True,
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
mask_prob_ratio: bool = False,
ppo: bool = False,
precalculate_logprobs: bool = False,
epsilon: float | None = None,
epsilon_high: float | None = None,
max_negative_advantage_importance_sampling_weight: float | None = None,
kimi_k2_tau: float | None = None,
kl_penalty_coef: float = 0.0,
allow_training_without_logprobs: bool | None = None,
plot_tensors: bool | None = None,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool | None = None,
logprob_calculation_chunk_size: int | None = None,
num_trajectories_learning_rate_multiplier_power: float | None = None,
kl_ref_adapter_path: str | None = None,
) -> tuple[TrainConfig, dev.TrainConfig]:
config = TrainConfig(
learning_rate=learning_rate,
kl_penalty_coef=kl_penalty_coef,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"ppo": ppo,
"precalculate_logprobs": precalculate_logprobs,
"scale_rewards": scale_rewards,
}

if allow_training_without_logprobs is not None:
dev_config["allow_training_without_logprobs"] = allow_training_without_logprobs
if plot_tensors is not None:
dev_config["plot_tensors"] = plot_tensors
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if scale_learning_rate_by_reward_std_dev is not None:
dev_config["scale_learning_rate_by_reward_std_dev"] = (
scale_learning_rate_by_reward_std_dev
)
if logprob_calculation_chunk_size is not None:
dev_config["logprob_calculation_chunk_size"] = logprob_calculation_chunk_size
if num_trajectories_learning_rate_multiplier_power is not None:
dev_config["num_trajectories_learning_rate_multiplier_power"] = (
num_trajectories_learning_rate_multiplier_power
)
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path

return config, dev_config


def aggregate_rl_training_metrics(
*,
training_metrics: list[dict[str, float]],
trajectory_groups: Iterable[TrajectoryGroup],
trainer_started: float,
) -> dict[str, float]:
groups_list = list(trajectory_groups)
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault("time/step_trainer_s", time.monotonic() - trainer_started)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
)
return avg_metrics
91 changes: 38 additions & 53 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@
from mp_actors import close_proxy, move_to_child_process

from .. import dev
from .._backend_training import (
aggregate_rl_training_metrics,
build_rl_train_configs,
)
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_training_summary_metrics,
summarize_trajectory_groups,
)
Expand Down Expand Up @@ -642,45 +645,36 @@ async def train( # type: ignore[override]
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")

# Build config objects from explicit kwargs
config = TrainConfig(
learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
"ppo": loss_fn == "ppo",
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
"logprob_calculation_chunk_size": logprob_calculation_chunk_size,
"num_trajectories_learning_rate_multiplier_power": num_trajectories_learning_rate_multiplier_power,
}
# Only include optional fields if they're set
if epsilon is not None:
dev_config["epsilon"] = epsilon
if epsilon_high is not None:
dev_config["epsilon_high"] = epsilon_high
if max_negative_advantage_importance_sampling_weight is not None:
dev_config["max_negative_advantage_importance_sampling_weight"] = (
max_negative_advantage_importance_sampling_weight
)
if kimi_k2_tau is not None:
dev_config["kimi_k2_tau"] = kimi_k2_tau
if truncated_importance_sampling is not None:
dev_config["truncated_importance_sampling"] = truncated_importance_sampling
if kl_ref_adapter_path is not None:
dev_config["kl_ref_adapter_path"] = kl_ref_adapter_path
elif kl_penalty_reference_step is not None:
ref_checkpoint_dir = get_step_checkpoint_dir(
resolved_kl_ref_adapter_path = kl_ref_adapter_path
if (
resolved_kl_ref_adapter_path is None
and kl_penalty_reference_step is not None
):
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
dev_config["kl_ref_adapter_path"] = ref_checkpoint_dir
config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
)

# Collect metrics from training
training_metrics: list[dict[str, float]] = []
Expand All @@ -690,21 +684,10 @@ async def train( # type: ignore[override]
):
training_metrics.append(metrics)

# Aggregate metrics
avg_metrics = average_metric_samples(training_metrics)
summary = summarize_trajectory_groups(groups_list)
avg_metrics.setdefault(
"time/step_trainer_s", time.monotonic() - trainer_started
)
avg_metrics.update(
{
key: value
for key, value in build_training_summary_metrics(
summary,
include_trainable_groups=True,
).items()
if key not in avg_metrics
}
avg_metrics = aggregate_rl_training_metrics(
training_metrics=training_metrics,
trajectory_groups=groups_list,
trainer_started=trainer_started,
)

# Get step and checkpoint path
Expand Down Expand Up @@ -822,7 +805,9 @@ async def _train_model(
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
)
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
grad_accumulation_sequences = max(1, int(config.grad_accumulation_sequences))
grad_accumulation_sequences = max(
1, int(config.grad_accumulation_sequences or 1)
)
estimated_gradient_steps = math.ceil(
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
)
Expand Down
2 changes: 1 addition & 1 deletion src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import dev

if TYPE_CHECKING:
from art.unsloth.service import TrainInputs
from art.preprocessing.inputs import TrainInputs


class Loss(BaseModel):
Expand Down
58 changes: 58 additions & 0 deletions src/art/megatron/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio
import datetime
import json
import os
from typing import Any, AsyncIterator

from .jobs import DEFAULT_JOBS_DIR, MegatronJob
from .merge import merge_lora_adapter

DEFAULT_TRAINING_LOG_DIR = "/tmp/megatron_training_logs"


def create_megatron_job_paths() -> tuple[str, str]:
timestamp = datetime.datetime.now().isoformat()
os.makedirs(DEFAULT_JOBS_DIR, exist_ok=True)
os.makedirs(DEFAULT_TRAINING_LOG_DIR, exist_ok=True)
return (
os.path.join(DEFAULT_JOBS_DIR, f"{timestamp}.json"),
os.path.join(DEFAULT_TRAINING_LOG_DIR, f"{timestamp}.jsonl"),
)


def write_megatron_job(job: MegatronJob, *, job_path: str) -> None:
os.makedirs(os.path.dirname(job_path), exist_ok=True)
with open(job_path, "w", encoding="utf-8") as handle:
handle.write(job.model_dump_json())


async def stream_megatron_job(
job: MegatronJob,
*,
job_path: str,
poll_interval: float = 0.1,
) -> AsyncIterator[dict[str, Any]]:
num_lines = 0
try:
while True:
await asyncio.sleep(poll_interval)
try:
with open(job.log_path, "a+", encoding="utf-8") as log_file:
log_file.seek(0)
lines = log_file.readlines()[num_lines:]
except FileNotFoundError:
continue

for line in lines:
if not (line := line.strip()):
continue
if line == "all done":
merge_lora_adapter(job.lora_path)
return
num_lines += 1
yield json.loads(line)
finally:
if os.path.exists(job_path):
os.remove(job_path)
if os.path.exists(job.log_path):
os.remove(job.log_path)
36 changes: 36 additions & 0 deletions src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Literal

from pydantic import BaseModel

from .. import types
from ..preprocessing.pack import DiskPackedTensors

DEFAULT_TRAINING_LOG_PATH = "/tmp/megatron_training_log.jsonl"
DEFAULT_JOBS_DIR = "/tmp/megatron_training_jobs"
DEFAULT_VLLM_WAKE_LOCK_PATH = "/tmp/megatron_vllm_waking"


class MegatronTrainingJob(BaseModel):
lora_path: str
optimizer_state_path: str
disk_packed_tensors: DiskPackedTensors
config: types.TrainConfig
experimental_config: dict[str, Any]
moe_routing_replay_path: str | None = None
moe_routing_replay_strict: bool = True
log_path: str = DEFAULT_TRAINING_LOG_PATH


class MegatronSFTTrainingJob(BaseModel):
job_type: Literal["sft"] = "sft"
lora_path: str
optimizer_state_path: str
sft_data_dir: str
num_batches: int
learning_rates: list[float]
weight_decay: float = 0.0
max_grad_norm: float = 1.0
log_path: str = DEFAULT_TRAINING_LOG_PATH


MegatronJob = MegatronTrainingJob | MegatronSFTTrainingJob
Loading
Loading