diff --git a/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py b/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py index 136429a5f1..4ce1c409e4 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py @@ -22,9 +22,6 @@ import pytest import logging -os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2") -os.environ.setdefault("SAGEMAKER_REGION", "us-west-2") - import boto3 from sagemaker.core.helper.session_helper import Session from sagemaker.train.evaluate import MultiTurnRLEvaluator @@ -36,22 +33,30 @@ # Timeout for evaluation pipeline execution (4 hours) EVALUATION_TIMEOUT_SECONDS = 14400 -# Resolve current account ID for account-agnostic paths _REGION = "us-west-2" -_ACCOUNT_ID = boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"] - -# Test configuration — uses current account for all resource paths -TEST_CONFIG = { - #"base_model": "huggingface-vlm-qwen3-6-27b", - "base_model": "openai-reasoning-gpt-oss-20b", - "agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS", - "dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet", - "s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/", - "mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6", - "model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg", - "role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin", - "region": _REGION, -} + + +def _get_test_config(): + """Build test configuration lazily (only when tests actually run).""" + boto_session = boto3.Session(region_name=_REGION) + account_id = boto_session.client("sts").get_caller_identity()["Account"] + return { + "base_model": "openai-reasoning-gpt-oss-20b", + "agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{account_id}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS", + "dataset": f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet", + "s3_output_path": f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/output-artifacts/", + "mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{account_id}:mlflow-app/app-TTAUWUNMUHH6", + "model_package_group": f"arn:aws:sagemaker:{_REGION}:{account_id}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg", + "role": f"arn:aws:iam::{account_id}:role/Admin", + "region": _REGION, + "account_id": account_id, + } + + +@pytest.fixture(scope="module") +def test_config(): + """Lazily resolve test configuration (avoids module-level API calls).""" + return _get_test_config() def _ensure_model_package_group_exists(sm_client, group_name): @@ -84,14 +89,14 @@ def _ensure_model_package_exists(sm_client, group_name, base_model_name): @pytest.fixture(scope="module") -def sagemaker_session(): +def sagemaker_session_mtrl(): """Create a SageMaker session with explicit region for CI environments.""" - boto_session = boto3.Session(region_name=TEST_CONFIG["region"]) + boto_session = boto3.Session(region_name=_REGION) return Session(boto_session=boto_session) @pytest.fixture(scope="module") -def mtrl_trainer(sagemaker_session): +def mtrl_trainer(sagemaker_session_mtrl, test_config): """Create a lightweight MultiTurnRLTrainer-like object for evaluator tests. Instead of going through the full constructor (which validates remote @@ -99,20 +104,20 @@ def mtrl_trainer(sagemaker_session): needs. This makes the test account-agnostic — it creates the required resources (model package group + model package) on the fly. """ - sm_client = sagemaker_session.boto_session.client("sagemaker") + sm_client = sagemaker_session_mtrl.boto_session.client("sagemaker") group_name = "mtrl-integ-test-evaluator" _ensure_model_package_group_exists(sm_client, group_name) model_package_arn = _ensure_model_package_exists( - sm_client, group_name, TEST_CONFIG["base_model"] + sm_client, group_name, test_config["base_model"] ) trainer = object.__new__(MultiTurnRLTrainer) - trainer._model_name = TEST_CONFIG["base_model"] - trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{TEST_CONFIG['base_model']}/1.0.0" - trainer.agent_env = TEST_CONFIG["agent_arn"] + trainer._model_name = test_config["base_model"] + trainer._model_arn = f"arn:aws:sagemaker:{_REGION}:aws:hub-content/SageMakerPublicHub/Model/{test_config['base_model']}/1.0.0" + trainer.agent_env = test_config["agent_arn"] trainer.bedrock_agentcore_qualifier = "DEFAULT" - trainer.output_model_package_group = TEST_CONFIG["model_package_group"] - trainer.sagemaker_session = sagemaker_session + trainer.output_model_package_group = test_config["model_package_group"] + trainer.sagemaker_session = sagemaker_session_mtrl # Use the real model package ARN from the account class _FakeJob: @@ -131,16 +136,16 @@ class _FakeJob: class TestMTRLEvaluatorJobConfigDocument: """Tests validating the JobConfigDocument field naming for GA API contract.""" - def test_bedrock_agent_config_fields(self, mtrl_trainer): + def test_bedrock_agent_config_fields(self, mtrl_trainer, test_config): """Verify BedrockAgentCoreConfig uses AgentRuntimeArn and Qualifier.""" evaluator = MultiTurnRLEvaluator( model=mtrl_trainer, - dataset=TEST_CONFIG["dataset"], - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-bedrock/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], - agent_config=TEST_CONFIG["agent_arn"], + dataset=test_config["dataset"], + s3_output_path=f'{test_config["s3_output_path"]}integ-fields-bedrock/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], + agent_config=test_config["agent_arn"], agent_qualifier="PROD", ) @@ -148,10 +153,10 @@ def test_bedrock_agent_config_fields(self, mtrl_trainer): evaluator._resolve_agent_arn() ctx = evaluator._build_template_context( - aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID, - "role_arn": TEST_CONFIG["role"]}, + aws_context={"region": test_config["region"], "account_id": test_config["account_id"], + "role_arn": test_config["role"]}, artifacts={}, - model_package_group_arn=TEST_CONFIG["model_package_group"], + model_package_group_arn=test_config["model_package_group"], ) doc = json.loads(ctx["job_config_document_ft_str"]) @@ -166,16 +171,16 @@ def test_bedrock_agent_config_fields(self, mtrl_trainer): assert "AgentArn" not in agent_cfg.get("BedrockAgentCoreConfig", {}) assert "BedrockAgentCoreQualifier" not in agent_cfg.get("BedrockAgentCoreConfig", {}) - def test_lambda_agent_config_fields(self, mtrl_trainer): + def test_lambda_agent_config_fields(self, mtrl_trainer, test_config): """Verify Lambda agent uses CustomAgentLambdaConfig (not LambdaConfig).""" lambda_arn = "arn:aws:lambda:us-east-1:060795915353:function:SageMaker-agent-adapter-gsm8k" evaluator = MultiTurnRLEvaluator( model=mtrl_trainer, - dataset=TEST_CONFIG["dataset"], - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-lambda/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + dataset=test_config["dataset"], + s3_output_path=f'{test_config["s3_output_path"]}integ-fields-lambda/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], agent_config=lambda_arn, ) @@ -183,10 +188,10 @@ def test_lambda_agent_config_fields(self, mtrl_trainer): evaluator._resolve_agent_arn() ctx = evaluator._build_template_context( - aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID, - "role_arn": TEST_CONFIG["role"]}, + aws_context={"region": test_config["region"], "account_id": test_config["account_id"], + "role_arn": test_config["role"]}, artifacts={}, - model_package_group_arn=TEST_CONFIG["model_package_group"], + model_package_group_arn=test_config["model_package_group"], ) doc = json.loads(ctx["job_config_document_ft_str"]) @@ -198,26 +203,26 @@ def test_lambda_agent_config_fields(self, mtrl_trainer): # Ensure old field name is NOT present assert "LambdaConfig" not in agent_cfg - def test_model_package_config_fields(self, mtrl_trainer): + def test_model_package_config_fields(self, mtrl_trainer, test_config): """Verify ModelPackageConfig uses InputModelPackageArn only (no OutputModelPackageGroupArn for eval).""" evaluator = MultiTurnRLEvaluator( model=mtrl_trainer, - dataset=TEST_CONFIG["dataset"], - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-fields-mpc/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], - agent_config=TEST_CONFIG["agent_arn"], + dataset=test_config["dataset"], + s3_output_path=f'{test_config["s3_output_path"]}integ-fields-mpc/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], + agent_config=test_config["agent_arn"], ) evaluator._resolve_trainer_defaults() evaluator._resolve_agent_arn() ctx = evaluator._build_template_context( - aws_context={"region": TEST_CONFIG["region"], "account_id": _ACCOUNT_ID, - "role_arn": TEST_CONFIG["role"]}, + aws_context={"region": test_config["region"], "account_id": test_config["account_id"], + "role_arn": test_config["role"]}, artifacts={}, - model_package_group_arn=TEST_CONFIG["model_package_group"], + model_package_group_arn=test_config["model_package_group"], ) doc = json.loads(ctx["job_config_document_ft_str"]) @@ -239,41 +244,41 @@ class TestMTRLEvaluatorIntegration: in accounts with the feature flag enabled (e.g., 742774200982). """ - def test_evaluator_construction_with_trainer(self, mtrl_trainer): + def test_evaluator_construction_with_trainer(self, mtrl_trainer, test_config): """Test that MultiTurnRLEvaluator can be constructed from a trainer.""" evaluator = MultiTurnRLEvaluator( model=mtrl_trainer, - dataset=TEST_CONFIG["dataset"], - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-construct/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], - agent_config=TEST_CONFIG["agent_arn"], + dataset=test_config["dataset"], + s3_output_path=f'{test_config["s3_output_path"]}integ-construct/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], + agent_config=test_config["agent_arn"], ) assert evaluator is not None assert evaluator.model is mtrl_trainer - assert evaluator.dataset == TEST_CONFIG["dataset"] - assert evaluator.region == TEST_CONFIG["region"] + assert evaluator.dataset == test_config["dataset"] + assert evaluator.region == test_config["region"] - def test_evaluator_construction_with_base_model(self): + def test_evaluator_construction_with_base_model(self, test_config): """Test that MultiTurnRLEvaluator can be constructed from a base model string.""" evaluator = MultiTurnRLEvaluator( - model=TEST_CONFIG["base_model"], - dataset=TEST_CONFIG["dataset"], - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}integ-base/', - agent_config=TEST_CONFIG["agent_arn"], - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + model=test_config["base_model"], + dataset=test_config["dataset"], + s3_output_path=f'{test_config["s3_output_path"]}integ-base/', + agent_config=test_config["agent_arn"], + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], ) assert evaluator is not None - assert evaluator.model == TEST_CONFIG["base_model"] + assert evaluator.model == test_config["base_model"] - def test_get_all_mtrl_evaluations(self): + def test_get_all_mtrl_evaluations(self, test_config): """Test listing all MTRL evaluation executions.""" - all_execs = MultiTurnRLEvaluator.get_all(region=TEST_CONFIG["region"]) + all_execs = MultiTurnRLEvaluator.get_all(region=test_config["region"]) if hasattr(all_execs, '__iter__'): all_execs = list(all_execs) diff --git a/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py b/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py index ef8f30a8fe..8314096fca 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py @@ -38,19 +38,13 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(name)s - %(message)s") logger = logging.getLogger(__name__) -os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2") -os.environ.setdefault("SAGEMAKER_REGION", "us-west-2") - # Timeout for evaluation pipeline execution (4 hours) EVALUATION_TIMEOUT_SECONDS = 14400 -# Resolve current account ID for account-agnostic paths _REGION = "us-west-2" -_ACCOUNT_ID = boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"] # Lambda configuration LAMBDA_FUNCTION_NAME = "SageMaker-AgentConnector-Lambda-MTRL-integ-test" -LAMBDA_ROLE = f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin" LAMBDA_RUNTIME = "python3.12" LAMBDA_TIMEOUT = 120 LAMBDA_REGION = _REGION @@ -149,26 +143,37 @@ def handler(event, context): ''' # Test configuration for 3P agent evaluation. -TEST_CONFIG = { - "base_model": "openai-reasoning-gpt-oss-20b", - "dataset": os.environ.get( - "MTRL_3P_DATASET", - f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet", - ), - "s3_output_path": os.environ.get( - "MTRL_3P_S3_OUTPUT", - f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/3p-agent-integ/", - ), - "mlflow_resource_arn": os.environ.get( - "MTRL_3P_MLFLOW_ARN", - f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6", - ), - "role": os.environ.get( - "MTRL_3P_ROLE", - f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin", - ), - "region": os.environ.get("MTRL_3P_REGION", _REGION), -} +def _get_3p_test_config(): + """Build test configuration lazily (only when tests actually run).""" + boto_session = boto3.Session(region_name=_REGION) + account_id = boto_session.client("sts").get_caller_identity()["Account"] + return { + "base_model": "openai-reasoning-gpt-oss-20b", + "dataset": os.environ.get( + "MTRL_3P_DATASET", + f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet", + ), + "s3_output_path": os.environ.get( + "MTRL_3P_S3_OUTPUT", + f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/3p-agent-integ/", + ), + "mlflow_resource_arn": os.environ.get( + "MTRL_3P_MLFLOW_ARN", + f"arn:aws:sagemaker:{_REGION}:{account_id}:mlflow-app/app-TTAUWUNMUHH6", + ), + "role": os.environ.get( + "MTRL_3P_ROLE", + f"arn:aws:iam::{account_id}:role/Admin", + ), + "region": os.environ.get("MTRL_3P_REGION", _REGION), + "account_id": account_id, + } + + +@pytest.fixture(scope="module") +def test_config(): + """Lazily resolve test configuration.""" + return _get_3p_test_config() def _create_lambda_zip(source_code: str) -> bytes: @@ -179,11 +184,13 @@ def _create_lambda_zip(source_code: str) -> bytes: return buf.getvalue() -def _ensure_lambda_exists() -> str: +def _ensure_lambda_exists(account_id) -> str: """Create the Lambda function if it doesn't exist, return its ARN.""" from botocore.exceptions import ClientError - client = boto3.client("lambda", region_name=LAMBDA_REGION) + boto_session = boto3.Session(region_name=LAMBDA_REGION) + client = boto_session.client("lambda") + lambda_role = f"arn:aws:iam::{account_id}:role/Admin" try: resp = client.get_function(FunctionName=LAMBDA_FUNCTION_NAME) @@ -204,7 +211,7 @@ def _ensure_lambda_exists() -> str: resp = client.create_function( FunctionName=LAMBDA_FUNCTION_NAME, Runtime=LAMBDA_RUNTIME, - Role=LAMBDA_ROLE, + Role=lambda_role, Handler="lambda_function.handler", Code={"ZipFile": zip_bytes}, Timeout=LAMBDA_TIMEOUT, @@ -228,15 +235,15 @@ def _ensure_lambda_exists() -> str: @pytest.fixture(scope="module") -def lambda_agent_arn(): +def lambda_agent_arn(test_config): """Ensure the 3P agent Lambda exists and return its ARN.""" - return _ensure_lambda_exists() + return _ensure_lambda_exists(test_config["account_id"]) class TestMTRLEvaluator3PAgentIntegration: """Integration tests for MultiTurnRLEvaluator with Lambda-based 3P agent.""" - def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn): + def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn, test_config): """Test evaluating a base model using a Lambda ARN as agent_config. This is the primary 3P integration pattern: customer provides a @@ -244,13 +251,13 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn): and the evaluator runs rollouts against it. """ evaluator = MultiTurnRLEvaluator( - model=TEST_CONFIG["base_model"], - dataset=TEST_CONFIG["dataset"], + model=test_config["base_model"], + dataset=test_config["dataset"], agent_config=lambda_agent_arn, - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}lambda-base-model/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + s3_output_path=f'{test_config["s3_output_path"]}lambda-base-model/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], accept_eula=True, ) @@ -263,7 +270,7 @@ def test_evaluate_base_model_with_lambda_agent(self, lambda_agent_arn): logger.info(f"Status: {execution.status.overall_status}") @pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually") - def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn): + def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn, test_config): """Test evaluating using an CustomAgentLambda object as agent_config. Validates that the evaluator accepts CustomAgentLambda instances (not @@ -272,13 +279,13 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn): agent = CustomAgentLambda(lambda_arn=lambda_agent_arn) evaluator = MultiTurnRLEvaluator( - model=TEST_CONFIG["base_model"], - dataset=TEST_CONFIG["dataset"], + model=test_config["base_model"], + dataset=test_config["dataset"], agent_config=agent, - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}lambda-object/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + s3_output_path=f'{test_config["s3_output_path"]}lambda-object/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], accept_eula=True, ) @@ -289,20 +296,20 @@ def test_evaluate_base_model_with_agent_lambda_object(self, lambda_agent_arn): logger.info(f"Started CustomAgentLambda object evaluation: {execution.arn}") @pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually") - def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn): + def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn, test_config): """Test full end-to-end: start evaluation and wait for completion. This test validates the complete lifecycle including wait() using the standard sagemaker-core pipeline execution path. """ evaluator = MultiTurnRLEvaluator( - model=TEST_CONFIG["base_model"], - dataset=TEST_CONFIG["dataset"], + model=test_config["base_model"], + dataset=test_config["dataset"], agent_config=lambda_agent_arn, - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}lambda-e2e/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + s3_output_path=f'{test_config["s3_output_path"]}lambda-e2e/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], accept_eula=True, ) @@ -319,20 +326,20 @@ def test_evaluate_with_lambda_agent_wait_for_completion(self, lambda_agent_arn): logger.error(f"Failure reason: {execution.status.failure_reason}") @pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually") - def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn): + def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn, test_config): """Test that 3P agent evaluations are discoverable via get_all. Validates that evaluations started with Lambda agents show up in the standard get_all() discovery path (pipeline tagging works). """ evaluator = MultiTurnRLEvaluator( - model=TEST_CONFIG["base_model"], - dataset=TEST_CONFIG["dataset"], + model=test_config["base_model"], + dataset=test_config["dataset"], agent_config=lambda_agent_arn, - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}lambda-discovery/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + s3_output_path=f'{test_config["s3_output_path"]}lambda-discovery/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], accept_eula=True, ) @@ -345,7 +352,7 @@ def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn): # Verify it's discoverable via get_all found = False - for ex in MultiTurnRLEvaluator.get_all(region=TEST_CONFIG["region"]): + for ex in MultiTurnRLEvaluator.get_all(region=test_config["region"]): if ex.arn == started_arn: found = True break @@ -359,7 +366,7 @@ def test_evaluate_lambda_agent_discoverable_via_get_all(self, lambda_agent_arn): @pytest.mark.skip(reason="Quota limited (1 concurrent eval job) - run manually") - def test_evaluate_with_attached_trainer(self, lambda_agent_arn): + def test_evaluate_with_attached_trainer(self, lambda_agent_arn, test_config): """Test evaluating a fine-tuned model by attaching to an existing training job.""" from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer @@ -369,12 +376,12 @@ def test_evaluate_with_attached_trainer(self, lambda_agent_arn): evaluator = MultiTurnRLEvaluator( model=attached_job, - dataset=TEST_CONFIG["dataset"], + dataset=test_config["dataset"], agent_config=lambda_agent_arn, - s3_output_path=f'{TEST_CONFIG["s3_output_path"]}attached-trainer/', - mlflow_resource_arn=TEST_CONFIG["mlflow_resource_arn"], - role=TEST_CONFIG["role"], - region=TEST_CONFIG["region"], + s3_output_path=f'{test_config["s3_output_path"]}attached-trainer/', + mlflow_resource_arn=test_config["mlflow_resource_arn"], + role=test_config["role"], + region=test_config["region"], accept_eula=True, ) diff --git a/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py b/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py index 92bebe6a38..20b9ae1e1f 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py @@ -28,10 +28,6 @@ import boto3 -os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2") -os.environ.setdefault("SAGEMAKER_REGION", "us-west-2") -os.environ.setdefault("AWS_REGION", "us-west-2") - from sagemaker.train.multi_turn_rl_trainer import MultiTurnRLTrainer from sagemaker.train.evaluate import MultiTurnRLEvaluator @@ -45,7 +41,8 @@ def _get_account_id(): """Get current AWS account ID via STS.""" - return boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"] + boto_session = boto3.Session(region_name=_REGION) + return boto_session.client("sts").get_caller_identity()["Account"] # ============================================================ # Per-account resource configuration diff --git a/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py b/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py index bdac7a7323..86da4f5109 100644 --- a/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py @@ -20,10 +20,6 @@ import os import time -os.environ.setdefault("AWS_DEFAULT_REGION", "us-west-2") -os.environ.setdefault("SAGEMAKER_REGION", "us-west-2") -os.environ.setdefault("AWS_REGION", "us-west-2") - import boto3 import pytest from sagemaker.core.helper.session_helper import Session @@ -31,40 +27,55 @@ from sagemaker.train.agent_rft_job import AgentRFTJob _REGION = "us-west-2" -_ACCOUNT_ID = boto3.client("sts", region_name=_REGION).get_caller_identity()["Account"] +_ACCOUNT_ID = None # Resolved lazily in fixtures + + +def _get_account_id(): + """Resolve account ID lazily.""" + global _ACCOUNT_ID + if _ACCOUNT_ID is None: + boto_session = boto3.Session(region_name=_REGION) + _ACCOUNT_ID = boto_session.client("sts").get_caller_identity()["Account"] + return _ACCOUNT_ID AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS" -ROLE_ARN = f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin" -MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-TTAUWUNMUHH6" -S3_INPUT_PATH = f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet" -S3_OUTPUT_PATH = f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/mtrl-trainer-integ/" -LAMBDA_ARN = f"arn:aws:lambda:{_REGION}:{_ACCOUNT_ID}:function:SageMaker-AgentConnector-Lambda-MTRL-integ-test" BASE_MODEL = "openai-reasoning-gpt-oss-20b" -EXISTING_JOB_NAME="openai-reasoning-gpt-oss-20b-mtrl-20260602005937" +EXISTING_JOB_NAME = "openai-reasoning-gpt-oss-20b-mtrl-20260602005937" @pytest.fixture(scope="module") def sagemaker_session(): - os.environ.setdefault("AWS_DEFAULT_REGION", _REGION) - os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = f"https://mlflow.sagemaker.{_REGION}.app.aws" boto_session = boto3.Session(region_name=_REGION) session = Session(boto_session=boto_session) yield session +@pytest.fixture(scope="module") +def test_resources(): + """Resolve account-specific resource ARNs lazily.""" + account_id = _get_account_id() + return { + "role_arn": f"arn:aws:iam::{account_id}:role/Admin", + "mlflow_arn": f"arn:aws:sagemaker:{_REGION}:{account_id}:mlflow-app/app-TTAUWUNMUHH6", + "s3_input_path": f"s3://sagemaker-rft-{account_id}/prompts/gsm8k_small/prompts.parquet", + "s3_output_path": f"s3://sagemaker-{_REGION}-{account_id}/model-evaluation/mtrl-trainer-integ/", + "lambda_arn": f"arn:aws:lambda:{_REGION}:{account_id}:function:SageMaker-AgentConnector-Lambda-MTRL-integ-test", + } + + @pytest.mark.skip(reason="GPU resource intensive — run manually") class TestMultiTurnRLTrainerBedrockAgent: """Test MTRL training with Bedrock AgentCore runtime.""" - def test_train_and_wait(self, sagemaker_session): + def test_train_and_wait(self, sagemaker_session, test_resources): """Test complete MTRL workflow with Bedrock AgentCore agent.""" trainer = MultiTurnRLTrainer( model=BASE_MODEL, agent_env=AGENT_RUNTIME_ID, - training_dataset=S3_INPUT_PATH, - mlflow_app_arn=MLFLOW_ARN, - s3_output_path=S3_OUTPUT_PATH, - role=ROLE_ARN, + training_dataset=test_resources["s3_input_path"], + mlflow_app_arn=test_resources["mlflow_arn"], + s3_output_path=test_resources["s3_output_path"], + role=test_resources["role_arn"], accept_eula=True, sagemaker_session=sagemaker_session, ) @@ -81,14 +92,14 @@ def test_train_and_wait(self, sagemaker_session): assert job.output_model_package_arn is not None assert job.s3_output_path is not None - def test_train_and_stop(self, sagemaker_session): + def test_train_and_stop(self, sagemaker_session, test_resources): """Test creating and stopping an MTRL job.""" trainer = MultiTurnRLTrainer( model=BASE_MODEL, agent_env=AGENT_RUNTIME_ID, - training_dataset=S3_INPUT_PATH, - mlflow_app_arn=MLFLOW_ARN, - role=ROLE_ARN, + training_dataset=test_resources["s3_input_path"], + mlflow_app_arn=test_resources["mlflow_arn"], + role=test_resources["role_arn"], accept_eula=True, sagemaker_session=sagemaker_session, ) @@ -109,16 +120,16 @@ def test_train_and_stop(self, sagemaker_session): class TestMultiTurnRLTrainerLambdaAgent: """Test MTRL training with Lambda agent.""" - def test_train_with_lambda_arn(self, sagemaker_session): + def test_train_with_lambda_arn(self, sagemaker_session, test_resources): """Test MTRL workflow using an existing Lambda ARN as agent.""" trainer = MultiTurnRLTrainer( model=BASE_MODEL, - agent_env=LAMBDA_ARN, - training_dataset=S3_INPUT_PATH, - mlflow_app_arn=MLFLOW_ARN, - s3_output_path=S3_OUTPUT_PATH, + agent_env=test_resources["lambda_arn"], + training_dataset=test_resources["s3_input_path"], + mlflow_app_arn=test_resources["mlflow_arn"], + s3_output_path=test_resources["s3_output_path"], accept_eula=True, - role=ROLE_ARN, + role=test_resources["role_arn"], sagemaker_session=sagemaker_session, ) trainer.hyperparameters.global_batch_size = 32