diff --git a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py index 0ea74ee207..882eb197f4 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/finetune_utils.py @@ -5,8 +5,8 @@ import time import logging import json -from typing import Optional -import time +from typing import Optional, List + import boto3 from sagemaker.core.resources import MlflowApp, ModelPackage, ModelPackageGroup from sagemaker.core.helper.session_helper import Session @@ -69,7 +69,7 @@ def _read_domain_id_from_metadata() -> Optional[str]: metadata = json.load(f) return metadata.get('DomainId') except Exception as e: - logger.debug(f"Could not read Studio metadata file: {e}") + logger.debug("Could not read Studio metadata file: %s", e) return None @@ -94,7 +94,7 @@ def _get_current_domain_id(sagemaker_session) -> Optional[str]: # ARN format: arn:aws:sagemaker:region:account:user-profile/domain-id/profile-name return user_profile_arn.split('/')[1] except Exception as e: - logger.debug(f"Could not extract domain ID from user profile ARN: {e}") + logger.debug("Could not extract domain ID from user profile ARN: %s", e) return None @@ -408,7 +408,7 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni v_copy['default'] = None # No default — won't appear in to_dict() unless set options_dict[k] = v_copy except Exception as e: - logger.debug(f"Could not fetch subscription recipe override_params: {type(e).__name__}: {e}") + logger.debug("Could not fetch subscription recipe override_params: %s: %s", type(e).__name__, e) if options_dict: return FineTuningOptions(options_dict), model_arn, is_gated_model @@ -423,11 +423,15 @@ def _get_fine_tuning_options_and_model_arn(model_name: str, customization_techni def _create_input_channels(dataset: str, content_type: Optional[str] = None, input_compression_type: Optional[str] = None, record_wrapper_type: Optional[str] = None, - input_mode: Optional[str] = None): + input_mode: Optional[str] = None) -> List[Channel]: """Create input channels from dataset (S3 URI or dataset ARN). Args: dataset: S3 URI (s3://bucket/key) or dataset ARN (arn:aws:sagemaker:...) + content_type: MIME type of the input data + input_compression_type: Compression type of the input data + record_wrapper_type: Record wrapper type + input_mode: Input mode for the channel Returns: list: List of Channel objects @@ -486,7 +490,7 @@ def _resolve_model_and_name(model, sagemaker_session=None): try: import boto3 region_name = boto3.Session().region_name or os.environ.get('AWS_DEFAULT_REGION') - except: + except Exception: pass if isinstance(model, str): @@ -518,8 +522,9 @@ def _resolve_model_and_name(model, sagemaker_session=None): return model, model_name -def _create_serverless_config(model_arn, customization_technique, - training_type, accept_eula, evaluator_arn=None, job_type=JOB_TYPE) -> Optional['ServerlessJobConfig']: +def _create_serverless_config(model_arn: str, customization_technique: str, + training_type, accept_eula: bool, evaluator_arn: Optional[str] = None, + job_type: str = JOB_TYPE) -> Optional[ServerlessJobConfig]: """Create serverless job configuration for fine-tuning. Args: @@ -606,8 +611,9 @@ def _create_model_package_config(model_package_group_name, model, sagemaker_sess -def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, - mlflow_experiment_name=None, mlflow_run_name=None): +def _create_mlflow_config(sagemaker_session, mlflow_resource_arn: Optional[str] = None, + mlflow_experiment_name: Optional[str] = None, + mlflow_run_name: Optional[str] = None) -> Optional[MlflowConfig]: """Create MLflow configuration with resolved resource ARN. Args: @@ -623,7 +629,7 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, # Derive mlflow_resource_arn with default experience resolved_mlflow_arn = _resolve_mlflow_resource_arn(sagemaker_session, mlflow_resource_arn) - logger.info(f"MLflow resource ARN: {resolved_mlflow_arn}") + logger.info("MLflow resource ARN: %s", resolved_mlflow_arn) # Create MlflowConfig using shapes mlflow_config = None @@ -637,12 +643,13 @@ def _create_mlflow_config(sagemaker_session, mlflow_resource_arn=None, return mlflow_config -def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None): +def _create_output_config(sagemaker_session, s3_output_path: Optional[str] = None, + kms_key_id: Optional[str] = None) -> OutputDataConfig: """Create output data configuration with default S3 path if needed. Args: - s3_output_path: S3 output path (if None, generates default path) sagemaker_session: SageMaker session for generating default path + s3_output_path: S3 output path (if None, generates default path) kms_key_id: Optional KMS key ID for encryption Returns: @@ -662,7 +669,7 @@ def _create_output_config(sagemaker_session,s3_output_path=None, kms_key_id=None ) -def _convert_input_data_to_channels(input_data_config ): +def _convert_input_data_to_channels(input_data_config: List[InputData]) -> List[Channel]: """Convert InputData objects to Channel objects with S3 and dataset ARN support. Args: @@ -689,9 +696,15 @@ def _convert_input_data_to_channels(input_data_config ): dataset_source={"dataset_arn": input_data.data_source} ) + # Safely get content_type and compression_type from InputData + content_type = getattr(input_data, 'content_type', None) + compression_type = getattr(input_data, 'compression_type', None) + channel = Channel( channel_name=input_data.channel_name, data_source=data_source, + content_type=content_type, + compression_type=compression_type, ) channels.append(channel) @@ -713,7 +726,7 @@ def _validate_and_resolve_model_package_group(model, model_package_group_name): "not a ModelPackage artifact/not continued finetuning") -def _validate_eula_for_gated_model(model, accept_eula, is_gated_model): +def _validate_eula_for_gated_model(model, accept_eula: bool, is_gated_model: bool) -> bool: """Validate EULA acceptance for gated models. Args: @@ -786,7 +799,6 @@ def _validate_s3_path_exists(s3_path: str, sagemaker_session): def _validate_hyperparameter_values(hyperparameters: dict): """Validate hyperparameter values for allowed characters.""" - import re allowed_chars = r"^[a-zA-Z0-9/_.:,\-\s'\"\[\]]*$" for key, value in hyperparameters.items(): if isinstance(value, str) and not re.match(allowed_chars, value): diff --git a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py index 7a63e36234..9bfd765973 100644 --- a/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py +++ b/sagemaker-train/tests/unit/train/common_utils/test_finetune_utils.py @@ -450,12 +450,45 @@ def test__create_output_config(self, mock_validate_s3): mock_validate_s3.assert_called_once_with("s3://bucket/output", mock_session) def test__convert_input_data_to_channels(self): + """Test basic conversion of InputData to Channel, including content_type.""" + input_data = [InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json")] + channels = _convert_input_data_to_channels(input_data) + + assert len(channels) == 1 + assert channels[0].channel_name == "train" + assert channels[0].content_type == "application/json" + + def test__convert_input_data_to_channels_with_content_type_preserved(self): + """Test that content_type is preserved when converting InputData to Channel.""" + input_data = [ + InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json"), + InputData(channel_name="validation", data_source="s3://bucket/val", content_type="text/csv"), + ] + channels = _convert_input_data_to_channels(input_data) + + assert len(channels) == 2 + assert channels[0].channel_name == "train" + assert channels[0].content_type == "application/json" + assert channels[1].channel_name == "validation" + assert channels[1].content_type == "text/csv" + def test__convert_input_data_to_channels_without_content_type(self): + """Test that omitting content_type doesn't cause issues.""" input_data = [InputData(channel_name="train", data_source="s3://bucket/data")] channels = _convert_input_data_to_channels(input_data) assert len(channels) == 1 assert channels[0].channel_name == "train" + assert channels[0].content_type is None + + def test__convert_input_data_to_channels_without_compression_type(self): + """Test that InputData without compression_type doesn't cause issues in conversion.""" + input_data = [InputData(channel_name="train", data_source="s3://bucket/data", content_type="application/json")] + channels = _convert_input_data_to_channels(input_data) + + assert len(channels) == 1 + assert channels[0].channel_name == "train" + assert channels[0].compression_type is None def test__validate_eula_for_gated_model_with_model_package(self): """Test EULA validation returns True for ModelPackage input"""