Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 82 additions & 77 deletions sagemaker-train/tests/integ/train/test_mtrl_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import pytest
import logging

os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2")
os.environ.setdefault("SAGEMAKER_REGION", "us-west-2")

import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.train.evaluate import MultiTurnRLEvaluator
Expand All @@ -36,22 +33,30 @@
# Timeout for evaluation pipeline execution (4 hours)
EVALUATION_TIMEOUT_SECONDS = 14400

# Resolve current account ID for account-agnostic paths
_REGION = "us-west-2"
_ACCOUNT_ID = boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"]

# Test configuration — uses current account for all resource paths
TEST_CONFIG = {
#"base_model": "huggingface-vlm-qwen3-6-27b",
"base_model": "openai-reasoning-gpt-oss-20b",
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
"dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet",
"s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/",
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
"role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin",
"region": _REGION,
}


def _get_test_config():
"""Build test configuration lazily (only when tests actually run)."""
boto_session = boto3.Session(region_name=_REGION)
account_id = boto_session.client("sts").get_caller_identity()["Account"]
return {
"base_model": "openai-reasoning-gpt-oss-20b",
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{account_id}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
"dataset": f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet",
"s3_output_path": f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/output-artifacts/",
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{account_id}:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{account_id}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
"role": f"arn:aws:iam::{account_id}:role/Admin",
"region": _REGION,
"account_id": account_id,
}


@pytest.fixture(scope="module")
def test_config():
"""Lazily resolve test configuration (avoids module-level API calls)."""
return _get_test_config()


def _ensure_model_package_group_exists(sm_client, group_name):
Expand Down Expand Up @@ -84,35 +89,35 @@ def _ensure_model_package_exists(sm_client, group_name, base_model_name):


@pytest.fixture(scope="module")
def sagemaker_session():
def sagemaker_session_mtrl():
"""Create a SageMaker session with explicit region for CI environments."""
boto_session = boto3.Session(region_name=TEST_CONFIG["region"])
boto_session = boto3.Session(region_name=_REGION)
return Session(boto_session=boto_session)


@pytest.fixture(scope="module")
def mtrl_trainer(sagemaker_session):
def mtrl_trainer(sagemaker_session_mtrl, test_config):
"""Create a lightweight MultiTurnRLTrainer-like object for evaluator tests.

Instead of going through the full constructor (which validates remote
resources), we build a minimal object with the attributes the evaluator
needs. This makes the test account-agnostic — it creates the required
resources (model package group + model package) on the fly.
"""
sm_client = sagemaker_session.boto_session.client("sagemaker")
sm_client = sagemaker_session_mtrl.boto_session.client("sagemaker")
group_name = "mtrl-integ-test-evaluator"
_ensure_model_package_group_exists(sm_client, group_name)
model_package_arn = _ensure_model_package_exists(
sm_client, group_name, TEST_CONFIG["base_model"]
sm_client, group_name, test_config["base_model"]
)

trainer = object.__new__(MultiTurnRLTrainer)
trainer._model_name = TEST_CONFIG["base_model"]
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{TEST_CONFIG['base_model']}/1.0.0"
trainer.agent_env = TEST_CONFIG["agent_arn"]
trainer._model_name = test_config["base_model"]
trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{test_config['base_model']}/1.0.0"
trainer.agent_env = test_config["agent_arn"]
trainer.bedrock_agentcore_qualifier = "DEFAULT"
trainer.output_model_package_group = TEST_CONFIG["model_package_group"]
trainer.sagemaker_session = sagemaker_session
trainer.output_model_package_group = test_config["model_package_group"]
trainer.sagemaker_session = sagemaker_session_mtrl

# Use the real model package ARN from the account
class _FakeJob:
Expand All @@ -131,27 +136,27 @@ class _FakeJob:
class TestMTRLEvaluatorJobConfigDocument:
"""Tests validating the JobConfigDocument field naming for GA API contract."""

def test_bedrock_agent_config_fields(self, mtrl_trainer):
def test_bedrock_agent_config_fields(self, mtrl_trainer, test_config):
"""Verify BedrockAgentCoreConfig uses AgentRuntimeArn and Qualifier."""
evaluator = MultiTurnRLEvaluator(
model=mtrl_trainer,
dataset=TEST_CONFIG["dataset"],
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-bedrock/',
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
role=TEST_CONFIG["role"],
region=TEST_CONFIG["region"],
agent_config=TEST_CONFIG["agent_arn"],
dataset=test_config["dataset"],
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-bedrock/',
mlflow_resource_arn=test_config["mlflow_resource_arn"],
role=test_config["role"],
region=test_config["region"],
agent_config=test_config["agent_arn"],
agent_qualifier="PROD",
)

evaluator._resolve_trainer_defaults()
evaluator._resolve_agent_arn()

ctx = evaluator._build_template_context(
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
"role_arn": TEST_CONFIG["role"]},
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
"role_arn": test_config["role"]},
artifacts={},
model_package_group_arn=TEST_CONFIG["model_package_group"],
model_package_group_arn=test_config["model_package_group"],
)

doc = json.loads(ctx["job_config_document_ft_str"])
Expand All @@ -166,27 +171,27 @@ def test_bedrock_agent_config_fields(self, mtrl_trainer):
assert "AgentArn" not in agent_cfg.get("BedrockAgentCoreConfig", {})
assert "BedrockAgentCoreQualifier" not in agent_cfg.get("BedrockAgentCoreConfig", {})

