Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def hyperparameters(self):
override_params = _get_evaluation_override_params(
hub_content_name=hub_content_name,
hub_name="SageMakerPublicHub",
evaluation_type="AgentRFTEvaluation",
evaluation_type="MTRLEvaluation",
region=self.region,
session=boto_session,
)
Expand All @@ -328,7 +328,9 @@ def hyperparameters(self):
f"JumpStart hub."
)

spec = _extract_eval_override_options(override_params, return_full_spec=True)
spec = _extract_eval_override_options(
override_params, param_names=list(override_params.keys()), return_full_spec=True
)
self._hyperparameters = FineTuningOptions(spec)
return self._hyperparameters

Expand Down
12 changes: 6 additions & 6 deletions sagemaker-train/tests/integ/train/test_benchmark_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
"region": "us-west-2",
}
Expand All @@ -57,7 +57,7 @@
"base_model_id": "meta-textgeneration-llama-3-2-1b-instruct",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"region": "us-west-2",
}

Expand Down Expand Up @@ -124,7 +124,7 @@ def test_benchmark_evaluation_full_flow(self):
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
model_package_group=TEST_CONFIG["model_package_group_arn"],
base_eval_name="integ-test-gen-qa-eval",
)
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_benchmark_evaluator_validation(self):
benchmark="invalid_benchmark",
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
)

# Test invalid MLflow ARN format
Expand All @@ -265,7 +265,7 @@ def test_benchmark_subtasks_validation(self):
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
subtasks="abstract_algebra",
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
)
Expand All @@ -277,7 +277,7 @@ def test_benchmark_subtasks_validation(self):
benchmark=Benchmark.MMLU,
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
subtasks=["invalid"],
model_package_group="arn:aws:sagemaker:us-west-2:123456789012:model-package-group/test",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"model_package_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
"evaluate_base_model": False,
"region": "us-west-2",
Expand All @@ -60,7 +60,7 @@
"evaluator_arn": "arn:aws:sagemaker:us-west-2:729646638167:hub-content/sdktest/JsonDoc/eval-lambda-test/0.0.1",
"dataset_s3_uri": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/zc_test.jsonl",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"region": "us-west-2",
}

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_custom_scorer_evaluation_full_flow(self):
dataset=TEST_CONFIG["dataset_s3_uri"],
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
)

Expand Down Expand Up @@ -228,7 +228,7 @@ def test_custom_scorer_evaluator_validation(self):
evaluator=123, # Invalid type (not string, enum, or object)
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
dataset=TEST_CONFIG["dataset_s3_uri"],
)

Expand Down Expand Up @@ -268,7 +268,7 @@ def test_custom_scorer_with_builtin_metric(self):
dataset=TEST_CONFIG["dataset_s3_uri"],
model=TEST_CONFIG["model_package_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
evaluate_base_model=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"builtin_metrics": ["Completeness", "Faithfulness"],
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/base-model-fix-test/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"evaluate_base_model": True, # This is the key difference - testing base model evaluation
"region": "us-west-2",
}
Expand Down Expand Up @@ -115,6 +115,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
custom_metrics=TEST_CONFIG["custom_metrics_json"],
s3_output_path=TEST_CONFIG["s3_output_path"],
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
)

# Verify evaluator configuration
Expand Down Expand Up @@ -271,6 +272,7 @@ def test_base_model_false_still_works(self):
builtin_metrics=TEST_CONFIG["builtin_metrics"],
s3_output_path=TEST_CONFIG["s3_output_path"],
evaluate_base_model=False, # Only evaluate custom model
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
)

# Verify evaluator configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"builtin_metrics": ["Completeness", "Faithfulness"],
"custom_metrics_json": json.dumps([CUSTOM_METRIC_DICT]),
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/model-customization/eval/",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-W7FOBBXZANVX",
"mlflow_tracking_server_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
# "model_package_group_arn": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/sdk-test-finetuned-models",
"evaluate_base_model": False,
"region": "us-west-2",
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_llm_as_judge_evaluation_full_flow(self):
dataset=TEST_CONFIG["dataset_s3_uri"],
builtin_metrics=TEST_CONFIG["builtin_metrics"],
custom_metrics=TEST_CONFIG["custom_metrics_json"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
s3_output_path=TEST_CONFIG["s3_output_path"],
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
)
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_llm_as_judge_builtin_metrics_prefix_handling(self):
evaluator_model=TEST_CONFIG["evaluator_model"],
dataset=TEST_CONFIG["dataset_s3_uri"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
builtin_metrics=["Builtin.Correctness", "Builtin.Helpfulness"],
)
assert evaluator_with_prefix.builtin_metrics == ["Builtin.Correctness", "Builtin.Helpfulness"]
Expand All @@ -247,7 +247,7 @@ def test_llm_as_judge_builtin_metrics_prefix_handling(self):
evaluator_model=TEST_CONFIG["evaluator_model"],
dataset=TEST_CONFIG["dataset_s3_uri"],
s3_output_path=TEST_CONFIG["s3_output_path"],
# mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
builtin_metrics=["Correctness", "Helpfulness"],
)
assert evaluator_without_prefix.builtin_metrics == ["Correctness", "Helpfulness"]
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-train/tests/integ/train/test_mtrl_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
"dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet",
"s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/",
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU",
"mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
"role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin",
"region": _REGION,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def handler(event, context):
),
"mlflow_resource_arn": os.environ.get(
"MTRL_3P_MLFLOW_ARN",
f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU",
f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6",
),
"role": os.environ.get(
"MTRL_3P_ROLE",
Expand Down Expand Up @@ -262,6 +262,7 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn):
logger.info(f"Started 3P agent base model evaluation: {execution.arn}")
logger.info(f"Status: {execution.status.overall_status}")

