Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import
import logging

from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.serve.utils.types import ModelServer, HardwareType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -35,6 +36,10 @@

def validate_image_uri_and_hardware(image_uri: str, instance_type: str, model_server: ModelServer):
"""Placeholder docstring"""
if isinstance(image_uri, PipelineVariable):
# Skip validation since the value is not known at build time
return

if "xgboost" in image_uri:
# xgboost container does not care about hardware type
# hence skipping validation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Validates that a given image_uri is not a 1p image."""

from __future__ import absolute_import
from typing import Union

from sagemaker.core.helper.pipeline_variable import PipelineVariable

# Generated by running the parse_registry_accounts.py script
all_accounts = {
Expand Down Expand Up @@ -296,7 +299,9 @@
}


def is_1p_image_uri(image_uri: str) -> bool:
def is_1p_image_uri(image_uri: Union[str, PipelineVariable]) -> bool:
"""Shows if the given image_uri is owned by a 1st party account"""
if isinstance(image_uri, PipelineVariable):
return False
image_uri_account = image_uri[0:12]
return image_uri_account in all_accounts
19 changes: 19 additions & 0 deletions sagemaker-serve/tests/unit/validations/test_check_image_uri.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import unittest
from unittest.mock import Mock

from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.parameters import ParameterString
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri, all_accounts


Expand Down Expand Up @@ -27,6 +31,21 @@ def test_all_accounts_contains_known_accounts(self):
self.assertIn("763104351884", all_accounts)
self.assertIn("246618743249", all_accounts)

def test_is_1p_image_uri_with_parameter_string_returns_false(self):
# ParameterString should not raise TypeError and should return False
param = ParameterString(
"image_uri",
default_value="763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch:latest",
)
result = is_1p_image_uri(param)
self.assertFalse(result)

def test_is_1p_image_uri_with_pipeline_variable_returns_false(self):
# A mock PipelineVariable instance should return False without raising
mock_pipeline_var = Mock(spec=PipelineVariable)
result = is_1p_image_uri(mock_pipeline_var)
self.assertFalse(result)


if __name__ == "__main__":
unittest.main()
Loading