Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def _prepare_train_script(
execute_driver=execute_driver,
)

with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f:
with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w", newline="\n") as f:
f.write(train_script)

@classmethod
Expand Down
43 changes: 43 additions & 0 deletions sagemaker-train/tests/unit/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,49 @@ def mock_upload_data(path, bucket, key_prefix):
assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard"


@patch("sagemaker.train.model_trainer.TrainingJob")
@patch("sagemaker.train.model_trainer.TemporaryDirectory")
def test_prepare_train_script_uses_lf_line_endings(
mock_tmp_dir, mock_training_job, modules_session
):
"""Test that _prepare_train_script generates sm_train.sh with LF line endings only."""
modules_session.upload_data.return_value = (
f"s3://{DEFAULT_BUCKET}/{DEFAULT_BASE_NAME}-job/input/test"
)

tmp_dir = tempfile.TemporaryDirectory()
tmp_dir._cleanup = False
tmp_dir.cleanup = lambda: None
mock_tmp_dir.return_value = tmp_dir

try:
model_trainer = ModelTrainer(
sagemaker_session=modules_session,
training_image=DEFAULT_IMAGE,
source_code=DEFAULT_SOURCE_CODE,
role=DEFAULT_ROLE,
)

model_trainer.train()

train_script_path = os.path.join(tmp_dir.name, TRAIN_SCRIPT)
assert os.path.exists(train_script_path)

with open(train_script_path, "rb") as f:
content = f.read()

# Verify no CRLF line endings exist
assert b"\r\n" not in content, (
"sm_train.sh contains CRLF line endings; expected LF only"
)
# Verify LF line endings are present
assert b"\n" in content, (
"sm_train.sh does not contain any LF line endings"
)
finally:
shutil.rmtree(tmp_dir.name, ignore_errors=True)


@patch("sagemaker.train.model_trainer.TrainingJob")
def test_input_merge(mock_training_job, modules_session):
model_input = InputData(channel_name="model", data_source="s3://bucket/model/model.tar.gz")
Expand Down
Loading