Skip to content
Draft
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ requires-python = ">=3.10"
"import-linter~=2.10",
"pytest-deadfixtures~=3.1",
"taplo~=0.9.3",
"gymnasium~=1.2",
]
rl = ["gymnasium~=1.2"]
docs = [
"sphinx~=8.1",
"nvidia-sphinx-theme~=0.0.8",
Expand Down
1 change: 1 addition & 0 deletions src/cloudai/_core/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class TestRun:
post_test: Optional[TestScenario] = None
reports: Set[Type[ReportGenerationStrategy]] = field(default_factory=set)
extra_srun_args: str | None = None
current_env_params: dict[str, Any] = field(default_factory=dict)

def __hash__(self) -> int:
return hash(self.name + self.test.name + str(self.iterations) + str(self.current_iteration))
Expand Down
60 changes: 59 additions & 1 deletion src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Protocol, TypeGuard, runtime_checkable
from unittest.mock import Mock

import toml
Expand Down Expand Up @@ -118,6 +118,60 @@ def prepare_installation(
return installables, installer


@runtime_checkable
class CustomTrainingLoopAgent(Protocol):
"""
Agent that drives its own training loop and skips the ``handle_dse_job`` step loop.

Set ``HAS_CUSTOM_TRAINING_LOOP = True`` on the agent class to opt in. Used by
agents (e.g. RLlib-based) whose training loops are not modelled as a sequence
of independent ``select_action`` / ``env.step`` calls.
"""

HAS_CUSTOM_TRAINING_LOOP: bool

def train(self) -> None: ...


def _has_custom_training_loop(agent: object) -> TypeGuard[CustomTrainingLoopAgent]:
"""
Narrow ``agent`` to :class:`CustomTrainingLoopAgent` when it opts into the dispatch path.

Returning :class:`TypeGuard` (instead of plain ``bool``) lets the type checker
treat this predicate like ``isinstance``: callers inside the truthy branch see
``agent`` as a :class:`CustomTrainingLoopAgent`, so ``agent.train()`` type-checks
without ``getattr`` or ``cast``.
"""
return bool(getattr(agent, "HAS_CUSTOM_TRAINING_LOOP", False))


def _run_custom_training_loop(agent: CustomTrainingLoopAgent, agent_type: str) -> int:
"""
Drive an agent's self-contained training loop and return a process-style exit code.

``shutdown()`` runs inside its own ``try/except`` so a faulty teardown cannot
suppress the exit code from ``train()`` nor propagate out of this helper:
``handle_dse_job`` relies on the returned ``rc`` to accumulate ``err |= rc``
and continue with the remaining test runs.
"""
logging.info(f"Agent {agent_type} drives its own training loop; delegating to agent.train().")
rc = 0
try:
agent.train()
except Exception:
logging.exception(f"Custom training loop failed for agent {agent_type}.")
rc = 1
finally:
shutdown = getattr(agent, "shutdown", None)
if callable(shutdown):
try:
shutdown()
except Exception:
logging.exception(f"Shutdown failed for agent {agent_type}.")
rc = 1
return rc


def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
registry = Registry()

Expand Down Expand Up @@ -157,6 +211,10 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:

agent = agent_class(env, agent_config)

if _has_custom_training_loop(agent):
err |= _run_custom_training_loop(agent, agent_type)
continue

for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand Down
2 changes: 2 additions & 0 deletions src/cloudai/configurator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from .base_gym import BaseGym
from .cloudai_gym import CloudAIGymEnv, TrajectoryEntry
from .grid_search import GridSearchAgent
from .gymnasium_adapter import GymnasiumAdapter

__all__ = [
"BaseAgent",
"BaseGym",
"CloudAIGymEnv",
"GridSearchAgent",
"GymnasiumAdapter",
"TrajectoryEntry",
]
55 changes: 50 additions & 5 deletions src/cloudai/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import dataclasses
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from cloudai.core import METRIC_ERROR, BaseRunner, Registry, TestRun
from cloudai.util.lazy_imports import lazy

from .base_agent import RewardOverrides
from .base_gym import BaseGym
from .env_params import CsvSink, EnvParamsObserver, StepObserver


@dataclasses.dataclass(frozen=True)
Expand All @@ -36,6 +37,7 @@ class TrajectoryEntry:
action: dict[str, Any]
reward: float
observation: list
env_params: dict[str, Any] = dataclasses.field(default_factory=dict)


class CloudAIGymEnv(BaseGym):
Expand All @@ -61,8 +63,27 @@ def __init__(self, test_run: TestRun, runner: BaseRunner, rewards: RewardOverrid
self.max_steps = test_run.test.agent_steps
self.reward_function = Registry().get_reward_function(test_run.test.agent_reward_function)
self.trajectory: dict[int, list[TrajectoryEntry]] = {}
self.observers: List[StepObserver] = self._build_observers()
super().__init__()

def _build_observers(self) -> List[StepObserver]:
"""
Construct the per-step observers implied by the TestDefinition.

Workloads opt in to env_params via a TOML ``[env_params.<name>]`` block;
an empty mapping yields no observers and zero overhead.
"""
observers: List[StepObserver] = []
if self.test_run.test.env_params:
seed = int((self.test_run.test.agent_config or {}).get("random_seed", 0))
sink = CsvSink(self._env_csv_path())
observers.append(EnvParamsObserver(self.test_run.test.env_params, sink, seed))
return observers

def _env_csv_path(self) -> Path:
"""``env.csv`` lives alongside ``trajectory.csv`` so a plain ``merge`` joins them."""
return self.trajectory_file_path.parent / "env.csv"

def define_action_space(self) -> Dict[str, list[Any]]:
return self.test_run.param_space

Expand All @@ -76,9 +97,10 @@ def define_observation_space(self) -> list:
Define the observation space for the environment.

Returns:
list: The observation space.
list: One float slot per agent metric (at least one), giving the correct shape
for adapters that derive ``gymnasium.spaces.Box`` from this output.
"""
return [0.0]
return [0.0] * max(len(self.test_run.test.agent_metrics), 1)

def reset(
self,
Expand All @@ -100,7 +122,7 @@ def reset(
if seed is not None:
lazy.np.random.seed(seed)
self.test_run.current_iteration = 0
observation = [0.0]
observation = self.define_observation_space()
info = {}
return observation, info

Expand All @@ -120,6 +142,9 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
"""
self.test_run = self.test_run.apply_params_set(action)

for observer in self.observers:
observer.before_step(self.test_run)

cached_result = self.get_cached_trajectory_result(action)
if cached_result is not None:
logging.info(
Expand All @@ -133,8 +158,11 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=cached_result.reward,
observation=cached_result.observation,
env_params=dict(self.test_run.current_env_params),
)
)
for observer in self.observers:
observer.after_step(self.test_run, cached_result.observation, cached_result.reward)
return cached_result.observation, cached_result.reward, False, {}

if not self.test_run.test.constraint_check(self.test_run, self.runner.system):
Expand Down Expand Up @@ -170,9 +198,13 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
action=action,
reward=reward,
observation=observation,
env_params=dict(self.test_run.current_env_params),
)
)

for observer in self.observers:
observer.after_step(self.test_run, observation, reward)

return observation, reward, False, {}

def render(self, mode: str = "human"):
Expand Down Expand Up @@ -251,8 +283,21 @@ def current_trajectory(self) -> list[TrajectoryEntry]:
return self.trajectory.setdefault(self.test_run.current_iteration, [])

def get_cached_trajectory_result(self, action: Any) -> TrajectoryEntry | None:
"""
Return a cached entry only when the full trial identity matches.

Trial identity is ``(action, env_params)``: env-randomized parameters
change the workload's behaviour, so a trial repeating the same action
under a different ``env_params`` sample must miss and re-run. Empty
env_params on both sides is the back-compat path for workloads that
do not declare any ``[env_params.*]`` block.
"""
current_env_params = getattr(self.test_run, "current_env_params", {}) or {}
for entry in self.current_trajectory:
if self._values_match_exact(entry.action, action):
if not self._values_match_exact(entry.action, action):
continue
entry_env = getattr(entry, "env_params", {}) or {}
if self._values_match_exact(entry_env, current_env_params):
return entry

return None
Expand Down
Loading