@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn):
"""Test evaluating using an CustomAgentLambda object as agent_config.

Expand All @@ -287,6 +288,7 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn):
assert execution.arn is not None
logger.info(f"Started CustomAgentLambda object evaluation: {execution.arn}")

@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn):
"""Test full end-to-end: start evaluation and wait for completion.

Expand Down Expand Up @@ -316,6 +318,7 @@ def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn):
if execution.status.overall_status == "Failed":
logger.error(f"Failure reason: {execution.status.failure_reason}")

@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn):
"""Test that 3P agent evaluations are discoverable via get_all.

Expand Down Expand Up @@ -355,6 +358,7 @@ def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn):



@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_with_attached_trainer(self, lambda_agent_arn):
"""Test evaluating a fine-tuned model by attaching to an existing training job."""
from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer
Expand Down
37 changes: 35 additions & 2 deletions sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def _get_account_id():
# PROD — Main account (729646638167)
"729646638167": {
"env_name": "PROD",
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260602150414",
"existing_job_name": "openai-reasoning-gpt-oss-20b-mtrl-20260602215955",
"base_model": "openai-reasoning-gpt-oss-20b",
"agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:729646638167:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS",
"dataset": "s3://sagemaker-rft-729646638167/prompts/gsm8k_small/prompts.parquet",
"s3_output_path": "s3://sagemaker-us-west-2-729646638167/mtrl-integ/eval-output/",
"mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-ZG6FYITNGMMU",
"mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-TTAUWUNMUHH6",
"model_package_group": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg",
"role": "arn:aws:iam::729646638167:role/Admin",
},
Expand Down Expand Up @@ -187,6 +187,7 @@ def test_evaluate_finetuned_model(self, attached_trainer, config):
f"reason: {execution.status.failure_reason}"
)

@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_base_model(self, config):
"""Evaluate the base model only — submit and wait for completion."""
evaluator = MultiTurnRLEvaluator(
Expand Down Expand Up @@ -247,3 +248,35 @@ def test_evaluate_comparison(self, attached_trainer, config):
f"[{config['env_name']}] Comparison eval failed with status: {status}, "
f"reason: {execution.status.failure_reason}"
)

@pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually")
def test_evaluate_with_hyperparam_override(self, attached_trainer, config):
"""Test that hyperparameter overrides are passed through to the eval job."""
evaluator = MultiTurnRLEvaluator(
model=attached_trainer,
dataset=config["dataset"],
s3_output_path=f'{config["s3_output_path"]}hyperparam-override/',
mlflow_resource_arn=config["mlflow_resource_arn"],
role=config["role"],
region=_REGION,
)

# Override MTRL-specific hyperparams
evaluator.hyperparameters.sampling_max_tokens = 1024
evaluator.hyperparameters.eval_group_size = 4

execution = evaluator.evaluate()

assert execution is not None
assert execution.arn is not None
logger.info(f"[{config['env_name']}] Started hyperparam override eval: {execution.arn}")

execution.wait(timeout=EVAL_TIMEOUT)

status = execution.status.overall_status
logger.info(f"[{config['env_name']}] Hyperparam override eval completed: {status}")

assert status == "Succeeded", (
f"[{config['env_name']}] Hyperparam override eval failed with status: {status}, "
f"reason: {execution.status.failure_reason}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS"
ROLE_ARN = f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin"
MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU"
MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6"
S3_INPUT_PATH = f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet"
S3_OUTPUT_PATH = f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/mtrl-trainer-integ/"
LAMBDA_ARN = f"arn:aws:lambda:{_REGION}:{_ACCOUNT_ID}:function:SageMaker-AgentConnector-Lambda-MTRL-integ-test"
Expand Down
Loading