Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mlpstorage_py/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
add_dlio_arguments,
)

from mlpstorage_py.cli.training_args import add_training_arguments
from mlpstorage_py.cli.checkpointing_args import add_checkpointing_arguments
from mlpstorage_py.cli.vectordb_args import add_vectordb_arguments
from mlpstorage_py.cli.kvcache_args import add_kvcache_arguments
from mlpstorage_py.cli.training_args import add_training_arguments, validate_training_arguments
from mlpstorage_py.cli.checkpointing_args import add_checkpointing_arguments, validate_checkpointing_arguments
from mlpstorage_py.cli.vectordb_args import add_vectordb_arguments, validate_vectordb_arguments
from mlpstorage_py.cli.kvcache_args import add_kvcache_arguments, validate_kvcache_arguments
from mlpstorage_py.cli.utility_args import add_reports_arguments, add_history_arguments
from mlpstorage_py.cli.lockfile_args import add_lockfile_arguments

Expand All @@ -47,10 +47,10 @@
'add_host_arguments',
'add_dlio_arguments',
# Benchmark argument builders
'add_training_arguments',
'add_checkpointing_arguments',
'add_vectordb_arguments',
'add_kvcache_arguments',
'add_training_arguments', 'validate_training_arguments',
'add_checkpointing_arguments', 'validate_checkpointing_arguments',
'add_vectordb_arguments', 'validate_vectordb_arguments',
'add_kvcache_arguments', 'validate_kvcache_arguments',
# Utility argument builders
'add_reports_arguments',
'add_history_arguments',
Expand Down
130 changes: 84 additions & 46 deletions mlpstorage_py/cli/checkpointing_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
including datasize and run commands.
"""

from mlpstorage_py.config import DEFAULT_HOSTS, EXEC_TYPE
import sys

from mlpstorage_py.config import DEFAULT_HOSTS, EXEC_TYPE, LLM_MODELS, LLM_MODELS_CLOSED, EXIT_CODE
from mlpstorage_py.cli.common_args import (
HELP_MESSAGES,
add_universal_arguments,
Expand All @@ -16,7 +18,7 @@
)


def add_checkpointing_arguments(parser):
def add_checkpointing_arguments(parser, is_closed):
"""Add checkpointing benchmark arguments to the parser.

Args:
Expand All @@ -37,8 +39,7 @@ def add_checkpointing_arguments(parser):

# Common arguments for both datasize and run
for _parser in [datasize, run_benchmark]:
add_host_arguments(_parser)

add_host_arguments(_parser, is_closed)
_parser.add_argument(
'--client-host-memory-in-gb', '-cm',
type=int,
Expand All @@ -48,53 +49,90 @@ def add_checkpointing_arguments(parser):

# Model argument - using help text with choices instead of choices param
# to avoid very long help output
_parser.add_argument(
'--model', '-m',
required=True,
help=HELP_MESSAGES['llm_model']
)
if is_closed:
_parser.add_argument(
'--model', '-m',
choices=LLM_MODELS_CLOSED,
required=True,
help=HELP_MESSAGES['llm_model']
)

else:
_parser.add_argument(
'--model', '-m',
choices=LLM_MODELS,
required=True,
help=HELP_MESSAGES['llm_model']
)

if is_closed:
_parser.set_defaults(
num_checkpoints_read=10,
num_checkpoints_write=10
)
else:
_parser.add_argument(
'--num-checkpoints-read', '-ncr',
type=int,
default=10,
help=HELP_MESSAGES['num_checkpoints']
)

_parser.add_argument(
'--num-checkpoints-write', '-ncw',
type=int,
default=10,
help=HELP_MESSAGES['num_checkpoints']
)

add_dlio_arguments(_parser, is_closed)

# Add exec-type and MPI arguments to both datasize and run
_parser.add_argument(
'--exec-type', '-et',
type=EXEC_TYPE,
choices=list(EXEC_TYPE),
default=EXEC_TYPE.MPI,
help=HELP_MESSAGES['exec_type']
)
add_mpi_arguments(_parser, is_closed)

run_benchmark.add_argument(
'--num-processes', '-np',
type=int,
required=True,
help=HELP_MESSAGES['num_checkpoint_accelerators']
)

_parser.add_argument(
'--num-checkpoints-read', '-ncr',
type=int,
default=10,
help=HELP_MESSAGES['num_checkpoints']
)
run_benchmark.add_argument(
"--checkpoint-folder", '-cf',
type=str,
required=True,
help=HELP_MESSAGES['checkpoint_folder']
)

_parser.add_argument(
'--num-checkpoints-write', '-ncw',
type=int,
default=10,
help=HELP_MESSAGES['num_checkpoints']
)
add_universal_arguments(run_benchmark, True, True, True, is_closed)
add_universal_arguments(datasize, False, False, True, is_closed)

_parser.add_argument(
'--num-processes', '-np',
type=int,
required=True,
help=HELP_MESSAGES['num_checkpoint_accelerators']
)
# Add time-series arguments to run command only
add_timeseries_arguments(run_benchmark, is_closed)

_parser.add_argument(
"--checkpoint-folder", '-cf',
type=str,
required=True,
help=HELP_MESSAGES['checkpoint_folder']
)

add_dlio_arguments(_parser)
def validate_checkpointing_arguments(args):
"""Validate the whole set of args given that we're doing a checkpointing benchmark

Args:
args (argparse.Namespace): The parsed command-line arguments
"""
error_messages = []

# Add exec-type and MPI arguments to both datasize and run
_parser.add_argument(
'--exec-type', '-et',
type=EXEC_TYPE,
choices=list(EXEC_TYPE),
default=EXEC_TYPE.MPI,
help=HELP_MESSAGES['exec_type']
)
add_mpi_arguments(_parser)
if args.model not in LLM_MODELS:
error_messages.append("Invalid LLM model. Supported models are: {}".format(", ".join(LLM_MODELS)))
if args.num_checkpoints_read < 0 or args.num_checkpoints_write < 0:
error_messages.append("Number of checkpoints read and write must be non-negative")

add_universal_arguments(_parser)
if error_messages:
for msg in error_messages:
print(msg)

# Add time-series arguments to run command only
add_timeseries_arguments(run_benchmark)
sys.exit(EXIT_CODE.INVALID_ARGUMENTS)
Loading
Loading