Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import time
import logging
import json
from typing import Optional
import time
from typing import Optional, List

import boto3
from sagemaker.core.resources import MlflowApp, ModelPackage, ModelPackageGroup
from sagemaker.core.helper.session_helper import Session
Expand Down Expand Up @@ -69,7 +69,7 @@ def _read_domain_id_from_metadata() -> Optional[str]:
metadata = json.load(f)
return metadata.get('DomainId')
except Exception as e:
logger.debug(f"Could not read Studio metadata file: {e}")
logger.debug("Could not read Studio metadata file: %s", e)
return None


Expand All @@ -94,7 +94,7 @@ def _get_current_domain_id(sagemaker_session) -> Optional[str]:
# ARN format: arn:aws:sagemaker:region:account:user-profile/domain-id/profile-name
return user_profile_arn.split('/')[1]
except Exception as e:
logger.debug(f"Could not extract domain ID from user profile ARN: {e}")
logger.debug("Could not extract domain ID from user profile ARN: %s", e)

return None

Expand Down Expand Up @@ -408,7 +408,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
v_copy['default'] = None # No default — won't appear in to_dict() unless set
options_dict[k] = v_copy
except Exception as e:
logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}")
logger.debug("Could not fetch subscription recipe override_params: %s: %s", type(e).__name__, e)

if options_dict:
return FineTuningOptions(options_dict), model_arn, is_gated_model
Expand All @@ -423,11 +423,15 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni
def _create_input_channels(dataset: str, content_type: Optional[str] = None,
input_compression_type: Optional[str] = None,
record_wrapper_type: Optional[str] = None,
input_mode: Optional[str] = None):
input_mode: Optional[str] = None) -> List[Channel]:
"""Create input channels from dataset (S3 URI or dataset ARN).

Args:
dataset: S3 URI (s3://bucket/key) or dataset ARN (arn:aws:sagemaker:...)
content_type: MIME type of the input data
input_compression_type: Compression type of the input data
record_wrapper_type: Record wrapper type
input_mode: Input mode for the channel

Returns:
list: List of Channel objects
Expand Down Expand Up @@ -486,7 +490,7 @@ def _resolve_model_and_name(model, sagemaker_session=None):
try:
import boto3
region_name = boto3.Session().region_name or os.environ.get('AWS_DEFAULT_REGION')
except:
except Exception:
pass

if isinstance(model, str):
Expand Down Expand Up @@ -518,8 +522,9 @@ def _resolve_model_and_name(model, sagemaker_session=None):
return model, model_name


def _create_serverless_config(model_arn, customization_technique,
training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']:
def _create_serverless_config(model_arn: str, customization_technique: str,
training_type, accept_eula: bool, evaluator_arn: Optional[str] = None,
job_type: str = JOB_TYPE) -> Optional[ServerlessJobConfig]:
"""Create serverless job configuration for fine-tuning.

Args:
Expand Down Expand Up @@ -606,8 +611,9 @@ def _create_model_package_config(model_package_group_name, model, sagemaker_sess



def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None,
mlflow_experiment_name=None, mlflow_run_name=None):
def _create_mlflow_config(sagemaker_session, mlflow_resource_arn: Optional[str] = None,
mlflow_experiment_name: Optional[str] = None,
mlflow_run_name: Optional[str] = None) -> Optional[MlflowConfig]:
"""Create MLflow configuration with resolved resource ARN.

Args:
Expand All @@ -623,7 +629,7 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None,

# Derive mlflow_resource_arn with default experience
resolved_mlflow_arn = _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn)
logger.info(f"MLflow resource ARN: {resolved_mlflow_arn}")
logger.info("MLflow resource ARN: %s", resolved_mlflow_arn)

# Create MlflowConfig using shapes
mlflow_config = None
Expand All @@ -637,12 +643,13 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None,
return mlflow_config


def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None):
def _create_output_config(sagemaker_session, s3_output_path: Optional[str] = None,
kms_key_id: Optional[str] = None) -> OutputDataConfig:
"""Create output data configuration with default S3 path if needed.

Args:
s3_output_path: S3 output path (if None, generates default path)
sagemaker_session: SageMaker session for generating default path
s3_output_path: S3 output path (if None, generates default path)
kms_key_id: Optional KMS key ID for encryption

Returns:
Expand All @@ -662,7 +669,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None
)


def _convert_input_data_to_channels(input_data_config ):
def _convert_input_data_to_channels(input_data_config: List[InputData]) -> List[Channel]:
"""Convert InputData objects to Channel objects with S3 and dataset ARN support.

Args:
Expand All @@ -689,9 +696,15 @@ def _convert_input_data_to_channels(input_data_config ):
dataset_source={"dataset_arn": input_data.data_source}
)

# Safely get content_type and compression_type from InputData
content_type = getattr(input_data, 'content_type', None)
compression_type = getattr(input_data, 'compression_type', None)

channel = Channel(
channel_name=input_data.channel_name,
data_source=data_source,
content_type=content_type,
compression_type=compression_type,
)
channels.append(channel)

Expand All @@ -713,7 +726,7 @@ def _validate_and_resolve_model_package_group(model, model_package_group_name):
"not a ModelPackage artifact/not continued finetuning")


def _validate_eula_for_gated_model(model, accept_eula, is_gated_model):
def _validate_eula_for_gated_model(model, accept_eula: bool, is_gated_model: bool) -> bool:
"""Validate EULA acceptance for gated models.

Args:
Expand Down Expand Up @@ -786,7 +799,6 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session):

def _validate_hyperparameter_values(hyperparameters: dict):
"""Validate hyperparameter values for allowed characters."""
import re
allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$"
for key, value in hyperparameters.items():
if isinstance(value, str) and not re.match(allowed_chars, value):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,45 @@ def test__create_output_config(self, mock_validate_s3):
mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session)

def test__convert_input_data_to_channels(self):
"""Test basic conversion of InputData to Channel, including content_type."""
input_data = [InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json")]
channels = _convert_input_data_to_channels(input_data)

assert len(channels) == 1
assert channels[0].channel_name == "train"
assert channels[0].content_type == "application/json"

def test__convert_input_data_to_channels_with_content_type_preserved(self):
"""Test that content_type is preserved when converting InputData to Channel."""
input_data = [
InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json"),
InputData(channel_name="validation", data_source="s3://bucket/val", content_type="text/csv"),
]
channels = _convert_input_data_to_channels(input_data)

assert len(channels) == 2
assert channels[0].channel_name == "train"
assert channels[0].content_type == "application/json"
assert channels[1].channel_name == "validation"
assert channels[1].content_type == "text/csv"

def test__convert_input_data_to_channels_without_content_type(self):
"""Test that omitting content_type doesn't cause issues."""
input_data = [InputData(channel_name="train", data_source="s3://bucket/data")]
channels = _convert_input_data_to_channels(input_data)

assert len(channels) == 1
assert channels[0].channel_name == "train"
assert channels[0].content_type is None

def test__convert_input_data_to_channels_without_compression_type(self):
"""Test that InputData without compression_type doesn't cause issues in conversion."""
input_data = [InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json")]
channels = _convert_input_data_to_channels(input_data)

assert len(channels) == 1
assert channels[0].channel_name == "train"
assert channels[0].compression_type is None

def test__validate_eula_for_gated_model_with_model_package(self):
"""Test EULA validation returns True for ModelPackage input"""
Expand Down
Loading