Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

class RouterType(str, Enum):
SGLANG = "sglang"
DYNAMO = "dynamo"


class SGLangGatewayRouterConfig(CoreModel):
Expand Down Expand Up @@ -45,8 +46,15 @@ class SGLangServiceRouterConfig(CoreModel):

class ReplicaGroupRouterConfig(CoreModel):
type: Annotated[
Literal["sglang"],
Field(description="The router implementation for this replica group."),
Literal["sglang", "dynamo"],
Field(
description=(
"The router implementation for this replica group. "
"`sglang` runs the SGLang router and dstack syncs worker URLs to it. "
"`dynamo` runs the NVIDIA Dynamo frontend, which discovers workers "
"itself via etcd/NATS."
),
),
] = "sglang"


Expand Down
26 changes: 26 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from dstack._internal.core.models.repos import AnyRunRepoData
from dstack._internal.core.models.resources import Memory, ResourcesSpec
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint
from dstack._internal.utils import common as common_utils
Expand Down Expand Up @@ -603,6 +604,31 @@ def _merged_profile(cls, values) -> Dict:
values["merged_profile"] = merged_profile
return values

@root_validator
def _validate_dynamo_no_retry(cls, values) -> Dict:
"""Reject `retry` for services with a Dynamo router replica group.
Dynamo workers cache the router's internal IP at provisioning time. A
retry would produce a new router and likely a new internal_ip, leaving workers bound
to a router that no longer exists.
"""
merged_profile = values.get("merged_profile")
cfg = values.get("configuration")
if merged_profile is None or merged_profile.retry is None:
return values
if not isinstance(cfg, ServiceConfiguration):
return values
for g in cfg.replica_groups:
if g.router is not None and g.router.type == RouterType.DYNAMO:
raise ValueError(
"Retry cannot be configured for services with a Dynamo "
"router replica group. The router's address must remain "
"stable for the life of the run; allowing retry would "
"leave workers bound to a router that no longer exists. "
"Remove `retry` from the profile/configuration and "
"re-apply."
)
return values


