From 0f37d919d3e8d2a41e36bf117b7c73229e47bedd Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 13:19:43 -0700 Subject: [PATCH 01/11] fix: bypass SageMakerClient singleton for cross-region model package resolution The SageMakerClient singleton caches the first region it is initialized with and ignores subsequent region parameters. This causes Nova integ tests (which run in us-east-1) to fail when the singleton was already created with us-west-2 by an earlier test in the same process. Errors observed: - ModelPackageGroup arn:aws:sagemaker:us-west-2:784379639078:model-package-group/sdk-test-finetuned-models does not exist - DescribeModelPackage: ARN should be scoped to correct region: us-west-2 Fix: use session.boto_session.client("sagemaker") directly instead of ModelPackageGroup.get() / ModelPackage.get() in the three call sites that resolve model package resources. This respects the session's actual region without depending on the singleton's cached state. --- .../train/common_utils/finetune_utils.py | 30 ++++++++++++------- .../train/common_utils/model_resolution.py | 19 +++++++----- .../train/evaluate/base_evaluator.py | 24 +++++++-------- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index dd5bfa87b1..56558c4b31 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -338,13 +338,17 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ # It's already an ARN return model_package_group_name_or_arn else: - # It's a name, resolve to ARN - model_package_group = ModelPackageGroup.get( - model_package_group_name=model_package_group_name_or_arn, - session=sagemaker_session.boto_session, - region=sagemaker_session.boto_session.region_name + # It's a name, resolve to ARN using the session's client directly + # to respect the session's region (avoids SageMakerClient singleton + # which may cache a different region). + sm_client = sagemaker_session.boto_session.client( + "sagemaker", + region_name=sagemaker_session.boto_session.region_name ) - return model_package_group.model_package_group_arn + response = sm_client.describe_model_package_group( + ModelPackageGroupName=model_package_group_name_or_arn + ) + return response["ModelPackageGroupArn"] else: # It's a ModelPackageGroup object return model_package_group_name_or_arn.model_package_group_arn @@ -581,12 +585,16 @@ def _resolve_model_and_name(model, sagemaker_session=None): if isinstance(model, str): # Check if it's a model package ARN if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model: - # Get ModelPackage object from ARN - model_package = ModelPackage.get( - model_package_name=model, - session=sagemaker_session.boto_session if sagemaker_session else None, - region=sagemaker_session.boto_session.region_name if sagemaker_session else None + # Get ModelPackage object from ARN using the session's boto client directly + # to avoid SageMakerClient singleton which may cache a different region. + from sagemaker.core.utils.code_injection.codec import transform + sm_client = sagemaker_session.boto_session.client( + "sagemaker", + region_name=sagemaker_session.boto_session.region_name ) + response = sm_client.describe_model_package(ModelPackageName=model) + transformed_response = transform(response, "DescribeModelPackageOutput") + model_package = ModelPackage(**transformed_response) model_name = _resolve_model_name(model_package) # Validate region availability if region_name: diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 8e2bee6971..596526338b 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -290,20 +290,23 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo: # Validate ARN format self._validate_model_package_arn(model_package_arn) - # Use sagemaker.core ModelPackage.get() to retrieve model package information from sagemaker.core.resources import ModelPackage + from sagemaker.core.utils.code_injection.codec import transform import logging logger = logging.getLogger(__name__) - # Get the model package using sagemaker.core - model_package = ModelPackage.get( - model_package_name=model_package_arn, - session=session.boto_session, - region=session.boto_session.region_name - ) + # Use the session's boto client directly to avoid SageMakerClient singleton + # which may cache a different region. + region = session.boto_session.region_name + sm_client = session.boto_session.client("sagemaker", region_name=region) + response = sm_client.describe_model_package(ModelPackageName=model_package_arn) + + logger.info(f"Retrieved ModelPackage in region: {region}") - logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}") + # Deserialize and create ModelPackage object + transformed_response = transform(response, "DescribeModelPackageOutput") + model_package = ModelPackage(**transformed_response) # Now use the existing _resolve_model_package_object method to extract base model info return self._resolve_model_package_object(model_package) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 7d7fac006d..5b79c08e68 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -222,21 +222,17 @@ def _validate_and_resolve_model_package_group(cls, v, values): if hasattr(session, 'boto_region_name') else 'us-west-2') - # Fetch the object - obj = ModelPackageGroup.get( - model_package_group_name=v, - region=region - ) - - # Extract ARN - if hasattr(obj, 'model_package_group_arn'): - arn = obj.model_package_group_arn - _logger.info(f"Resolved model package group name '{v}' to ARN: {arn}") - return arn + # Resolve directly via boto3 client to avoid SageMakerClient singleton + # which may cache a different region. + import boto3 as _boto3 + if session and hasattr(session, 'boto_session'): + sm_client = session.boto_session.client("sagemaker", region_name=region) else: - raise ValueError( - f"ModelPackageGroup object for name '{v}' does not have model_package_group_arn attribute" - ) + sm_client = _boto3.client("sagemaker", region_name=region) + response = sm_client.describe_model_package_group(ModelPackageGroupName=v) + arn = response["ModelPackageGroupArn"] + _logger.info(f"Resolved model package group name '{v}' to ARN: {arn}") + return arn except Exception as e: raise ValueError( From 37d081adf890036a07a86a72c12b307d58a817ae Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 14:43:36 -0700 Subject: [PATCH 02/11] test: update unit tests --- .../train/common_utils/test_finetune_utils.py | 46 ++-- .../common_utils/test_model_resolution.py | 197 ++++++++++-------- .../train/evaluate/test_base_evaluator.py | 27 +-- 3 files changed, 157 insertions(+), 113 deletions(-) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 701f8ccd51..a3c9b4a8a1 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -168,17 +168,21 @@ def test__validate_model_package_group_requirement_without_group_name(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_model_package_group_requirement("string-model", None) - @patch('sagemaker.core.resources.ModelPackageGroup.get') - def test__resolve_model_package_group_arn_with_name(self, mock_get): + def test__resolve_model_package_group_arn_with_name(self): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" - mock_group = Mock() - mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - mock_get.return_value = mock_group + mock_sm_client = Mock() + mock_sm_client.describe_model_package_group.return_value = { + "ModelPackageGroupArn": "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + } + mock_session.boto_session.client.return_value = mock_sm_client result = _resolve_model_package_group_arn("test-group", mock_session) - assert result == mock_group.model_package_group_arn + assert result == "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + mock_sm_client.describe_model_package_group.assert_called_once_with( + ModelPackageGroupName="test-group" + ) def test__resolve_model_package_group_arn_with_arn(self): mock_session = Mock() @@ -362,23 +366,35 @@ def test__validate_and_resolve_model_package_group_missing_both(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_and_resolve_model_package_group("string-model", None) - @patch('sagemaker.core.resources.ModelPackage.get') - def test__resolve_model_and_name_with_model_package_arn(self, mock_get): + def test__resolve_model_and_name_with_model_package_arn(self): mock_session = Mock() mock_session.boto_region_name = "us-east-1" # Set valid region - mock_model_package = Mock(spec=ModelPackage) + mock_session.boto_session.region_name = "us-east-1" + mock_sm_client = Mock() + mock_session.boto_session.client.return_value = mock_sm_client + + # Mock describe_model_package response + mock_sm_client.describe_model_package.return_value = { + "ModelPackageArn": "arn:aws:sagemaker:us-east-1:123456789012:model-package/test", + } + + # Mock transform and ModelPackage constructor mock_container = Mock() mock_base_model = Mock() mock_base_model.hub_content_name = "test-model" mock_container.base_model = mock_base_model - mock_model_package.inference_specification = Mock() - mock_model_package.inference_specification.containers = [mock_container] - mock_get.return_value = mock_model_package + mock_inference_spec = Mock() + mock_inference_spec.containers = [mock_container] - model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) + mock_model_package = Mock(spec=ModelPackage) + mock_model_package.inference_specification = mock_inference_spec - assert model == mock_model_package - assert name == "test-model" + with patch('sagemaker.core.utils.code_injection.codec.transform', return_value={}): + with patch('sagemaker.train.common_utils.finetune_utils.ModelPackage', return_value=mock_model_package): + model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) + + assert model == mock_model_package + assert name == "test-model" def test__resolve_model_and_name_with_string(self): model, name = _resolve_model_and_name("test-model") diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index 8a57dc2d28..ddaeb56331 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -307,11 +307,10 @@ def test_resolve_package_fallback_name(self): class TestResolveModelPackageArn: """Tests for _resolve_model_package_arn method.""" - @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_success(self, mock_validate, mock_get_session, mock_model_package_class): - """Test successful ARN resolution using ModelPackage.get().""" + def test_resolve_arn_success(self, mock_validate, mock_get_session): + """Test successful ARN resolution.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" # Mock session @@ -319,40 +318,52 @@ def test_resolve_arn_success(self, mock_validate, mock_get_session, mock_model_p mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock ModelPackage.get() return value - mock_package = MagicMock() - mock_package.model_package_arn = arn - - # Mock inference specification with hub_content_arn - mock_container = MagicMock() - mock_base_model = MagicMock() - mock_base_model.hub_content_name = 'base-model' - mock_base_model.hub_content_version = '1.0' - mock_base_model.hub_content_arn = 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' - mock_container.base_model = mock_base_model - - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] - - mock_model_package_class.get.return_value = mock_package - - resolver = _ModelResolver() - result = resolver._resolve_model_package_arn(arn) - - assert result.base_model_name == "base-model" - assert result.hub_content_name == "base-model" - assert result.source_model_package_arn == arn - assert result.model_type == _ModelType.FINE_TUNED - mock_model_package_class.get.assert_called_once_with( - model_package_name=arn, - session=mock_session.boto_session, - region='us-west-2' - ) + # Mock boto client + mock_sm_client = MagicMock() + mock_session.boto_session.client.return_value = mock_sm_client + mock_sm_client.describe_model_package.return_value = { + "ModelPackageArn": arn, + "InferenceSpecification": { + "Containers": [{ + "BaseModel": { + "HubContentName": "base-model", + "HubContentVersion": "1.0", + "HubContentArn": "arn:aws:sagemaker:us-west-2:aws:hub-content/base" + } + }] + } + } + + # Mock transform to return a proper ModelPackage-like object + with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: + mock_transformed = {} + mock_transform.return_value = mock_transformed + + with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' + mock_container.base_model = mock_base_model + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + mock_mp_class.return_value = mock_package + + resolver = _ModelResolver() + result = resolver._resolve_model_package_arn(arn) + + assert result.base_model_name == "base-model" + assert result.hub_content_name == "base-model" + assert result.source_model_package_arn == arn + assert result.model_type == _ModelType.FINE_TUNED + mock_sm_client.describe_model_package.assert_called_once_with(ModelPackageName=arn) - @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session, mock_model_package_class): + def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session): """Test ARN resolution when HubContentArn needs to be constructed.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -361,35 +372,38 @@ def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_ses mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock ModelPackage without hub_content_arn (needs to be constructed) - mock_package = MagicMock() - mock_package.model_package_arn = arn - - mock_container = MagicMock() - mock_base_model = MagicMock() - mock_base_model.hub_content_name = 'base-model' - mock_base_model.hub_content_version = '1.0' - mock_base_model.hub_content_arn = None # Not provided, needs construction - mock_container.base_model = mock_base_model - - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] + # Mock boto client + mock_sm_client = MagicMock() + mock_session.boto_session.client.return_value = mock_sm_client + mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} - mock_model_package_class.get.return_value = mock_package - - resolver = _ModelResolver() - result = resolver._resolve_model_package_arn(arn) - - # Should construct ARN from region and hub content name/version - expected_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/base-model/1.0" - assert result.base_model_arn == expected_arn - assert result.base_model_name == "base-model" - assert result.hub_content_name == "base-model" + with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: + mock_transform.return_value = {} + + with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = None # Not provided, needs construction + mock_container.base_model = mock_base_model + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + mock_mp_class.return_value = mock_package + + resolver = _ModelResolver() + result = resolver._resolve_model_package_arn(arn) + + expected_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/base-model/1.0" + assert result.base_model_arn == expected_arn + assert result.base_model_name == "base-model" + assert result.hub_content_name == "base-model" - @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session, mock_model_package_class): + def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session): """Test error when InferenceSpecification is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -398,22 +412,28 @@ def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session, mo mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock ModelPackage without inference_specification - mock_package = MagicMock() - mock_package.model_package_arn = arn - mock_package.inference_specification = None - - mock_model_package_class.get.return_value = mock_package - - resolver = _ModelResolver() + # Mock boto client + mock_sm_client = MagicMock() + mock_session.boto_session.client.return_value = mock_sm_client + mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} - with pytest.raises(ValueError, match="NotSupported.*does not have an inference_specification"): - resolver._resolve_model_package_arn(arn) + with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: + mock_transform.return_value = {} + + with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_package.inference_specification = None + mock_mp_class.return_value = mock_package + + resolver = _ModelResolver() + + with pytest.raises(ValueError, match="NotSupported.*does not have an inference_specification"): + resolver._resolve_model_package_arn(arn) - @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session, mock_model_package_class): + def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session): """Test error when BaseModel is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -422,22 +442,27 @@ def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session, mock_m mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock ModelPackage with container but no base_model - mock_package = MagicMock() - mock_package.model_package_arn = arn - - mock_container = MagicMock() - mock_container.base_model = None - - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] - - mock_model_package_class.get.return_value = mock_package - - resolver = _ModelResolver() + # Mock boto client + mock_sm_client = MagicMock() + mock_session.boto_session.client.return_value = mock_sm_client + mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} - with pytest.raises(ValueError, match="NotSupported.*does not have base_model metadata"): - resolver._resolve_model_package_arn(arn) + with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: + mock_transform.return_value = {} + + with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_container = MagicMock() + mock_container.base_model = None + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + mock_mp_class.return_value = mock_package + + resolver = _ModelResolver() + + with pytest.raises(ValueError, match="NotSupported.*does not have base_model metadata"): + resolver._resolve_model_package_arn(arn) class TestValidateModelPackageArn: diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 1f690be8a9..13060ebe0a 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -311,15 +311,16 @@ def test_model_package_group_arn_valid(self, mock_resolve, mock_session, mock_mo assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - @patch("sagemaker.core.resources.ModelPackageGroup.get") - def test_model_package_group_name_resolution(self, mock_mpg_get, mock_resolve, mock_session, mock_model_info): + def test_model_package_group_name_resolution(self, mock_resolve, mock_session, mock_model_info): """Test model package group name resolution to ARN.""" mock_resolve.return_value = mock_model_info - # Mock ModelPackageGroup.get to return an object with ARN - mock_mpg = MagicMock() - mock_mpg.model_package_group_arn = DEFAULT_MODEL_PACKAGE_GROUP_ARN - mock_mpg_get.return_value = mock_mpg + # Mock the boto client's describe_model_package_group call + mock_sm_client = MagicMock() + mock_sm_client.describe_model_package_group.return_value = { + "ModelPackageGroupArn": DEFAULT_MODEL_PACKAGE_GROUP_ARN + } + mock_session.boto_session.client.return_value = mock_sm_client evaluator = BaseEvaluator( model=DEFAULT_MODEL, @@ -331,9 +332,8 @@ def test_model_package_group_name_resolution(self, mock_mpg_get, mock_resolve, m ) assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN - mock_mpg_get.assert_called_once_with( - model_package_group_name="my-package", - region=DEFAULT_REGION, + mock_sm_client.describe_model_package_group.assert_called_once_with( + ModelPackageGroupName="my-package" ) @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") @@ -355,11 +355,14 @@ def test_model_package_group_object_resolution(self, mock_resolve, mock_session, assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - @patch("sagemaker.core.resources.ModelPackageGroup.get") - def test_model_package_group_name_not_found(self, mock_mpg_get, mock_resolve, mock_session, mock_model_info): + def test_model_package_group_name_not_found(self, mock_resolve, mock_session, mock_model_info): """Test model package group name that doesn't exist.""" mock_resolve.return_value = mock_model_info - mock_mpg_get.side_effect = Exception("Model package group not found") + + # Mock the boto client to raise an exception + mock_sm_client = MagicMock() + mock_sm_client.describe_model_package_group.side_effect = Exception("Model package group not found") + mock_session.boto_session.client.return_value = mock_sm_client with pytest.raises(ValidationError, match="Failed to resolve model package group name"): BaseEvaluator( From 059ac4111475a07fdaa6dc87825e6aa74fed6b24 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 14:50:35 -0700 Subject: [PATCH 03/11] fix: handle missing pipeline version context in lineage update _update_pipeline_lineage assumed the version context always exists. When it's been deleted or never created (e.g. prior run failure), DescribeContext throws ResourceNotFound. Now catches the error and recreates the version context with proper associations. --- .../lineage/_feature_processor_lineage.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py index d706b3b441..cf86d89118 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -359,9 +359,39 @@ def _update_pipeline_lineage( # If pipeline lineage exists then determine whether to create a new version. pipeline_context: Context = self._get_pipeline_context() - current_pipeline_version_context: Context = self._get_pipeline_version_context( - last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] - ) + try: + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + except ClientError as e: + if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND: + # Pipeline version context does not exist (possibly deleted or never created). + # Create a new pipeline version context and its associations. + logger.info( + "Pipeline version context not found. Creating new pipeline version lineage." + ) + pipeline_context.properties["LastUpdateTime"] = self.pipeline[ + LAST_MODIFIED_TIME + ].strftime("%s") + PipelineLineageEntityHandler.update_pipeline_context( + pipeline_context=pipeline_context + ) + new_pipeline_version_context: Context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + pipeline_context_arn=pipeline_context.context_arn, + pipeline_versions_context_arn=new_pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + pipeline_context_arn=pipeline_context.context_arn, + pipeline_version_context_arn=new_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + return + raise e upstream_feature_group_associations: Iterator[AssociationSummary] = ( LineageAssociationHandler.list_upstream_associations( # pylint: disable=no-member From fbfc0c79cc4d6251b6b305bc2e98af27c1c07bb8 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 17:35:16 -0700 Subject: [PATCH 04/11] fix(test): add mlflow_resource_arn fixture that auto-discovers or creates app Replace hard-coded MLflow app ARN with a conftest fixture that finds an existing ready app or creates a temporary one (cleaned up after tests). Prevents failures when the hard-coded app is deleted or quota is full. X-AI-Prompt: add self-healing mlflow fixture for llm_as_judge integ tests X-AI-Tool: kiro-cli --- sagemaker-train/tests/integ/train/conftest.py | 89 +++++++++++++++++++ .../train/test_llm_as_judge_base_model_fix.py | 6 +- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/sagemaker-train/tests/integ/train/conftest.py b/sagemaker-train/tests/integ/train/conftest.py index 1857a6262d..712afee3bc 100644 --- a/sagemaker-train/tests/integ/train/conftest.py +++ b/sagemaker-train/tests/integ/train/conftest.py @@ -48,3 +48,92 @@ def sagemaker_session_us_east_1(): """Create a SageMaker session in us-east-1 for Nova model tests.""" boto_session = boto3.Session(region_name=NOVA_REGION) return Session(boto_session=boto_session) + + +import time +import logging + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def mlflow_resource_arn(): + """Discover or create an MLflow app for integ tests, clean up if created. + + Looks for an existing MLflow app in Created/Updated state. If none exists, + creates one and deletes it after the test module finishes. + """ + region = os.environ.get("AWS_DEFAULT_REGION", DEFAULT_REGION) + sm_client = boto3.client("sagemaker", region_name=region) + created_arn = None + + # Try to find an existing ready app + try: + paginator = sm_client.get_paginator("list_mlflow_apps") + for page in paginator.paginate(): + for app in page.get("MlflowApps", []): + if app.get("Status") in ("Created", "Updated"): + logger.info(f"Using existing MLflow app: {app['Arn']}") + yield app["Arn"] + return + except Exception as e: + logger.warning(f"Failed to list MLflow apps: {e}") + + # No ready app found — create one + logger.info("No ready MLflow app found. Creating one for integ tests...") + sts_client = boto3.client("sts", region_name=region) + account_id = sts_client.get_caller_identity()["Account"] + app_name = f"integ-test-mlflow-{int(time.time())}" + artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts" + + # Ensure bucket/prefix exists + s3_client = boto3.client("s3", region_name=region) + bucket_name = f"sagemaker-{region}-{account_id}" + try: + s3_client.head_bucket(Bucket=bucket_name) + except Exception: + if region == "us-east-1": + s3_client.create_bucket(Bucket=bucket_name) + else: + s3_client.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, + ) + try: + s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/") + except Exception: + pass + + # Get execution role + from sagemaker.train.defaults import TrainDefaults + boto_session = boto3.Session(region_name=region) + sagemaker_session = Session(boto_session=boto_session) + role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session) + + resp = sm_client.create_mlflow_app( + Name=app_name, + ArtifactStoreUri=artifact_store_uri, + RoleArn=role_arn, + AccountDefaultStatus="DISABLED", + ) + created_arn = resp["Arn"] + logger.info(f"Created MLflow app: {created_arn}") + + # Wait for it to become ready + for _ in range(60): + desc = sm_client.describe_mlflow_app(Arn=created_arn) + status = desc.get("Status") + if status in ("Created", "Updated"): + break + if status in ("Failed", "CreateFailed", "DeleteFailed"): + pytest.skip(f"MLflow app creation failed: {desc.get('FailureReason')}") + time.sleep(10) + + yield created_arn + + # Cleanup + logger.info(f"Cleaning up MLflow app: {created_arn}") + try: + sm_client.delete_mlflow_app(Arn=created_arn) + except Exception as e: + logger.warning(f"Failed to delete MLflow app {created_arn}: {e}") diff --git a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py index 1da31f71c6..0d3ffd0fde 100644 --- a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py +++ b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py @@ -81,7 +81,7 @@ class TestLLMAsJudgeBaseModelFix: """Integration test for base model fix in LLMAsJudgeEvaluator""" - def test_base_model_evaluation_uses_correct_weights(self): + def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn): """ Test that base model evaluation uses original base model weights. @@ -115,6 +115,7 @@ def test_base_model_evaluation_uses_correct_weights(self): custom_metrics=TEST_CONFIG["custom_metrics_json"], s3_output_path=TEST_CONFIG["s3_output_path"], evaluate_base_model=TEST_CONFIG["evaluate_base_model"], + mlflow_resource_arn=mlflow_resource_arn, ) # Verify evaluator configuration @@ -250,7 +251,7 @@ def test_base_model_evaluation_uses_correct_weights(self): # Re-raise to fail the test raise - def test_base_model_false_still_works(self): + def test_base_model_false_still_works(self, mlflow_resource_arn): """ Test that evaluate_base_model=False still works correctly (backward compatibility). @@ -271,6 +272,7 @@ def test_base_model_false_still_works(self): builtin_metrics=TEST_CONFIG["builtin_metrics"], s3_output_path=TEST_CONFIG["s3_output_path"], evaluate_base_model=False, # Only evaluate custom model + mlflow_resource_arn=mlflow_resource_arn, ) # Verify evaluator configuration From c3644a778714e68a924f9beb517ea81b2150ed09 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 17:56:19 -0700 Subject: [PATCH 05/11] fix(test): use correct response key "Summaries" for list_mlflow_apps API --- sagemaker-train/tests/integ/train/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker-train/tests/integ/train/conftest.py b/sagemaker-train/tests/integ/train/conftest.py index 712afee3bc..03ad480702 100644 --- a/sagemaker-train/tests/integ/train/conftest.py +++ b/sagemaker-train/tests/integ/train/conftest.py @@ -71,7 +71,7 @@ def mlflow_resource_arn(): try: paginator = sm_client.get_paginator("list_mlflow_apps") for page in paginator.paginate(): - for app in page.get("MlflowApps", []): + for app in page.get("Summaries", []): if app.get("Status") in ("Created", "Updated"): logger.info(f"Using existing MLflow app: {app['Arn']}") yield app["Arn"] From bd2f4062917764d4907a43f055c7d95b6abf81d3 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 21:56:19 -0700 Subject: [PATCH 06/11] mark two slow tests as not serial --- .../tests/integ/train/test_llm_as_judge_base_model_fix.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py index 0d3ffd0fde..84c2d5ca02 100644 --- a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py +++ b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py @@ -77,7 +77,6 @@ } -@pytest.mark.serial class TestLLMAsJudgeBaseModelFix: """Integration test for base model fix in LLMAsJudgeEvaluator""" From 3f676da5fc089c75bf777870050f792c3e45f26d Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 23:03:22 -0700 Subject: [PATCH 07/11] fix: use correct response key "Summaries" in _resolve_mlflow_resource_arn --- .../src/sagemaker/train/common_utils/finetune_utils.py | 2 +- .../tests/unit/train/common_utils/test_finetune_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 56558c4b31..73af317208 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -125,7 +125,7 @@ def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optiona mlflow_apps_list = [] paginator = sm_client.get_paginator("list_mlflow_apps") for page in paginator.paginate(): - mlflow_apps_list.extend(page.get("MlflowApps", [])) + mlflow_apps_list.extend(page.get("Summaries", [])) logger.info("Found %d MLflow apps: %s", len(mlflow_apps_list), [(a.get("Name", "?"), a.get("Status", "?"), a.get("MlflowVersion", "?")) for a in mlflow_apps_list]) diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index a3c9b4a8a1..6dbdbb6db8 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -83,7 +83,7 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_client, moc mock_get_domain.return_value = "d-123456789" mock_sm_client = Mock() mock_paginator = Mock() - mock_paginator.paginate.return_value = [{"MlflowApps": []}] + mock_paginator.paginate.return_value = [{"Summaries": []}] mock_sm_client.get_paginator.return_value = mock_paginator mock_get_client.return_value = mock_sm_client expected_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app" @@ -649,7 +649,7 @@ def test_upgrades_when_below_min_version(self, mock_get_client, mock_upgrade, mo } mock_sm_client = Mock() mock_paginator = Mock() - mock_paginator.paginate.return_value = [{"MlflowApps": [old_app]}] + mock_paginator.paginate.return_value = [{"Summaries": [old_app]}] mock_sm_client.get_paginator.return_value = mock_paginator mock_get_client.return_value = mock_sm_client @@ -675,7 +675,7 @@ def test_no_upgrade_when_meets_version(self, mock_get_client, mock_domain): } mock_sm_client = Mock() mock_paginator = Mock() - mock_paginator.paginate.return_value = [{"MlflowApps": [app]}] + mock_paginator.paginate.return_value = [{"Summaries": [app]}] mock_sm_client.get_paginator.return_value = mock_paginator mock_get_client.return_value = mock_sm_client From c39b7dc50a3a462abe20e9b400fee7de4a826d8a Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 00:37:36 -0700 Subject: [PATCH 08/11] replace not-existing mlflow app --- sagemaker-train/tests/integ/train/test_mtrl_evaluator.py | 2 +- .../tests/integ/train/test_mtrl_evaluator_3p_agent.py | 2 +- .../tests/integ/train/test_mtrl_trainer_integration.py | 2 +- .../tests/integ/train/test_multi_turn_rl_trainer_integration.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py b/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py index 512d66fa33..a746fb653a 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_evaluator.py @@ -47,7 +47,7 @@ "agent_arn": f"arn:aws:bedrock-agentcore:{_REGION}:{_ACCOUNT_ID}:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS", "dataset": f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet", "s3_output_path": f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/output-artifacts/", - "mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU", + "mlflow_resource_arn": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-O4ZGQYBYHMRH", "model_package_group": f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg", "role": f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin", "region": _REGION, diff --git a/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py b/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py index 5e9a70964e..81dd24f835 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_evaluator_3p_agent.py @@ -161,7 +161,7 @@ def handler(event, context): ), "mlflow_resource_arn": os.environ.get( "MTRL_3P_MLFLOW_ARN", - f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU", + f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-O4ZGQYBYHMRH", ), "role": os.environ.get( "MTRL_3P_ROLE", diff --git a/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py b/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py index 5a09b22894..28fb4e8c12 100644 --- a/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_mtrl_trainer_integration.py @@ -60,7 +60,7 @@ def _get_account_id(): "agent_core_arn": "arn:aws:bedrock-agentcore:us-west-2:729646638167:runtime/sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS", "dataset": "s3://sagemaker-rft-729646638167/prompts/gsm8k_small/prompts.parquet", "s3_output_path": "s3://sagemaker-us-west-2-729646638167/mtrl-integ/eval-output/", - "mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-ZG6FYITNGMMU", + "mlflow_resource_arn": "arn:aws:sagemaker:us-west-2:729646638167:mlflow-app/app-O4ZGQYBYHMRH", "model_package_group": "arn:aws:sagemaker:us-west-2:729646638167:model-package-group/openai-reasoning-gpt-oss-20b-mtrl-mpg", "role": "arn:aws:iam::729646638167:role/Admin", }, diff --git a/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py b/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py index b7fa08d669..c91f146a98 100644 --- a/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py +++ b/sagemaker-train/tests/integ/train/test_multi_turn_rl_trainer_integration.py @@ -35,7 +35,7 @@ AGENT_RUNTIME_ID = "sagemaker_rft_prod_gsm8k_streaming-Yk6O377mUS" ROLE_ARN = f"arn:aws:iam::{_ACCOUNT_ID}:role/Admin" -MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-ZG6FYITNGMMU" +MLFLOW_ARN = f"arn:aws:sagemaker:{_REGION}:{_ACCOUNT_ID}:mlflow-app/app-O4ZGQYBYHMRH" S3_INPUT_PATH = f"s3://sagemaker-rft-{_ACCOUNT_ID}/prompts/gsm8k_small/prompts.parquet" S3_OUTPUT_PATH = f"s3://sagemaker-{_REGION}-{_ACCOUNT_ID}/model-evaluation/mtrl-trainer-integ/" LAMBDA_ARN = f"arn:aws:lambda:{_REGION}:{_ACCOUNT_ID}:function:SageMaker-AgentConnector-Lambda-MTRL-integ-test" From cdd050006ae956cda90ae1b66311d17fbdb50029 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 10:19:17 -0700 Subject: [PATCH 09/11] refactor: use session.sagemaker_client instead of boto_session.client Per SDK coding standards, avoid calling boto3 directly. Use the session's sagemaker_client attribute which already has the correct region bound at session creation time. --- .../train/common_utils/finetune_utils.py | 19 +++++++------------ .../train/common_utils/model_resolution.py | 9 ++++----- .../train/evaluate/base_evaluator.py | 11 ++++++----- .../train/common_utils/test_finetune_utils.py | 4 ++-- .../common_utils/test_model_resolution.py | 8 ++++---- .../train/evaluate/test_base_evaluator.py | 4 ++-- 6 files changed, 25 insertions(+), 30 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 73af317208..c1b6c8f422 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -338,13 +338,10 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ # It's already an ARN return model_package_group_name_or_arn else: - # It's a name, resolve to ARN using the session's client directly - # to respect the session's region (avoids SageMakerClient singleton + # It's a name, resolve to ARN using the session's sagemaker_client + # which respects the session's region (avoids SageMakerClient singleton # which may cache a different region). - sm_client = sagemaker_session.boto_session.client( - "sagemaker", - region_name=sagemaker_session.boto_session.region_name - ) + sm_client = sagemaker_session.sagemaker_client response = sm_client.describe_model_package_group( ModelPackageGroupName=model_package_group_name_or_arn ) @@ -585,13 +582,11 @@ def _resolve_model_and_name(model, sagemaker_session=None): if isinstance(model, str): # Check if it's a model package ARN if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model: - # Get ModelPackage object from ARN using the session's boto client directly - # to avoid SageMakerClient singleton which may cache a different region. + # Get ModelPackage object from ARN using the session's sagemaker_client + # which respects the session's region (avoids SageMakerClient singleton + # which may cache a different region). from sagemaker.core.utils.code_injection.codec import transform - sm_client = sagemaker_session.boto_session.client( - "sagemaker", - region_name=sagemaker_session.boto_session.region_name - ) + sm_client = sagemaker_session.sagemaker_client response = sm_client.describe_model_package(ModelPackageName=model) transformed_response = transform(response, "DescribeModelPackageOutput") model_package = ModelPackage(**transformed_response) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 596526338b..70e5cd44b6 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -296,13 +296,12 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo: import logging logger = logging.getLogger(__name__) - # Use the session's boto client directly to avoid SageMakerClient singleton - # which may cache a different region. - region = session.boto_session.region_name - sm_client = session.boto_session.client("sagemaker", region_name=region) + # Use the session's sagemaker_client which respects the session's region + # (avoids SageMakerClient singleton which may cache a different region). + sm_client = session.sagemaker_client response = sm_client.describe_model_package(ModelPackageName=model_package_arn) - logger.info(f"Retrieved ModelPackage in region: {region}") + logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}") # Deserialize and create ModelPackage object transformed_response = transform(response, "DescribeModelPackageOutput") diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 5b79c08e68..595fd0f936 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -222,12 +222,13 @@ def _validate_and_resolve_model_package_group(cls, v, values): if hasattr(session, 'boto_region_name') else 'us-west-2') - # Resolve directly via boto3 client to avoid SageMakerClient singleton - # which may cache a different region. - import boto3 as _boto3 - if session and hasattr(session, 'boto_session'): - sm_client = session.boto_session.client("sagemaker", region_name=region) + # Resolve using the session's sagemaker_client which respects + # the session's region (avoids SageMakerClient singleton + # which may cache a different region). + if session and hasattr(session, 'sagemaker_client'): + sm_client = session.sagemaker_client else: + import boto3 as _boto3 sm_client = _boto3.client("sagemaker", region_name=region) response = sm_client.describe_model_package_group(ModelPackageGroupName=v) arn = response["ModelPackageGroupArn"] diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 6dbdbb6db8..c311b2011f 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -175,7 +175,7 @@ def test__resolve_model_package_group_arn_with_name(self): mock_sm_client.describe_model_package_group.return_value = { "ModelPackageGroupArn": "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" } - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client result = _resolve_model_package_group_arn("test-group", mock_session) @@ -371,7 +371,7 @@ def test__resolve_model_and_name_with_model_package_arn(self): mock_session.boto_region_name = "us-east-1" # Set valid region mock_session.boto_session.region_name = "us-east-1" mock_sm_client = Mock() - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client # Mock describe_model_package response mock_sm_client.describe_model_package.return_value = { diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index ddaeb56331..ea53bdfdef 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -320,7 +320,7 @@ def test_resolve_arn_success(self, mock_validate, mock_get_session): # Mock boto client mock_sm_client = MagicMock() - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client mock_sm_client.describe_model_package.return_value = { "ModelPackageArn": arn, "InferenceSpecification": { @@ -374,7 +374,7 @@ def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_ses # Mock boto client mock_sm_client = MagicMock() - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: @@ -414,7 +414,7 @@ def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session): # Mock boto client mock_sm_client = MagicMock() - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: @@ -444,7 +444,7 @@ def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session): # Mock boto client mock_sm_client = MagicMock() - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 13060ebe0a..5691d11804 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -320,7 +320,7 @@ def test_model_package_group_name_resolution(self, mock_resolve, mock_session, m mock_sm_client.describe_model_package_group.return_value = { "ModelPackageGroupArn": DEFAULT_MODEL_PACKAGE_GROUP_ARN } - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client evaluator = BaseEvaluator( model=DEFAULT_MODEL, @@ -362,7 +362,7 @@ def test_model_package_group_name_not_found(self, mock_resolve, mock_session, mo # Mock the boto client to raise an exception mock_sm_client = MagicMock() mock_sm_client.describe_model_package_group.side_effect = Exception("Model package group not found") - mock_session.boto_session.client.return_value = mock_sm_client + mock_session.sagemaker_client = mock_sm_client with pytest.raises(ValidationError, match="Failed to resolve model package group name"): BaseEvaluator( From 1edeecba7b4d03e14fe72a8dc0fa6819b5445756 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 11:41:24 -0700 Subject: [PATCH 10/11] revert: remove SageMakerClient singleton bypass from feature code --- .../train/common_utils/finetune_utils.py | 27 ++- .../train/common_utils/model_resolution.py | 16 +- .../train/evaluate/base_evaluator.py | 25 ++- .../train/common_utils/test_finetune_utils.py | 46 ++-- .../common_utils/test_model_resolution.py | 197 ++++++++---------- .../train/evaluate/test_base_evaluator.py | 27 ++- 6 files changed, 146 insertions(+), 192 deletions(-) diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index c1b6c8f422..6479e803bd 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -338,14 +338,13 @@ def _resolve_model_package_group_arn(model_package_group_name_or_arn, sagemaker_ # It's already an ARN return model_package_group_name_or_arn else: - # It's a name, resolve to ARN using the session's sagemaker_client - # which respects the session's region (avoids SageMakerClient singleton - # which may cache a different region). - sm_client = sagemaker_session.sagemaker_client - response = sm_client.describe_model_package_group( - ModelPackageGroupName=model_package_group_name_or_arn + # It's a name, resolve to ARN + model_package_group = ModelPackageGroup.get( + model_package_group_name=model_package_group_name_or_arn, + session=sagemaker_session.boto_session, + region=sagemaker_session.boto_session.region_name ) - return response["ModelPackageGroupArn"] + return model_package_group.model_package_group_arn else: # It's a ModelPackageGroup object return model_package_group_name_or_arn.model_package_group_arn @@ -582,14 +581,12 @@ def _resolve_model_and_name(model, sagemaker_session=None): if isinstance(model, str): # Check if it's a model package ARN if model.startswith("arn:aws:sagemaker:") and ":model-package/" in model: - # Get ModelPackage object from ARN using the session's sagemaker_client - # which respects the session's region (avoids SageMakerClient singleton - # which may cache a different region). - from sagemaker.core.utils.code_injection.codec import transform - sm_client = sagemaker_session.sagemaker_client - response = sm_client.describe_model_package(ModelPackageName=model) - transformed_response = transform(response, "DescribeModelPackageOutput") - model_package = ModelPackage(**transformed_response) + # Get ModelPackage object from ARN + model_package = ModelPackage.get( + model_package_name=model, + session=sagemaker_session.boto_session if sagemaker_session else None, + region=sagemaker_session.boto_session.region_name if sagemaker_session else None + ) model_name = _resolve_model_name(model_package) # Validate region availability if region_name: diff --git a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py index 70e5cd44b6..8e2bee6971 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/model_resolution.py @@ -290,23 +290,21 @@ def _resolve_model_package_arn(self, model_package_arn: str) -> _ModelInfo: # Validate ARN format self._validate_model_package_arn(model_package_arn) + # Use sagemaker.core ModelPackage.get() to retrieve model package information from sagemaker.core.resources import ModelPackage - from sagemaker.core.utils.code_injection.codec import transform import logging logger = logging.getLogger(__name__) - # Use the session's sagemaker_client which respects the session's region - # (avoids SageMakerClient singleton which may cache a different region). - sm_client = session.sagemaker_client - response = sm_client.describe_model_package(ModelPackageName=model_package_arn) + # Get the model package using sagemaker.core + model_package = ModelPackage.get( + model_package_name=model_package_arn, + session=session.boto_session, + region=session.boto_session.region_name + ) logger.info(f"Retrieved ModelPackage in region: {session.boto_session.region_name}") - # Deserialize and create ModelPackage object - transformed_response = transform(response, "DescribeModelPackageOutput") - model_package = ModelPackage(**transformed_response) - # Now use the existing _resolve_model_package_object method to extract base model info return self._resolve_model_package_object(model_package) diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 595fd0f936..7d7fac006d 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -222,18 +222,21 @@ def _validate_and_resolve_model_package_group(cls, v, values): if hasattr(session, 'boto_region_name') else 'us-west-2') - # Resolve using the session's sagemaker_client which respects - # the session's region (avoids SageMakerClient singleton - # which may cache a different region). - if session and hasattr(session, 'sagemaker_client'): - sm_client = session.sagemaker_client + # Fetch the object + obj = ModelPackageGroup.get( + model_package_group_name=v, + region=region + ) + + # Extract ARN + if hasattr(obj, 'model_package_group_arn'): + arn = obj.model_package_group_arn + _logger.info(f"Resolved model package group name '{v}' to ARN: {arn}") + return arn else: - import boto3 as _boto3 - sm_client = _boto3.client("sagemaker", region_name=region) - response = sm_client.describe_model_package_group(ModelPackageGroupName=v) - arn = response["ModelPackageGroupArn"] - _logger.info(f"Resolved model package group name '{v}' to ARN: {arn}") - return arn + raise ValueError( + f"ModelPackageGroup object for name '{v}' does not have model_package_group_arn attribute" + ) except Exception as e: raise ValueError( diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index c311b2011f..c98dea477f 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -168,21 +168,17 @@ def test__validate_model_package_group_requirement_without_group_name(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_model_package_group_requirement("string-model", None) - def test__resolve_model_package_group_arn_with_name(self): + @patch('sagemaker.core.resources.ModelPackageGroup.get') + def test__resolve_model_package_group_arn_with_name(self, mock_get): mock_session = Mock() mock_session.boto_session.region_name = "us-east-1" - mock_sm_client = Mock() - mock_sm_client.describe_model_package_group.return_value = { - "ModelPackageGroupArn": "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - } - mock_session.sagemaker_client = mock_sm_client + mock_group = Mock() + mock_group.model_package_group_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" + mock_get.return_value = mock_group result = _resolve_model_package_group_arn("test-group", mock_session) - assert result == "arn:aws:sagemaker:us-east-1:123456789012:model-package-group/test-group" - mock_sm_client.describe_model_package_group.assert_called_once_with( - ModelPackageGroupName="test-group" - ) + assert result == mock_group.model_package_group_arn def test__resolve_model_package_group_arn_with_arn(self): mock_session = Mock() @@ -366,35 +362,23 @@ def test__validate_and_resolve_model_package_group_missing_both(self): with pytest.raises(ValueError, match="model_package_group_name must be provided"): _validate_and_resolve_model_package_group("string-model", None) - def test__resolve_model_and_name_with_model_package_arn(self): + @patch('sagemaker.core.resources.ModelPackage.get') + def test__resolve_model_and_name_with_model_package_arn(self, mock_get): mock_session = Mock() mock_session.boto_region_name = "us-east-1" # Set valid region - mock_session.boto_session.region_name = "us-east-1" - mock_sm_client = Mock() - mock_session.sagemaker_client = mock_sm_client - - # Mock describe_model_package response - mock_sm_client.describe_model_package.return_value = { - "ModelPackageArn": "arn:aws:sagemaker:us-east-1:123456789012:model-package/test", - } - - # Mock transform and ModelPackage constructor + mock_model_package = Mock(spec=ModelPackage) mock_container = Mock() mock_base_model = Mock() mock_base_model.hub_content_name = "test-model" mock_container.base_model = mock_base_model - mock_inference_spec = Mock() - mock_inference_spec.containers = [mock_container] - - mock_model_package = Mock(spec=ModelPackage) - mock_model_package.inference_specification = mock_inference_spec + mock_model_package.inference_specification = Mock() + mock_model_package.inference_specification.containers = [mock_container] + mock_get.return_value = mock_model_package - with patch('sagemaker.core.utils.code_injection.codec.transform', return_value={}): - with patch('sagemaker.train.common_utils.finetune_utils.ModelPackage', return_value=mock_model_package): - model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) + model, name = _resolve_model_and_name("arn:aws:sagemaker:us-east-1:123456789012:model-package/test", mock_session) - assert model == mock_model_package - assert name == "test-model" + assert model == mock_model_package + assert name == "test-model" def test__resolve_model_and_name_with_string(self): model, name = _resolve_model_and_name("test-model") diff --git a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py index ea53bdfdef..8a57dc2d28 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_model_resolution.py @@ -307,10 +307,11 @@ def test_resolve_package_fallback_name(self): class TestResolveModelPackageArn: """Tests for _resolve_model_package_arn method.""" + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_success(self, mock_validate, mock_get_session): - """Test successful ARN resolution.""" + def test_resolve_arn_success(self, mock_validate, mock_get_session, mock_model_package_class): + """Test successful ARN resolution using ModelPackage.get().""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" # Mock session @@ -318,52 +319,40 @@ def test_resolve_arn_success(self, mock_validate, mock_get_session): mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock boto client - mock_sm_client = MagicMock() - mock_session.sagemaker_client = mock_sm_client - mock_sm_client.describe_model_package.return_value = { - "ModelPackageArn": arn, - "InferenceSpecification": { - "Containers": [{ - "BaseModel": { - "HubContentName": "base-model", - "HubContentVersion": "1.0", - "HubContentArn": "arn:aws:sagemaker:us-west-2:aws:hub-content/base" - } - }] - } - } - - # Mock transform to return a proper ModelPackage-like object - with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: - mock_transformed = {} - mock_transform.return_value = mock_transformed - - with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: - mock_package = MagicMock() - mock_package.model_package_arn = arn - mock_container = MagicMock() - mock_base_model = MagicMock() - mock_base_model.hub_content_name = 'base-model' - mock_base_model.hub_content_version = '1.0' - mock_base_model.hub_content_arn = 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' - mock_container.base_model = mock_base_model - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] - mock_mp_class.return_value = mock_package - - resolver = _ModelResolver() - result = resolver._resolve_model_package_arn(arn) - - assert result.base_model_name == "base-model" - assert result.hub_content_name == "base-model" - assert result.source_model_package_arn == arn - assert result.model_type == _ModelType.FINE_TUNED - mock_sm_client.describe_model_package.assert_called_once_with(ModelPackageName=arn) + # Mock ModelPackage.get() return value + mock_package = MagicMock() + mock_package.model_package_arn = arn + + # Mock inference specification with hub_content_arn + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = 'arn:aws:sagemaker:us-west-2:aws:hub-content/base' + mock_container.base_model = mock_base_model + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package + + resolver = _ModelResolver() + result = resolver._resolve_model_package_arn(arn) + + assert result.base_model_name == "base-model" + assert result.hub_content_name == "base-model" + assert result.source_model_package_arn == arn + assert result.model_type == _ModelType.FINE_TUNED + mock_model_package_class.get.assert_called_once_with( + model_package_name=arn, + session=mock_session.boto_session, + region='us-west-2' + ) + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session): + def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_session, mock_model_package_class): """Test ARN resolution when HubContentArn needs to be constructed.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -372,38 +361,35 @@ def test_resolve_arn_construct_hub_content_arn(self, mock_validate, mock_get_ses mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock boto client - mock_sm_client = MagicMock() - mock_session.sagemaker_client = mock_sm_client - mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} + # Mock ModelPackage without hub_content_arn (needs to be constructed) + mock_package = MagicMock() + mock_package.model_package_arn = arn - with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: - mock_transform.return_value = {} - - with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: - mock_package = MagicMock() - mock_package.model_package_arn = arn - mock_container = MagicMock() - mock_base_model = MagicMock() - mock_base_model.hub_content_name = 'base-model' - mock_base_model.hub_content_version = '1.0' - mock_base_model.hub_content_arn = None # Not provided, needs construction - mock_container.base_model = mock_base_model - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] - mock_mp_class.return_value = mock_package - - resolver = _ModelResolver() - result = resolver._resolve_model_package_arn(arn) - - expected_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/base-model/1.0" - assert result.base_model_arn == expected_arn - assert result.base_model_name == "base-model" - assert result.hub_content_name == "base-model" + mock_container = MagicMock() + mock_base_model = MagicMock() + mock_base_model.hub_content_name = 'base-model' + mock_base_model.hub_content_version = '1.0' + mock_base_model.hub_content_arn = None # Not provided, needs construction + mock_container.base_model = mock_base_model + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package + + resolver = _ModelResolver() + result = resolver._resolve_model_package_arn(arn) + + # Should construct ARN from region and hub content name/version + expected_arn = "arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/base-model/1.0" + assert result.base_model_arn == expected_arn + assert result.base_model_name == "base-model" + assert result.hub_content_name == "base-model" + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session): + def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session, mock_model_package_class): """Test error when InferenceSpecification is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -412,28 +398,22 @@ def test_resolve_arn_no_inference_spec(self, mock_validate, mock_get_session): mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock boto client - mock_sm_client = MagicMock() - mock_session.sagemaker_client = mock_sm_client - mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} + # Mock ModelPackage without inference_specification + mock_package = MagicMock() + mock_package.model_package_arn = arn + mock_package.inference_specification = None - with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: - mock_transform.return_value = {} - - with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: - mock_package = MagicMock() - mock_package.model_package_arn = arn - mock_package.inference_specification = None - mock_mp_class.return_value = mock_package - - resolver = _ModelResolver() - - with pytest.raises(ValueError, match="NotSupported.*does not have an inference_specification"): - resolver._resolve_model_package_arn(arn) + mock_model_package_class.get.return_value = mock_package + + resolver = _ModelResolver() + + with pytest.raises(ValueError, match="NotSupported.*does not have an inference_specification"): + resolver._resolve_model_package_arn(arn) + @patch('sagemaker.core.resources.ModelPackage') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._get_session') @patch('sagemaker.train.common_utils.model_resolution._ModelResolver._validate_model_package_arn') - def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session): + def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session, mock_model_package_class): """Test error when BaseModel is missing.""" arn = "arn:aws:sagemaker:us-west-2:123456789012:model-package/my-model/1" @@ -442,27 +422,22 @@ def test_resolve_arn_no_base_model(self, mock_validate, mock_get_session): mock_session.boto_session.region_name = 'us-west-2' mock_get_session.return_value = mock_session - # Mock boto client - mock_sm_client = MagicMock() - mock_session.sagemaker_client = mock_sm_client - mock_sm_client.describe_model_package.return_value = {"ModelPackageArn": arn} + # Mock ModelPackage with container but no base_model + mock_package = MagicMock() + mock_package.model_package_arn = arn - with patch('sagemaker.core.utils.code_injection.codec.transform') as mock_transform: - mock_transform.return_value = {} - - with patch('sagemaker.core.resources.ModelPackage') as mock_mp_class: - mock_package = MagicMock() - mock_package.model_package_arn = arn - mock_container = MagicMock() - mock_container.base_model = None - mock_package.inference_specification = MagicMock() - mock_package.inference_specification.containers = [mock_container] - mock_mp_class.return_value = mock_package - - resolver = _ModelResolver() - - with pytest.raises(ValueError, match="NotSupported.*does not have base_model metadata"): - resolver._resolve_model_package_arn(arn) + mock_container = MagicMock() + mock_container.base_model = None + + mock_package.inference_specification = MagicMock() + mock_package.inference_specification.containers = [mock_container] + + mock_model_package_class.get.return_value = mock_package + + resolver = _ModelResolver() + + with pytest.raises(ValueError, match="NotSupported.*does not have base_model metadata"): + resolver._resolve_model_package_arn(arn) class TestValidateModelPackageArn: diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index 5691d11804..1f690be8a9 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -311,16 +311,15 @@ def test_model_package_group_arn_valid(self, mock_resolve, mock_session, mock_mo assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_model_package_group_name_resolution(self, mock_resolve, mock_session, mock_model_info): + @patch("sagemaker.core.resources.ModelPackageGroup.get") + def test_model_package_group_name_resolution(self, mock_mpg_get, mock_resolve, mock_session, mock_model_info): """Test model package group name resolution to ARN.""" mock_resolve.return_value = mock_model_info - # Mock the boto client's describe_model_package_group call - mock_sm_client = MagicMock() - mock_sm_client.describe_model_package_group.return_value = { - "ModelPackageGroupArn": DEFAULT_MODEL_PACKAGE_GROUP_ARN - } - mock_session.sagemaker_client = mock_sm_client + # Mock ModelPackageGroup.get to return an object with ARN + mock_mpg = MagicMock() + mock_mpg.model_package_group_arn = DEFAULT_MODEL_PACKAGE_GROUP_ARN + mock_mpg_get.return_value = mock_mpg evaluator = BaseEvaluator( model=DEFAULT_MODEL, @@ -332,8 +331,9 @@ def test_model_package_group_name_resolution(self, mock_resolve, mock_session, m ) assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN - mock_sm_client.describe_model_package_group.assert_called_once_with( - ModelPackageGroupName="my-package" + mock_mpg_get.assert_called_once_with( + model_package_group_name="my-package", + region=DEFAULT_REGION, ) @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") @@ -355,14 +355,11 @@ def test_model_package_group_object_resolution(self, mock_resolve, mock_session, assert evaluator.model_package_group == DEFAULT_MODEL_PACKAGE_GROUP_ARN @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_model_package_group_name_not_found(self, mock_resolve, mock_session, mock_model_info): + @patch("sagemaker.core.resources.ModelPackageGroup.get") + def test_model_package_group_name_not_found(self, mock_mpg_get, mock_resolve, mock_session, mock_model_info): """Test model package group name that doesn't exist.""" mock_resolve.return_value = mock_model_info - - # Mock the boto client to raise an exception - mock_sm_client = MagicMock() - mock_sm_client.describe_model_package_group.side_effect = Exception("Model package group not found") - mock_session.sagemaker_client = mock_sm_client + mock_mpg_get.side_effect = Exception("Model package group not found") with pytest.raises(ValidationError, match="Failed to resolve model package group name"): BaseEvaluator( From 55e7a773d99f105a3d2044e22795e5564d722d54 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Thu, 4 Jun 2026 14:52:19 -0700 Subject: [PATCH 11/11] test: mark TestLLMAsJudgeBaseModelFix as serial Tests share the same pipeline definition and conflict when run in parallel (Pipeline has been modified since your last read). X-AI-Prompt: mark llm_as_judge_base_model_fix as serial X-AI-Tool: kiro-cli --- .../tests/integ/train/test_llm_as_judge_base_model_fix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py index f6cb0c5183..c7f2445650 100644 --- a/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py +++ b/sagemaker-train/tests/integ/train/test_llm_as_judge_base_model_fix.py @@ -77,6 +77,7 @@ } +@pytest.mark.serial class TestLLMAsJudgeBaseModelFix: """Integration test for base model fix in LLMAsJudgeEvaluator"""