Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions sagemaker-mlops/src/sagemaker/mlops/local/local_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def __init__(self, *args, **kwargs):
Accepts the same arguments as LocalSession.
"""
super().__init__(*args, **kwargs)
# Add pipeline storage to the sagemaker_client
if not hasattr(self.sagemaker_client, '_pipelines'):
self.sagemaker_client._pipelines = {}
# Store pipeline registry on the session instance itself
# (not on sagemaker_client, to avoid mutating a shared client instance)
self._pipelines = {}

@_telemetry_emitter(Feature.LOCAL_MODE, "local_pipeline_session.create_pipeline")
def create_pipeline(
Expand All @@ -68,7 +68,7 @@ def create_pipeline(
pipeline_description=pipeline_description,
local_session=self,
)
self.sagemaker_client._pipelines[pipeline.name] = local_pipeline
self._pipelines[pipeline.name] = local_pipeline
return {"PipelineArn": pipeline.name}

def update_pipeline(
Expand All @@ -83,17 +83,17 @@ def update_pipeline(
Returns:
Pipeline metadata (PipelineArn)
"""
if pipeline.name not in self.sagemaker_client._pipelines:
if pipeline.name not in self._pipelines:
error_response = {
"Error": {
"Code": "ResourceNotFound",
"Message": "Pipeline {} does not exist".format(pipeline.name),
}
}
raise ClientError(error_response, "update_pipeline")
self.sagemaker_client._pipelines[pipeline.name].pipeline_description = pipeline_description
self.sagemaker_client._pipelines[pipeline.name].pipeline = pipeline
self.sagemaker_client._pipelines[pipeline.name].last_modified_time = (
self._pipelines[pipeline.name].pipeline_description = pipeline_description
self._pipelines[pipeline.name].pipeline = pipeline
self._pipelines[pipeline.name].last_modified_time = (
datetime.now().timestamp()
)
return {"PipelineArn": pipeline.name}
Expand All @@ -107,15 +107,15 @@ def describe_pipeline(self, PipelineName):
Returns:
Pipeline metadata (PipelineArn, PipelineDefinition, LastModifiedTime, etc)
"""
if PipelineName not in self.sagemaker_client._pipelines:
if PipelineName not in self._pipelines:
error_response = {
"Error": {
"Code": "ResourceNotFound",
"Message": "Pipeline {} does not exist".format(PipelineName),
}
}
raise ClientError(error_response, "describe_pipeline")
return self.sagemaker_client._pipelines[PipelineName].describe()
return self._pipelines[PipelineName].describe()

def delete_pipeline(self, PipelineName):
"""Delete the local pipeline.
Expand All @@ -126,8 +126,8 @@ def delete_pipeline(self, PipelineName):
Returns:
Pipeline metadata (PipelineArn)
"""
if PipelineName in self.sagemaker_client._pipelines:
del self.sagemaker_client._pipelines[PipelineName]
if PipelineName in self._pipelines:
del self._pipelines[PipelineName]
return {"PipelineArn": PipelineName}

def start_pipeline_execution(self, PipelineName, **kwargs):
Expand All @@ -143,12 +143,12 @@ def start_pipeline_execution(self, PipelineName, **kwargs):
logger.warning("Parallelism configuration is not supported in local mode.")
if "SelectiveExecutionConfig" in kwargs:
raise ValueError("SelectiveExecutionConfig is not supported in local mode.")
if PipelineName not in self.sagemaker_client._pipelines:
if PipelineName not in self._pipelines:
error_response = {
"Error": {
"Code": "ResourceNotFound",
"Message": "Pipeline {} does not exist".format(PipelineName),
}
}
raise ClientError(error_response, "start_pipeline_execution")
return self.sagemaker_client._pipelines[PipelineName].start(**kwargs)
return self._pipelines[PipelineName].start(**kwargs)
48 changes: 26 additions & 22 deletions sagemaker-mlops/tests/unit/local/test_local_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def mock_pipeline():
def local_session():
def mock_init(self, *args, **kwargs):
self.sagemaker_client = Mock()
self.sagemaker_client._pipelines = {}
self._pipelines = {}

with patch.object(LocalPipelineSession, '__init__', mock_init):
session = LocalPipelineSession()
Expand All @@ -47,22 +47,26 @@ def mock_parent_init(self, *args, **kwargs):
with patch('sagemaker.core.local.LocalSession.__init__', mock_parent_init):
session = LocalPipelineSession()

# Verify _pipelines attribute is created as a dict
assert hasattr(session.sagemaker_client, '_pipelines')
assert session.sagemaker_client._pipelines == {}
# Verify _pipelines attribute is created on the session instance as a dict
assert hasattr(session, '_pipelines')
assert session._pipelines == {}


def test_local_pipeline_session_init_with_existing_pipelines():
"""Test LocalPipelineSession initialization when _pipelines already exists."""
def test_sessions_do_not_share_pipelines_registry():
"""Test that two LocalPipelineSession instances have independent _pipelines dicts."""
def mock_parent_init(self, *args, **kwargs):
self.sagemaker_client = Mock()
self.sagemaker_client._pipelines = {"existing": "pipeline"}
self.sagemaker_client = Mock() # Shared client mock

with patch('sagemaker.core.local.LocalSession.__init__', mock_parent_init):
session = LocalPipelineSession()
session1 = LocalPipelineSession()
session2 = LocalPipelineSession()

# Each session should have its own _pipelines dict
session1._pipelines["pipeline-a"] = "value-a"

# Should not overwrite existing _pipelines
assert session.sagemaker_client._pipelines == {"existing": "pipeline"}
assert "pipeline-a" in session1._pipelines
assert "pipeline-a" not in session2._pipelines
assert session1._pipelines is not session2._pipelines


def test_create_pipeline(local_session, mock_pipeline):
Expand All @@ -75,8 +79,8 @@ def test_create_pipeline(local_session, mock_pipeline):
result = LocalPipelineSession.create_pipeline(local_session, mock_pipeline, "Test pipeline description")

assert result == {"PipelineArn": "test-pipeline"}
assert "test-pipeline" in local_session.sagemaker_client._pipelines
assert local_session.sagemaker_client._pipelines["test-pipeline"] == mock_local_pipeline_instance
assert "test-pipeline" in local_session._pipelines
assert local_session._pipelines["test-pipeline"] == mock_local_pipeline_instance

mock_local_pipeline.assert_called_once_with(
pipeline=mock_pipeline,
Expand Down Expand Up @@ -109,7 +113,7 @@ def test_update_pipeline(local_session, mock_pipeline):
mock_local_pipeline.pipeline = Mock()
mock_local_pipeline.last_modified_time = 1000.0

local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

new_pipeline = Mock()
new_pipeline.name = "test-pipeline"
Expand All @@ -135,7 +139,7 @@ def test_update_pipeline_not_found(local_session, mock_pipeline):
def test_update_pipeline_with_kwargs(local_session, mock_pipeline):
"""Test update_pipeline ignores extra kwargs."""
mock_local_pipeline = Mock()
local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.update_pipeline(
local_session,
Expand All @@ -156,7 +160,7 @@ def test_describe_pipeline(local_session):
"LastModifiedTime": 1234567890
})

local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.describe_pipeline(local_session, "test-pipeline")

Expand All @@ -178,12 +182,12 @@ def test_describe_pipeline_not_found(local_session):
def test_delete_pipeline(local_session):
"""Test delete_pipeline removes pipeline."""
mock_local_pipeline = Mock()
local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.delete_pipeline(local_session, "test-pipeline")

assert result == {"PipelineArn": "test-pipeline"}
assert "test-pipeline" not in local_session.sagemaker_client._pipelines
assert "test-pipeline" not in local_session._pipelines


def test_delete_pipeline_not_found(local_session):
Expand All @@ -199,7 +203,7 @@ def test_start_pipeline_execution(local_session):
mock_execution = Mock()
mock_local_pipeline.start = Mock(return_value=mock_execution)

local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.start_pipeline_execution(local_session, "test-pipeline")

Expand All @@ -213,7 +217,7 @@ def test_start_pipeline_execution_with_kwargs(local_session):
mock_execution = Mock()
mock_local_pipeline.start = Mock(return_value=mock_execution)

local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.start_pipeline_execution(
local_session,
Expand All @@ -235,7 +239,7 @@ def test_start_pipeline_execution_with_parallelism_config(local_session, caplog)
mock_execution = Mock()
mock_local_pipeline.start = Mock(return_value=mock_execution)

local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

result = LocalPipelineSession.start_pipeline_execution(
local_session,
Expand All @@ -250,7 +254,7 @@ def test_start_pipeline_execution_with_parallelism_config(local_session, caplog)
def test_start_pipeline_execution_with_selective_execution_config(local_session):
"""Test start_pipeline_execution raises error for selective execution config."""
mock_local_pipeline = Mock()
local_session.sagemaker_client._pipelines["test-pipeline"] = mock_local_pipeline
local_session._pipelines["test-pipeline"] = mock_local_pipeline

with pytest.raises(ValueError) as exc_info:
LocalPipelineSession.start_pipeline_execution(
Expand Down
Loading