Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,39 @@ def _update_pipeline_lineage(

# If pipeline lineage exists then determine whether to create a new version.
pipeline_context: Context = self._get_pipeline_context()
current_pipeline_version_context: Context = self._get_pipeline_version_context(
last_update_time=pipeline_context.properties[LAST_UPDATE_TIME]
)
try:
current_pipeline_version_context: Context = self._get_pipeline_version_context(
last_update_time=pipeline_context.properties[LAST_UPDATE_TIME]
)
except ClientError as e:
if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND:
# Pipeline version context does not exist (possibly deleted or never created).
# Create a new pipeline version context and its associations.
logger.info(
"Pipeline version context not found. Creating new pipeline version lineage."
)
pipeline_context.properties["LastUpdateTime"] = self.pipeline[
LAST_MODIFIED_TIME
].strftime("%s")
PipelineLineageEntityHandler.update_pipeline_context(
pipeline_context=pipeline_context
)
new_pipeline_version_context: Context = self._create_pipeline_version_lineage()
self._add_associations_for_pipeline(
pipeline_context_arn=pipeline_context.context_arn,
pipeline_versions_context_arn=new_pipeline_version_context.context_arn,
input_feature_group_contexts=input_feature_group_contexts,
input_raw_data_artifacts=input_raw_data_artifacts,
output_feature_group_contexts=output_feature_group_contexts,
transformation_code_artifact=transformation_code_artifact,
)
LineageAssociationHandler.add_pipeline_and_pipeline_version_association(
pipeline_context_arn=pipeline_context.context_arn,
pipeline_version_context_arn=new_pipeline_version_context.context_arn,
sagemaker_session=self.sagemaker_session,
)
return
raise e
upstream_feature_group_associations: Iterator[AssociationSummary] = (
LineageAssociationHandler.list_upstream_associations(
# pylint: disable=no-member
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn: Optiona
mlflow_apps_list = []
paginator = sm_client.get_paginator("list_mlflow_apps")
for page in paginator.paginate():
mlflow_apps_list.extend(page.get("MlflowApps", []))
mlflow_apps_list.extend(page.get("Summaries", []))

logger.info("Found %d MLflow apps: %s", len(mlflow_apps_list),
[(a.get("Name", "?"), a.get("Status", "?"), a.get("MlflowVersion", "?")) for a in mlflow_apps_list])
Expand Down
89 changes: 89 additions & 0 deletions sagemaker-train/tests/integ/train/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,92 @@ def sagemaker_session_us_east_1():
"""Create a SageMaker session in us-east-1 for Nova model tests."""
boto_session = boto3.Session(region_name=NOVA_REGION)
return Session(boto_session=boto_session)


import time
import logging

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module")
def mlflow_resource_arn():
"""Discover or create an MLflow app for integ tests, clean up if created.

Looks for an existing MLflow app in Created/Updated state. If none exists,
creates one and deletes it after the test module finishes.
"""
region = os.environ.get("AWS_DEFAULT_REGION", DEFAULT_REGION)
sm_client = boto3.client("sagemaker", region_name=region)
created_arn = None

# Try to find an existing ready app
try:
paginator = sm_client.get_paginator("list_mlflow_apps")
for page in paginator.paginate():
for app in page.get("Summaries", []):
if app.get("Status") in ("Created", "Updated"):
logger.info(f"Using existing MLflow app: {app['Arn']}")
yield app["Arn"]
return
except Exception as e:
logger.warning(f"Failed to list MLflow apps: {e}")

# No ready app found — create one
logger.info("No ready MLflow app found. Creating one for integ tests...")
sts_client = boto3.client("sts", region_name=region)
account_id = sts_client.get_caller_identity()["Account"]
app_name = f"integ-test-mlflow-{int(time.time())}"
artifact_store_uri = f"s3://sagemaker-{region}-{account_id}/mlflow-artifacts"

# Ensure bucket/prefix exists
s3_client = boto3.client("s3", region_name=region)
bucket_name = f"sagemaker-{region}-{account_id}"
try:
s3_client.head_bucket(Bucket=bucket_name)
except Exception:
if region == "us-east-1":
s3_client.create_bucket(Bucket=bucket_name)
else:
s3_client.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": region},
)
try:
s3_client.put_object(Bucket=bucket_name, Key="mlflow-artifacts/")
except Exception:
pass

# Get execution role
from sagemaker.train.defaults import TrainDefaults
boto_session = boto3.Session(region_name=region)
sagemaker_session = Session(boto_session=boto_session)
role_arn = TrainDefaults.get_role(role=None, sagemaker_session=sagemaker_session)

resp = sm_client.create_mlflow_app(
Name=app_name,
ArtifactStoreUri=artifact_store_uri,
RoleArn=role_arn,
AccountDefaultStatus="DISABLED",
)
created_arn = resp["Arn"]
logger.info(f"Created MLflow app: {created_arn}")

# Wait for it to become ready
for _ in range(60):
desc = sm_client.describe_mlflow_app(Arn=created_arn)
status = desc.get("Status")
if status in ("Created", "Updated"):
break
if status in ("Failed", "CreateFailed", "DeleteFailed"):
pytest.skip(f"MLflow app creation failed: {desc.get('FailureReason')}")
time.sleep(10)

yield created_arn

# Cleanup
logger.info(f"Cleaning up MLflow app: {created_arn}")
try:
sm_client.delete_mlflow_app(Arn=created_arn)
except Exception as e:
logger.warning(f"Failed to delete MLflow app {created_arn}: {e}")
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
class TestLLMAsJudgeBaseModelFix:
"""Integration test for base model fix in LLMAsJudgeEvaluator"""

def test_base_model_evaluation_uses_correct_weights(self):
def test_base_model_evaluation_uses_correct_weights(self, mlflow_resource_arn):
"""
Test that base model evaluation uses original base model weights.

Expand Down Expand Up @@ -115,7 +115,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
custom_metrics=TEST_CONFIG["custom_metrics_json"],
s3_output_path=TEST_CONFIG["s3_output_path"],
evaluate_base_model=TEST_CONFIG["evaluate_base_model"],
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=mlflow_resource_arn,
)

# Verify evaluator configuration
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_base_model_evaluation_uses_correct_weights(self):
# Re-raise to fail the test
raise

def test_base_model_false_still_works(self):
def test_base_model_false_still_works(self, mlflow_resource_arn):
"""
Test that evaluate_base_model=False still works correctly (backward compatibility).

Expand All @@ -272,7 +272,7 @@ def test_base_model_false_still_works(self):
builtin_metrics=TEST_CONFIG["builtin_metrics"],
s3_output_path=TEST_CONFIG["s3_output_path"],
evaluate_base_model=False, # Only evaluate custom model
mlflow_resource_arn=TEST_CONFIG["mlflow_tracking_server_arn"],
mlflow_resource_arn=mlflow_resource_arn,
)

# Verify evaluator configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test__resolve_mlflow_resource_arn_creates_new_app(self, mock_get_client, moc
mock_get_domain.return_value = "d-123456789"
mock_sm_client = Mock()
mock_paginator = Mock()
mock_paginator.paginate.return_value = [{"MlflowApps": []}]
mock_paginator.paginate.return_value = [{"Summaries": []}]
mock_sm_client.get_paginator.return_value = mock_paginator
mock_get_client.return_value = mock_sm_client
expected_arn = "arn:aws:mlflow:us-east-1:123456789012:tracking-server/new-app"
Expand Down Expand Up @@ -633,7 +633,7 @@ def test_upgrades_when_below_min_version(self, mock_get_client, mock_upgrade, mo
}
mock_sm_client = Mock()
mock_paginator = Mock()
mock_paginator.paginate.return_value = [{"MlflowApps": [old_app]}]
mock_paginator.paginate.return_value = [{"Summaries": [old_app]}]
mock_sm_client.get_paginator.return_value = mock_paginator
mock_get_client.return_value = mock_sm_client

Expand All @@ -659,7 +659,7 @@ def test_no_upgrade_when_meets_version(self, mock_get_client, mock_domain):
}
mock_sm_client = Mock()
mock_paginator = Mock()
mock_paginator.paginate.return_value = [{"MlflowApps": [app]}]
mock_paginator.paginate.return_value = [{"Summaries": [app]}]
mock_sm_client.get_paginator.return_value = mock_paginator
mock_get_client.return_value = mock_sm_client

Expand Down
Loading