diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 48c42c9093..15fce697ea 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -1061,7 +1061,7 @@ def _prepare_train_script( execute_driver=execute_driver, ) - with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f: + with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w", newline="\n") as f: f.write(train_script) @classmethod diff --git a/sagemaker-train/tests/unit/train/test_model_trainer.py b/sagemaker-train/tests/unit/train/test_model_trainer.py index 220e0fb40f..4143f01151 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer.py @@ -1287,6 +1287,49 @@ def mock_upload_data(path, bucket, key_prefix): assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" +@patch("sagemaker.train.model_trainer.TrainingJob") +@patch("sagemaker.train.model_trainer.TemporaryDirectory") +def test_prepare_train_script_uses_lf_line_endings( + mock_tmp_dir, mock_training_job, modules_session +): + """Test that _prepare_train_script generates sm_train.sh with LF line endings only.""" + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test" + ) + + tmp_dir = tempfile.TemporaryDirectory() + tmp_dir._cleanup = False + tmp_dir.cleanup = lambda: None + mock_tmp_dir.return_value = tmp_dir + + try: + model_trainer = ModelTrainer( + sagemaker_session=modules_session, + training_image=DEFAULT_IMAGE, + source_code=DEFAULT_SOURCE_CODE, + role=DEFAULT_ROLE, + ) + + model_trainer.train() + + train_script_path = os.path.join(tmp_dir.name, TRAIN_SCRIPT) + assert os.path.exists(train_script_path) + + with open(train_script_path, "rb") as f: + content = f.read() + + # Verify no CRLF line endings exist + assert b"\r\n" not in content, ( + "sm_train.sh contains CRLF line endings; expected LF only" + ) + # Verify LF line endings are present + assert b"\n" in content, ( + "sm_train.sh does not contain any LF line endings" + ) + finally: + shutil.rmtree(tmp_dir.name, ignore_errors=True) + + @patch("sagemaker.train.model_trainer.TrainingJob") def test_input_merge(mock_training_job, modules_session): model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz")