diff --git a/sagemaker-serve/src/sagemaker/serve/validations/check_image_and_hardware_type.py b/sagemaker-serve/src/sagemaker/serve/validations/check_image_and_hardware_type.py index 0046e47a80..f72b678178 100644 --- a/sagemaker-serve/src/sagemaker/serve/validations/check_image_and_hardware_type.py +++ b/sagemaker-serve/src/sagemaker/serve/validations/check_image_and_hardware_type.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import logging +from sagemaker.core.helper.pipeline_variable import PipelineVariable from sagemaker.serve.utils.types import ModelServer, HardwareType logger = logging.getLogger(__name__) @@ -35,6 +36,10 @@ def validate_image_uri_and_hardware(image_uri: str, instance_type: str, model_server: ModelServer): """Placeholder docstring""" + if isinstance(image_uri, PipelineVariable): + # Skip validation since the value is not known at build time + return + if "xgboost" in image_uri: # xgboost container does not care about hardware type # hence skipping validation diff --git a/sagemaker-serve/src/sagemaker/serve/validations/check_image_uri.py b/sagemaker-serve/src/sagemaker/serve/validations/check_image_uri.py index 2f50faaeed..f3ffa7550d 100644 --- a/sagemaker-serve/src/sagemaker/serve/validations/check_image_uri.py +++ b/sagemaker-serve/src/sagemaker/serve/validations/check_image_uri.py @@ -1,6 +1,9 @@ """Validates that a given image_uri is not a 1p image.""" from __future__ import absolute_import +from typing import Union + +from sagemaker.core.helper.pipeline_variable import PipelineVariable # Generated by running the parse_registry_accounts.py script all_accounts = { @@ -296,7 +299,9 @@ } -def is_1p_image_uri(image_uri: str) -> bool: +def is_1p_image_uri(image_uri: Union[str, PipelineVariable]) -> bool: """Shows if the given image_uri is owned by a 1st party account""" + if isinstance(image_uri, PipelineVariable): + return False image_uri_account = image_uri[0:12] return image_uri_account in all_accounts diff --git a/sagemaker-serve/tests/unit/validations/test_check_image_uri.py b/sagemaker-serve/tests/unit/validations/test_check_image_uri.py index c978d74a5f..bd1b1aa9ee 100644 --- a/sagemaker-serve/tests/unit/validations/test_check_image_uri.py +++ b/sagemaker-serve/tests/unit/validations/test_check_image_uri.py @@ -1,4 +1,8 @@ import unittest +from unittest.mock import Mock + +from sagemaker.core.helper.pipeline_variable import PipelineVariable +from sagemaker.core.workflow.parameters import ParameterString from sagemaker.serve.validations.check_image_uri import is_1p_image_uri, all_accounts @@ -27,6 +31,21 @@ def test_all_accounts_contains_known_accounts(self): self.assertIn("763104351884", all_accounts) self.assertIn("246618743249", all_accounts) + def test_is_1p_image_uri_with_parameter_string_returns_false(self): + # ParameterString should not raise TypeError and should return False + param = ParameterString( + "image_uri", + default_value="763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch:latest", + ) + result = is_1p_image_uri(param) + self.assertFalse(result) + + def test_is_1p_image_uri_with_pipeline_variable_returns_false(self): + # A mock PipelineVariable instance should return False without raising + mock_pipeline_var = Mock(spec=PipelineVariable) + result = is_1p_image_uri(mock_pipeline_var) + self.assertFalse(result) + if __name__ == "__main__": unittest.main()