def test_lambda_agent_config_fields(self, mtrl_trainer):
def test_lambda_agent_config_fields(self, mtrl_trainer, test_config):
"""Verify Lambda agent uses CustomAgentLambdaConfig (not LambdaConfig)."""
lambda_arn = "arn:aws:lambda:us-east-1:060795915353:function:SageMaker-agent-adapter-gsm8k"
evaluator = MultiTurnRLEvaluator(
model=mtrl_trainer,
dataset=TEST_CONFIG["dataset"],
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-lambda/',
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
role=TEST_CONFIG["role"],
region=TEST_CONFIG["region"],
dataset=test_config["dataset"],
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-lambda/',
mlflow_resource_arn=test_config["mlflow_resource_arn"],
role=test_config["role"],
region=test_config["region"],
agent_config=lambda_arn,
)

evaluator._resolve_trainer_defaults()
evaluator._resolve_agent_arn()

ctx = evaluator._build_template_context(
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
"role_arn": TEST_CONFIG["role"]},
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
"role_arn": test_config["role"]},
artifacts={},
model_package_group_arn=TEST_CONFIG["model_package_group"],
model_package_group_arn=test_config["model_package_group"],
)

doc = json.loads(ctx["job_config_document_ft_str"])
Expand All @@ -198,26 +203,26 @@ def test_lambda_agent_config_fields(self, mtrl_trainer):
# Ensure old field name is NOT present
assert "LambdaConfig" not in agent_cfg

def test_model_package_config_fields(self, mtrl_trainer):
def test_model_package_config_fields(self, mtrl_trainer, test_config):
"""Verify ModelPackageConfig uses InputModelPackageArn only (no OutputModelPackageGroupArn for eval)."""
evaluator = MultiTurnRLEvaluator(
model=mtrl_trainer,
dataset=TEST_CONFIG["dataset"],
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-mpc/',
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
role=TEST_CONFIG["role"],
region=TEST_CONFIG["region"],
agent_config=TEST_CONFIG["agent_arn"],
dataset=test_config["dataset"],
s3_output_path=f'{test_config["s3_output_path"]}integ-fields-mpc/',
mlflow_resource_arn=test_config["mlflow_resource_arn"],
role=test_config["role"],
region=test_config["region"],
agent_config=test_config["agent_arn"],
)

evaluator._resolve_trainer_defaults()
evaluator._resolve_agent_arn()

ctx = evaluator._build_template_context(
aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID,
"role_arn": TEST_CONFIG["role"]},
aws_context={"region": test_config["region"], "account_id": test_config["account_id"],
"role_arn": test_config["role"]},
artifacts={},
model_package_group_arn=TEST_CONFIG["model_package_group"],
model_package_group_arn=test_config["model_package_group"],
)

doc = json.loads(ctx["job_config_document_ft_str"])
Expand All @@ -239,41 +244,41 @@ class TestMTRLEvaluatorIntegration:
in accounts with the feature flag enabled (e.g., 742774200982).
"""

def test_evaluator_construction_with_trainer(self, mtrl_trainer):
def test_evaluator_construction_with_trainer(self, mtrl_trainer, test_config):
"""Test that MultiTurnRLEvaluator can be constructed from a trainer."""
evaluator = MultiTurnRLEvaluator(
model=mtrl_trainer,
dataset=TEST_CONFIG["dataset"],
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-construct/',
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
role=TEST_CONFIG["role"],
region=TEST_CONFIG["region"],
agent_config=TEST_CONFIG["agent_arn"],
dataset=test_config["dataset"],
s3_output_path=f'{test_config["s3_output_path"]}integ-construct/',
mlflow_resource_arn=test_config["mlflow_resource_arn"],
role=test_config["role"],
region=test_config["region"],
agent_config=test_config["agent_arn"],
)

assert evaluator is not None
assert evaluator.model is mtrl_trainer
assert evaluator.dataset == TEST_CONFIG["dataset"]
assert evaluator.region == TEST_CONFIG["region"]
assert evaluator.dataset == test_config["dataset"]
assert evaluator.region == test_config["region"]

def test_evaluator_construction_with_base_model(self):
def test_evaluator_construction_with_base_model(self, test_config):
"""Test that MultiTurnRLEvaluator can be constructed from a base model string."""
evaluator = MultiTurnRLEvaluator(
model=TEST_CONFIG["base_model"],
dataset=TEST_CONFIG["dataset"],
s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-base/',
agent_config=TEST_CONFIG["agent_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"],
role=TEST_CONFIG["role"],
region=TEST_CONFIG["region"],
model=test_config["base_model"],
dataset=test_config["dataset"],
s3_output_path=f'{test_config["s3_output_path"]}integ-base/',
agent_config=test_config["agent_arn"],
mlflow_resource_arn=test_config["mlflow_resource_arn"],
role=test_config["role"],
region=test_config["region"],
)

assert evaluator is not None
assert evaluator.model == TEST_CONFIG["base_model"]
assert evaluator.model == test_config["base_model"]

def test_get_all_mtrl_evaluations(self):
def test_get_all_mtrl_evaluations(self, test_config):
"""Test listing all MTRL evaluation executions."""
all_execs = MultiTurnRLEvaluator.get_all(region=TEST_CONFIG["region"])
all_execs = MultiTurnRLEvaluator.get_all(region=test_config["region"])

if hasattr(all_execs, '__iter__'):
all_execs = list(all_execs)
Expand Down
Loading
Loading