class ServiceModelSpec(CoreModel):
name: str
Expand Down
18 changes: 9 additions & 9 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ServiceConfig,
)
from dstack._internal.proxy.lib import models
from dstack._internal.proxy.lib.const import SGLANG_WHITELISTED_PATHS
from dstack._internal.proxy.lib.const import ROUTER_WHITELISTED_PATHS
from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.proxy.lib.services.service_connection import (
Expand Down Expand Up @@ -344,7 +344,7 @@ async def get_nginx_service_config(
) -> ServiceConfig:
limit_req_zones: list[LimitReqZoneConfig] = []
locations: list[LocationConfig] = []
is_sglang = (
is_router = (
service.router is not None and service.router.type == RouterType.SGLANG
) or service.has_router_replica
sglang_limits: dict[str, LimitReqConfig] = {}
Expand All @@ -361,8 +361,8 @@ async def get_nginx_service_config(
limit_req_zones.append(
LimitReqZoneConfig(name=zone_name, key=key, rpm=round(rate_limit.rps * 60))
)
if is_sglang:
for path in SGLANG_WHITELISTED_PATHS:
if is_router:
for path in ROUTER_WHITELISTED_PATHS:
if rate_limit.prefix == path or path.startswith(rate_limit.prefix):
# Use the longest prefix if multiple prefixes match the same path
current_prefix_len = len(rate_limit.prefix)
Expand All @@ -381,9 +381,9 @@ async def get_nginx_service_config(
)
)

# Add SGLang whitelisted paths as locations
if is_sglang:
for path in SGLANG_WHITELISTED_PATHS:
# Add router whitelisted paths as locations
if is_router:
for path in ROUTER_WHITELISTED_PATHS:
# Use prefix match for paths that end with a slash and exact match for paths that don't
if path.endswith("/"):
locations.append(LocationConfig(prefix=path, limit_req=sglang_limits.get(path)))
Expand All @@ -392,8 +392,8 @@ async def get_nginx_service_config(
LocationConfig(prefix=f"= {path}", limit_req=sglang_limits.get(path))
)

# Don't auto-add / location for SGLang routers (catch-all 403 handles it)
if not any(location.prefix == "/" for location in locations) and not is_sglang:
# Don't auto-add / location for router-based services (catch-all 403 handles it)
if not any(location.prefix == "/" for location in locations) and not is_router:
locations.append(LocationConfig(prefix="/", limit_req=None))
return ServiceConfig(
domain=service.domain_safe,
Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/proxy/lib/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Shared constants for proxy components (gateway + in-server proxy).
"""

SGLANG_WHITELISTED_PATHS: tuple[str, ...] = (
# Inference endpoints exposed by the in-replica HTTP router. Applies to both
# SGLang's router and Dynamo's `dynamo.frontend` — they share the
# OpenAI-compatible endpoint surface.
ROUTER_WHITELISTED_PATHS: tuple[str, ...] = (
"/generate",
"/v1/",
"/chat/completions",
Expand Down
123 changes: 117 additions & 6 deletions src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few unit tests for these changes could be useful

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was referring to the changes in this file, the pipeline, that aren't yet covered. Pipelines are important to cover, because a broken pipeline can potentially affect many jobs on the server

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterable, Literal, Optional, Sequence, Union

import httpx
from sqlalchemy import and_, exists, func, or_, select, update
from sqlalchemy import and_, exists, false, func, or_, select, true, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only

Expand All @@ -23,6 +23,7 @@
from dstack._internal.core.models.metrics import Metric
from dstack._internal.core.models.profiles import StartupOrder
from dstack._internal.core.models.repos import RemoteRepoCreds
from dstack._internal.core.models.routers import RouterType
from dstack._internal.core.models.runs import (
ClusterInfo,
ImagePullProgress,
Expand Down Expand Up @@ -102,6 +103,11 @@
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.services.runs import is_job_ready, run_model_to_run
from dstack._internal.server.services.runs.replicas import (
RouterEnvStatus,
get_router_env_for_job,
get_router_replica_group,
)
from dstack._internal.server.services.secrets import get_project_secrets_mapping
from dstack._internal.server.services.storage import get_default_storage
from dstack._internal.server.utils import sentry_utils
Expand All @@ -114,6 +120,8 @@

JOB_STATUSES_WITH_MIN_PROCESSING_INTERVAL = [JobStatus.PROVISIONING, JobStatus.PULLING]

ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS = 30 * 60

JOB_DISCONNECTED_RETRY_TIMEOUT = timedelta(minutes=2)
"""`The minimum time before terminating active job in case of connectivity issues."""

Expand Down Expand Up @@ -368,6 +376,12 @@ class _StartupContext:
volumes: list[Volume]
secrets: dict[str, str]
repo_creds: Optional[RemoteRepoCreds]
router_env: Optional[Dict[str, str]] = None
"""Dynamo-specific env (e.g. DSTACK_ROUTER_INTERNAL_IP) computed from the
router replica's state. Passed through to RunnerClient.submit_job, which
merges it into a deep-copied job_spec.env so the shared job_spec is not
mutated. None for SGLang services, non-router runs, and the router
replica itself."""


async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_ProcessContext]:
Expand All @@ -384,10 +398,18 @@ async def _load_process_context(item: JobRunningPipelineItem) -> Optional[_Proce
job_submissions=[job_model_to_job_submission(job_model)],
)
else:
# PROVISIONING/PULLING jobs need same-replica siblings for cluster coordination.
# All sibling access is replica-scoped, so only load jobs for this replica.
# PROVISIONING/PULLING jobs need same-replica siblings for cluster
# coordination, plus — when the run has a router replica group —
# the router replica's job (cross-replica) so the env-injection
# gate in _prepare_startup_context can read its status / IP.
# _fetch_run_model handles both: same-replica jobs always, plus
# all non-terminated jobs when one exists.
run_spec = RunSpec.__response__.parse_raw(job_model.run.run_spec)
run_model = await _fetch_run_model(
session=session, run_id=job_model.run_id, replica_num=item.replica_num
session=session,
run_id=job_model.run_id,
replica_num=item.replica_num,
run_spec=run_spec,
)
run = run_model_to_run(run_model, include_sensitive=True)
job = find_job(run.jobs, job_model.replica_num, job_model.job_num)
Expand Down Expand Up @@ -477,6 +499,58 @@ async def _prepare_startup_context(
)
return None

# If this run has a router replica group and this job is a worker, gate
# startup on the router replica's state. The helper returns None for the
# router itself and for runs without a router group, so this whole block
# is a no-op in those cases.
router_env_outcome = get_router_env_for_job(
run_model=context.run_model,
run_spec=context.run.run_spec,
job_model=context.job_model,
)
if router_env_outcome is RouterEnvStatus.FAILED:
# Router has reached a terminal state — the worker cannot recover by
# waiting. Terminate it now with a clear reason instead of letting it
# idle until the run-level reconciler tears the whole run down.
_terminate_job(
job_model=context.job_model,
job_update_map=result.job_update_map,
termination_reason=JobTerminationReason.TERMINATED_BY_SERVER,
termination_reason_message=(
"Router replica is in a terminal state; cannot provision worker "
"without a running router."
),
)
return None
if router_env_outcome is RouterEnvStatus.NOT_PROVISIONED:
# Router is alive but its internal_ip is not yet known. Defer this
# worker — the next pipeline tick will re-check. Bound the wait so a
# router that is genuinely stuck can't burn worker instance-hours
# forever; see ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS.
waited_seconds = (get_current_datetime() - context.job_model.submitted_at).total_seconds()
if waited_seconds > ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS:
_terminate_job(
job_model=context.job_model,
job_update_map=result.job_update_map,
termination_reason=JobTerminationReason.TERMINATED_BY_SERVER,
termination_reason_message=(
f"Router replica did not acquire an internal IP within "
f"{ROUTER_PROVISIONING_WAIT_TIMEOUT_SECONDS}s; terminating worker."
),
)
return None
logger.debug(
"%s: waiting for router replica to be provisioned",
fmt(context.job_model),
)
return None
# Past the enum branches, router_env_outcome is either None or a Dict.
# We don't mutate job_spec.env here — RunnerClient.submit_job merges it
# into a deep-copied spec, mirroring how instance_env is handled.
router_env: Optional[Dict[str, str]] = (
router_env_outcome if isinstance(router_env_outcome, dict) else None
)

cluster_info = _get_cluster_info(
jobs=context.run.jobs,
replica_num=context.job.job_spec.replica_num,
Expand Down Expand Up @@ -520,6 +594,7 @@ async def _prepare_startup_context(
volumes=volumes,
secrets=secrets,
repo_creds=repo_creds,
router_env=router_env,
)


Expand All @@ -534,6 +609,7 @@ async def _refetch_locked_job_model(
)
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
.options(joinedload(JobModel.run).load_only(RunModel.id, RunModel.run_spec))
.execution_options(populate_existing=True)
)
return res.unique().scalar_one_or_none()
Expand All @@ -543,13 +619,22 @@ async def _fetch_run_model(
session: AsyncSession,
run_id: uuid.UUID,
replica_num: Optional[int] = None,
run_spec: Optional[RunSpec] = None,
) -> RunModel:
"""Fetch run model with related project, user, repo, and fleet.

Args:
replica_num: If None, skip loading jobs (for RUNNING jobs that don't need siblings).
If set, load only latest-submission jobs for that replica (for PROVISIONING/PULLING
jobs that need same-replica siblings for cluster coordination).
jobs that need same-replica siblings for cluster coordination). When the run has
a Dynamo router replica group, all non-terminated latest-submission jobs for the
run are loaded so find_router_job can identify the router by replica-group
membership.
run_spec: Required whenever `replica_num` is set. Used only to detect
whether the run has a Dynamo router replica group. The caller is
expected to parse it once from the eager-loaded JobModel.run
(see _refetch_locked_job_model) so we don't issue a separate
query for it here.
"""
query = (
select(RunModel)
Expand All @@ -560,14 +645,27 @@ async def _fetch_run_model(
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
)
if replica_num is not None:
assert run_spec is not None, "run_spec must be provided when replica_num is set"
router_group = get_router_replica_group(run_spec)
is_dynamo = (
router_group is not None
and router_group.router is not None
and router_group.router.type == RouterType.DYNAMO
)

latest_submissions_sq = (
select(
JobModel.run_id.label("run_id"),
JobModel.replica_num.label("replica_num"),
JobModel.job_num.label("job_num"),
func.max(JobModel.submission_num).label("max_submission_num"),
)
.where(JobModel.run_id == run_id, JobModel.replica_num == replica_num)
.where(
JobModel.run_id == run_id,
# For Service with Dynamo router: load all replicas. For Non-Dynamo: only the worker's
# own replica.
true() if is_dynamo else JobModel.replica_num == replica_num,
)
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
.subquery()
)
Expand All @@ -581,6 +679,15 @@ async def _fetch_run_model(
job_alias.replica_num == latest_submissions_sq.c.replica_num,
job_alias.job_num == latest_submissions_sq.c.job_num,
job_alias.submission_num == latest_submissions_sq.c.max_submission_num,
# For Dynamo runs, drop terminated rows so accumulated
# scale-down history doesn't bloat the load. Non-Dynamo
# runs are already restricted to the worker's own
# replica above, so this filter is a no-op for them.
or_(
false() if is_dynamo else true(),
~job_alias.status.in_(JobStatus.finished_statuses())
& (job_alias.status != JobStatus.TERMINATING),
),
),
)
.options(contains_eager(RunModel.jobs, alias=job_alias))
Expand Down Expand Up @@ -690,6 +797,7 @@ async def _process_provisioning_status(
file_archives=file_archives,
secrets=startup_context.secrets,
repo_credentials=startup_context.repo_creds,
router_env=startup_context.router_env,
success_if_not_available=False,
)
if submit_result is not False:
Expand Down Expand Up @@ -800,6 +908,7 @@ async def _process_pulling_status(
file_archives=file_archives,
secrets=startup_context.secrets,
repo_credentials=startup_context.repo_creds,
router_env=startup_context.router_env,
success_if_not_available=True,
)
if submit_result is not False:
Expand Down Expand Up @@ -1408,6 +1517,7 @@ def _submit_job_to_runner(
file_archives: Iterable[tuple[uuid.UUID, bytes]],
secrets: Dict[str, str],
repo_credentials: Optional[RemoteRepoCreds],
router_env: Optional[Dict[str, str]],
success_if_not_available: bool,
) -> Union[_SubmitJobToRunnerResult, Literal[False]]:
logger.debug("%s: submitting job spec", fmt(job_model))
Expand Down Expand Up @@ -1435,6 +1545,7 @@ def _submit_job_to_runner(
secrets={},
repo_credentials=repo_credentials,
instance_env=instance_env,
router_env=router_env,
)
for archive_id, archive in file_archives:
logger.debug("%s: uploading file archive: %s", fmt(job_model), archive_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.services.runs.router_worker_sync import (
run_model_has_router_replica_group,
run_model_has_sglang_router_replica_group,
sync_router_workers_for_run_model,
)
from dstack._internal.server.utils import sentry_utils
Expand Down Expand Up @@ -212,7 +212,7 @@ async def process(self, item: ServiceRouterWorkerSyncPipelineItem) -> None:
run_model.deleted
or run_model.status.is_finished()
or run_model.status != RunStatus.RUNNING
or not run_model_has_router_replica_group(run_model)
or not run_model_has_sglang_router_replica_group(run_model)
):
early_cleanup_update_map: _SyncRowUpdateMap = {"deleted": True}
set_processed_update_map_fields(early_cleanup_update_map)
Expand Down
Loading
Loading