diff --git a/mlpstorage b/mlpstorage index 880a5c02..9c160d50 100755 --- a/mlpstorage +++ b/mlpstorage @@ -1,2 +1,12 @@ -#! /bin/bash -uv run python3 -m mlpstorage_py.main $* +#!/usr/bin/env bash +# Wrapper for running mlpstorage directly from a checkout via `uv run`. +# Preferred install path is `uv pip install -e .` (uses [project.scripts]). +# +# IMPORTANT fixes vs prior version (see issue #322): +# - "$@" (not $*) preserves argument quoting; args with spaces stay intact. +# - `--` separator prevents `uv run` from intercepting flags meant for mlpstorage +# (e.g. --project, --directory, --no-sync, --python). +# - --project pins the uv project to this checkout regardless of $PWD. +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +exec uv run --project "$SCRIPT_DIR" -- python3 -m mlpstorage_py.main "$@" diff --git a/mlpstorage_py/cli/common_args.py b/mlpstorage_py/cli/common_args.py index 962b7185..1eaad497 100755 --- a/mlpstorage_py/cli/common_args.py +++ b/mlpstorage_py/cli/common_args.py @@ -51,11 +51,15 @@ ), 'client_hosts': ( "Space-separated list of IP addresses or hostnames of the participating hosts. " - "\nExample: '--hosts 192.168.1.1 192.168.1.2 192.168.1.3' or '--hosts host1 host2 host3'. Slots can " + "\nExample: '--hosts 192.168.1.1 192.168.1.2 192.168.1.3' or '--hosts host1 host2 host3'. " + "Comma-separated values are also accepted: '--hosts host1,host2,host3'. " + "Slots can " "be specified by appending ':' to a hostname like so: '--hosts host1:2 host2:2'. This " "example will run 2 accelerators on each host. If slots are not specified the number of processes " "will be equally distributed across the hosts with any remainder being distributed evenly on the " - "remaining hosts in the order they are listed." + "remaining hosts in the order they are listed. " + "\nDo NOT use '--hosts=h1 h2' (with '=' and space); argparse will only bind 'h1' to --hosts " + "and treat 'h2' as a stray positional argument. Use '--hosts h1 h2' or '--hosts=h1,h2' instead." ), 'category': "Benchmark category to be submitted.", 'results_dir': "Directory where the benchmark results will be saved.", diff --git a/mlpstorage_py/cli_parser.py b/mlpstorage_py/cli_parser.py index 9805ad41..af32669f 100755 --- a/mlpstorage_py/cli_parser.py +++ b/mlpstorage_py/cli_parser.py @@ -6,6 +6,7 @@ """ import argparse +import re import sys from mlpstorage_py import VERSION @@ -243,14 +244,29 @@ def update_args(args): flattened_mpi_params = [item for sublist in args.mpi_params for item in sublist] setattr(args,'mpi_params', flattened_mpi_params) - if hasattr(args, 'hosts'): - print(f'Hosts is: {args.hosts}') - # hosts can be comma separated string or a list of strings. If it's a string, it is still a list of length 1 - if len(args.hosts) == 1 and isinstance(args.hosts[0], str): - setattr(args, 'hosts', args.hosts[0].split(',')) - print(f'Hosts is: {args.hosts}') - - if not hasattr(args, "num_client_hosts") and hasattr(args, "hosts"): + if hasattr(args, 'hosts') and args.hosts is not None: + # Accept any of the following equivalent forms and normalize to a clean list: + # --hosts h1 h2 h3 -> ['h1', 'h2', 'h3'] + # --hosts h1,h2,h3 -> ['h1', 'h2', 'h3'] + # --hosts 'h1 h2 h3' -> ['h1', 'h2', 'h3'] (quoted, e.g. from YAML) + # --hosts='h1,h2,h3' -> ['h1', 'h2', 'h3'] (DLIO subprocess form) + # --hosts='h1 h2 h3' -> ['h1', 'h2', 'h3'] (quoted after '=') + # This defends against the argparse + nargs='+' + '=' interaction documented in + # https://github.com/mlcommons/storage/issues/322. + raw = args.hosts if isinstance(args.hosts, list) else [args.hosts] + normalized = [] + for item in raw: + if not isinstance(item, str): + continue + for tok in re.split(r'[,\s]+', item.strip()): + if tok: + normalized.append(tok) + if not normalized: + print("ERROR: --hosts is empty after parsing", file=sys.stderr) + sys.exit(EXIT_CODE.INVALID_ARGUMENTS) + args.hosts = normalized + + if hasattr(args, 'hosts') and getattr(args, 'num_client_hosts', None) is None: setattr(args, "num_client_hosts", len(args.hosts)) diff --git a/mlpstorage_py/cluster_collector.py b/mlpstorage_py/cluster_collector.py old mode 100755 new mode 100644 index 2b94b0bf..a4a1ea4e --- a/mlpstorage_py/cluster_collector.py +++ b/mlpstorage_py/cluster_collector.py @@ -11,9 +11,9 @@ import shutil import socket import subprocess -import tempfile import threading import time +import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field, asdict from typing import Any, Dict, List, Optional, Tuple @@ -1178,8 +1178,12 @@ def __init__( hosts: List[str], mpi_bin: str, logger, + results_dir: str, allow_run_as_root: bool = False, - timeout_seconds: int = 60 + timeout_seconds: int = 60, + shared_staging_dir: Optional[str] = None, + shared_tmp_dir: Optional[str] = None, # deprecated, see note below + ssh_username: Optional[str] = None, ): """ Initialize the MPI cluster collector. @@ -1188,14 +1192,57 @@ def __init__( hosts: List of hostnames/IPs, optionally with slot counts (e.g., "host1:4"). mpi_bin: MPI binary to use (MPIRUN or MPIEXEC constant). logger: Logger instance for messages. + results_dir: Absolute or relative path to the benchmark results + directory. The collector stages its helper script under + ``/collector-staging/``; the staged script + persists after the run as a debuggable artifact. This + replaces the previous per-invocation ``tempfile`` staging + directory so no programmatic ``rm -rf`` is ever issued over + SSH (see PR #347 review). allow_run_as_root: If True, adds --allow-run-as-root flag. timeout_seconds: Maximum time to wait for collection. + shared_staging_dir: Optional path that is visible on every node. + When set, the collector writes the helper script under this + path and skips SSH-based staging. Typically used on clusters + with a shared scratch filesystem (NFS/Lustre/GPFS). + shared_tmp_dir: Deprecated alias for ``shared_staging_dir``. + Kept for one release for backward compatibility; emits a + DeprecationWarning. + ssh_username: Optional SSH username used when staging the script on + remote hosts. Defaults to the current user. Ignored when + ``shared_staging_dir`` is set or when all hosts are localhost. + + Raises: + ValueError: if ``results_dir`` is empty or None. Multi-host + collection without a results directory has no defensible + staging location now that tempdir-based staging is gone. """ + if not results_dir: + raise ValueError( + "MPIClusterCollector requires results_dir for script staging" + ) + + # Backward compatibility for the old kwarg name. Drop in a future + # release. + if shared_tmp_dir is not None: + warnings.warn( + "shared_tmp_dir is deprecated; use shared_staging_dir instead.", + DeprecationWarning, + stacklevel=2, + ) + if shared_staging_dir is None: + shared_staging_dir = shared_tmp_dir + self.hosts = hosts self.mpi_bin = mpi_bin self.logger = logger + self.results_dir = os.path.abspath(results_dir) self.allow_run_as_root = allow_run_as_root self.timeout = timeout_seconds + self.shared_staging_dir = ( + os.path.abspath(shared_staging_dir) if shared_staging_dir else None + ) + self.ssh_username = ssh_username def _get_unique_hosts(self) -> List[str]: """Extract unique hostnames from the hosts list (removing slot counts).""" @@ -1252,10 +1299,135 @@ def _write_collector_script(self, script_path: str) -> None: f.write(MPI_COLLECTOR_SCRIPT) os.chmod(script_path, 0o755) + def _ssh_target(self, host: str) -> str: + """Return '[user@]host' for SSH/SCP invocations.""" + return f"{self.ssh_username}@{host}" if self.ssh_username else host + + def _ssh_common_opts(self) -> List[str]: + """SSH/SCP options used for all staging operations. + + * ``BatchMode=yes`` — never prompt for a password; fail fast if + passwordless SSH is not configured. + * ``StrictHostKeyChecking=accept-new`` — accept new host keys on first + contact but reject changed keys; matches the behavior users already + have configured for ``mpirun``. + * ``ForwardX11=no`` — suppress the ``Authorization required, but no + authorization protocol specified`` noise seen in issue #303. + * ``ConnectTimeout`` — bound per-host handshake time so a single bad + host cannot consume the whole collection timeout budget. + """ + return [ + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=accept-new", + "-o", "ForwardX11=no", + "-o", f"ConnectTimeout={max(5, self.timeout // 6)}", + ] + + def _remote_hosts_needing_staging(self) -> List[str]: + """Return remote (non-localhost) unique hosts that need the script.""" + return [h for h in self._get_unique_hosts() if not _is_localhost(h)] + + def _stage_script_on_remote_hosts( + self, + script_local_path: str, + remote_dir: str, + hosts: List[str], + ) -> Dict[str, Optional[str]]: + """SCP the collector script to ``remote_dir`` on each remote host. + + The per-host work is parallelised with a thread pool; each call is + independent and almost entirely I/O-bound. + + Args: + script_local_path: Path to the collector script on the launch host. + remote_dir: Absolute directory to create on each remote host; the + script will be placed at ``remote_dir/mlps_collector.py``. The + same absolute path is used on every node so the ``mpirun`` + command line is identical everywhere. + hosts: Remote hostnames to stage to. Callers should pass the result + of :meth:`_remote_hosts_needing_staging` to avoid SSHing to the + launch host. + + Returns: + Mapping ``{host: None on success, error_message_str on failure}``. + """ + per_host_timeout = max(10, self.timeout // 3) + ssh_common = self._ssh_common_opts() + + def stage_one(host: str) -> Tuple[str, Optional[str]]: + target = self._ssh_target(host) + try: + mkdir_cmd = [ + "ssh", *ssh_common, target, f"mkdir -p '{remote_dir}'" + ] + r = subprocess.run( + mkdir_cmd, capture_output=True, text=True, + timeout=per_host_timeout, + ) + if r.returncode != 0: + return host, f"ssh mkdir failed: {r.stderr.strip() or r.stdout.strip()}" + + scp_cmd = [ + "scp", *ssh_common, script_local_path, + f"{target}:{remote_dir}/mlps_collector.py", + ] + r = subprocess.run( + scp_cmd, capture_output=True, text=True, + timeout=per_host_timeout, + ) + if r.returncode != 0: + return host, f"scp failed: {r.stderr.strip() or r.stdout.strip()}" + return host, None + except subprocess.TimeoutExpired: + return host, f"timed out after {per_host_timeout}s" + except FileNotFoundError as e: + return host, f"ssh/scp binary not found: {e}" + except Exception as e: # pragma: no cover — defensive + return host, f"unexpected error: {e}" + + results: Dict[str, Optional[str]] = {} + max_workers = min(16, max(1, len(hosts))) + with ThreadPoolExecutor(max_workers=max_workers) as ex: + futures = {ex.submit(stage_one, h): h for h in hosts} + for f in as_completed(futures): + host, err = f.result() + results[host] = err + if err: + self.logger.warning( + f"Script staging on {host} failed: {err}" + ) + else: + self.logger.info( + f"Collector script staged on {host}:{remote_dir}" + ) + return results + def collect(self) -> Dict[str, Any]: """ Execute MPI collection across all nodes. + The collector script is written to ``/collector-staging/`` + on the launch host and the same absolute path is created on each + remote host via SSH before the script is copied there with SCP. + Because ``results_dir`` is resolved to an absolute path at + construction time, the path is identical on every participating + node, which is what ``mpirun`` requires. + + When ``shared_staging_dir`` is set the script is written under that + path and no SSH staging is performed (suitable for clusters with a + shared NFS/Lustre/GPFS scratch FS). + + The staged script is **not removed** at the end of the run — it is + kept as a persistent run artifact so users can inspect it after a + failure. This is a deliberate design choice (see PR #347 review) + : programmatic ``rm -rf`` over SSH is + unacceptable. Consecutive runs against the same ``results_dir`` + simply overwrite the script, which is safe and idempotent. + + This fixes issue #303, where the previous implementation assumed + ``tempfile.TemporaryDirectory()`` on the launch host was visible to + every rank. + Returns: Dictionary mapping hostname -> system_info dict. @@ -1263,66 +1435,133 @@ def collect(self) -> Dict[str, Any]: RuntimeError: If MPI collection fails completely. """ unique_hosts = self._get_unique_hosts() - self.logger.debug(f"Starting MPI cluster collection on {len(unique_hosts)} hosts") + self.logger.debug( + f"Starting MPI cluster collection on {len(unique_hosts)} hosts" + ) - # Create temporary files for script and output - with tempfile.TemporaryDirectory() as tmpdir: - script_path = os.path.join(tmpdir, 'mlps_collector.py') - output_path = os.path.join(tmpdir, 'cluster_info.json') + # --- Decide where to place the helper script --------------------- + if self.shared_staging_dir: + staging_dir = self.shared_staging_dir + use_staging = False + self.logger.debug( + f"Using shared staging dir (no SSH staging): {staging_dir}" + ) + else: + staging_dir = os.path.join(self.results_dir, "collector-staging") + use_staging = True - # Write the collector script - self._write_collector_script(script_path) + script_path = os.path.join(staging_dir, "mlps_collector.py") + output_path = os.path.join(staging_dir, "cluster_info.json") - # Generate and run the MPI command - cmd = self._generate_mpi_command(script_path, output_path) - self.logger.debug(f"Running MPI collection command: {cmd}") + remote_hosts_to_stage: List[str] = [] - try: - result = subprocess.run( - cmd, - shell=True, - capture_output=True, - text=True, - timeout=self.timeout + os.makedirs(staging_dir, exist_ok=True) + self._write_collector_script(script_path) + self.logger.info( + f"Collector script staged at {script_path} " + f"(persisted as run artifact)" + ) + + # --- Stage the script on remote hosts if needed ------------------ + if use_staging: + remote_hosts_to_stage = self._remote_hosts_needing_staging() + if remote_hosts_to_stage: + self.logger.info( + f"Staging collector script to " + f"{len(remote_hosts_to_stage)} remote host(s)..." ) + stage_results = self._stage_script_on_remote_hosts( + script_path, staging_dir, remote_hosts_to_stage + ) + failures = { + h: e for h, e in stage_results.items() if e + } + if failures: + raise RuntimeError( + "Failed to stage collector script on " + f"{len(failures)} host(s): {failures}. " + "Verify passwordless SSH from the launch host, or " + "set --cluster-collector-shared-staging / " + "MLPS_CLUSTER_COLLECTOR_SHARED_STAGING to a " + "directory visible on every node." + ) - # Read and parse the output if it exists - if os.path.exists(output_path): - with open(output_path, 'r') as f: - collected_data = json.load(f) - - # Check for MPI import error marker - if collected_data.get('_mpi_import_error'): - error_msg = collected_data.get('_error_message', 'mpi4py not available') - error_host = collected_data.get('_hostname', 'unknown') - raise RuntimeError( - f"MPI collection failed on host '{error_host}': {error_msg}. " - f"Ensure mpi4py is installed on all cluster nodes." - ) + # --- Build and run the mpirun command ---------------------------- + cmd = self._generate_mpi_command(script_path, output_path) + self.logger.info( + f"Running MPI collection across {len(unique_hosts)} host(s)" + ) + self.logger.debug(f"MPI command: {cmd}") + + # Silence OpenSSH X11-forwarding warnings that mpirun's rsh/ssh + # PLM emits when XAUTHORITY is not set on the launch host + # ('Authorization required, but no authorization protocol + # specified'). Reported in issue #303. + env = os.environ.copy() + env.pop("DISPLAY", None) # prevent SSH X11 forwarding handshake + env.pop("XAUTHORITY", None) # and its cookie lookup + env.setdefault( + "PLM_RSH_AGENT", + "ssh -o ForwardX11=no -o ForwardX11Trusted=no " + "-o StrictHostKeyChecking=accept-new", + ) - # Check for non-zero return code (other MPI errors) - if result.returncode != 0: - self.logger.warning( - f"MPI collection returned non-zero exit code: {result.returncode}\n" - f"stderr: {result.stderr}" - ) + try: + result = subprocess.run( + cmd, + shell=True, + capture_output=True, + text=True, + timeout=self.timeout, + env=env, + ) + except subprocess.TimeoutExpired: + raise RuntimeError( + f"MPI collection timed out after {self.timeout} seconds" + ) - self.logger.debug( - f"Successfully collected info from {len(collected_data)} hosts" - ) - return collected_data - else: - raise RuntimeError( - f"MPI collection did not produce output file. " - f"Return code: {result.returncode}, stderr: {result.stderr}" - ) + # --- Parse the output written by rank 0 -------------------------- + if os.path.exists(output_path): + with open(output_path, 'r') as f: + collected_data = json.load(f) - except subprocess.TimeoutExpired: + if collected_data.get('_mpi_import_error'): + error_msg = collected_data.get( + '_error_message', 'mpi4py not available' + ) + error_host = collected_data.get('_hostname', 'unknown') raise RuntimeError( - f"MPI collection timed out after {self.timeout} seconds" + f"MPI collection failed on host '{error_host}': " + f"{error_msg}. Ensure mpi4py is installed on all " + "cluster nodes." ) - except Exception as e: - raise RuntimeError(f"MPI collection failed: {e}") + + if result.returncode != 0: + self.logger.warning( + f"MPI collection returned non-zero exit code: " + f"{result.returncode}\nstderr: {result.stderr}" + ) + + self.logger.info( + f"MPI collection completed successfully " + f"({len(collected_data)} hosts reported)" + ) + return collected_data + + # No output file — surface staging + mpirun context together. + # The staged script is left in place on both launch and remote + # hosts so the failure can be diagnosed post-mortem. + staged_summary = ( + remote_hosts_to_stage if remote_hosts_to_stage + else "[launch host only]" + ) + raise RuntimeError( + "MPI collection did not produce output file. " + f"Return code: {result.returncode}. " + f"Staged on: {staged_summary}. " + f"Staged script (persisted for inspection): {script_path}. " + f"stderr: {result.stderr}" + ) def collect_local_only(self) -> Dict[str, Any]: """ @@ -1339,9 +1578,13 @@ def collect_cluster_info( hosts: List[str], mpi_bin: str, logger, + results_dir: str, allow_run_as_root: bool = False, timeout_seconds: int = 60, - fallback_to_local: bool = True + fallback_to_local: bool = True, + shared_staging_dir: Optional[str] = None, + shared_tmp_dir: Optional[str] = None, # deprecated, see note below + ssh_username: Optional[str] = None, ) -> Dict[str, Any]: """ High-level function to collect cluster information. @@ -1354,9 +1597,18 @@ def collect_cluster_info( hosts: List of hostnames/IPs to collect from. mpi_bin: MPI command to use. logger: Logger instance. + results_dir: Benchmark results directory. The helper script will be + staged under ``/collector-staging/`` and persists + after the run as a debuggable artifact. Required. allow_run_as_root: Whether to allow running as root. timeout_seconds: Timeout for MPI collection. fallback_to_local: If True, fall back to local collection on failure. + shared_staging_dir: Optional path visible on every node. If provided, + the collector skips SSH-based script staging. See + :class:`MPIClusterCollector` for details. + shared_tmp_dir: Deprecated alias for ``shared_staging_dir``. + ssh_username: Optional SSH username for remote script staging. + Defaults to the current user. Returns: Dictionary mapping hostname -> system_info dict. @@ -1366,8 +1618,12 @@ def collect_cluster_info( hosts=hosts, mpi_bin=mpi_bin, logger=logger, + results_dir=results_dir, allow_run_as_root=allow_run_as_root, - timeout_seconds=timeout_seconds + timeout_seconds=timeout_seconds, + shared_staging_dir=shared_staging_dir, + shared_tmp_dir=shared_tmp_dir, + ssh_username=ssh_username, ) metadata = { diff --git a/mlpstorage_py/environment/validators.py b/mlpstorage_py/environment/validators.py index f7f868da..27a1ed24 100755 --- a/mlpstorage_py/environment/validators.py +++ b/mlpstorage_py/environment/validators.py @@ -50,6 +50,9 @@ def __str__(self) -> str: parts.append(f"Host: {self.host}") return "\n".join(parts) +# Defensive: update_args() normalizes quoted/space-separated --hosts forms +# before we get here, so this should never fire in normal flow. Kept as a +# belt-and-suspenders guard against direct API callers that skip update_args. def validate_ssh_connectivity( hosts: List[str], @@ -103,6 +106,22 @@ def validate_ssh_connectivity( # Parse host:slots format (e.g., "node1:4" -> "node1") hostname = host_entry.split(':')[0].strip() + # Guard against malformed tokens that could only come from a parsing error + # upstream (see issue #322: users passing `--hosts='h1 h2'` with quotes get + # a single token containing whitespace). + if not hostname or any(ch.isspace() for ch in hostname): + results.append(( + host_entry, + False, + ( + f"Invalid host token {host_entry!r}: contains whitespace or is empty. " + "Hosts must be passed as separate arguments " + "(e.g. `--hosts h1 h2 h3`) or comma-separated " + "(e.g. `--hosts h1,h2,h3`), not as a single quoted string." + ), + )) + continue + # Skip localhost entries if hostname.lower() in ('localhost', '127.0.0.1'): results.append((hostname, True, 'localhost (skipped)')) diff --git a/mlpstorage_py/tests/test_cluster_collector.py b/mlpstorage_py/tests/test_cluster_collector.py index 949e8944..384f8ae9 100755 --- a/mlpstorage_py/tests/test_cluster_collector.py +++ b/mlpstorage_py/tests/test_cluster_collector.py @@ -724,10 +724,11 @@ def test_collector_returns_valid_data_without_error_marker(self, mock_logger): } with tempfile.TemporaryDirectory() as tmpdir: - output_path = os.path.join(tmpdir, 'cluster_info.json') - with open(output_path, 'w') as f: - json.dump(valid_output, f) - + # Under the new implementation (issue #303 fix), the collector + # creates a uuid-named subdirectory inside its base tmp dir and + # writes cluster_info.json there. We exercise that path by + # supplying ``shared_tmp_dir`` and pinning the uuid so we know + # the final output path. import subprocess from unittest.mock import patch, MagicMock @@ -735,12 +736,23 @@ def test_collector_returns_valid_data_without_error_marker(self, mock_logger): mock_result.returncode = 0 mock_result.stderr = "" - with patch('subprocess.run', return_value=mock_result): - with patch.object(collector, '_write_collector_script'): - with patch.object(collector, '_generate_mpi_command', return_value="mpirun test"): - with patch('tempfile.TemporaryDirectory') as mock_tmpdir: - mock_tmpdir.return_value.__enter__.return_value = tmpdir - + with patch('mlpstorage_py.cluster_collector.uuid.uuid4') as mock_uuid: + mock_uuid.return_value.hex = 'abcdef012345' + working_dir = os.path.join(tmpdir, 'mlps_collector_abcdef012345') + os.makedirs(working_dir, exist_ok=True) + output_path = os.path.join(working_dir, 'cluster_info.json') + with open(output_path, 'w') as f: + json.dump(valid_output, f) + + collector.shared_tmp_dir = tmpdir + + with patch('mlpstorage_py.cluster_collector.subprocess.run', + return_value=mock_result): + with patch.object(collector, '_write_collector_script'): + with patch.object( + collector, '_generate_mpi_command', + return_value="mpirun test", + ): result = collector.collect() assert 'host1' in result diff --git a/mlpstorage_py/tests/test_mpi_cluster_collector_issue_303.py b/mlpstorage_py/tests/test_mpi_cluster_collector_issue_303.py new file mode 100644 index 00000000..eab0224a --- /dev/null +++ b/mlpstorage_py/tests/test_mpi_cluster_collector_issue_303.py @@ -0,0 +1,578 @@ +"""Regression tests for issue #303. + +MPIClusterCollector used to write the helper collector script into a +``tempfile.TemporaryDirectory()`` on the launch host only, then invoke +``mpirun`` with that absolute path. On clusters with node-local ``/tmp`` +the remote ranks could not find the script and ``mpirun`` aborted with +``[Errno 2] No such file or directory``. + +Review-driven redesign (PR #347): + +* **wolfgang-desalvador**: programmatic ``rm -rf`` over SSH is unacceptable. + The collector now stages under ``/collector-staging/`` and + never removes anything remotely. The staged script persists as a run + artifact for post-mortem. +* **russfellows**: staging progress emitted at INFO; remaining ``mkdir -p`` + path is single-quoted; ``num_client_hosts`` re-derive uses ``is None``. + +These tests cover the resulting code paths: + +* default "stage-and-run" path — SCPs the script to each remote host + before ``mpirun``; the script persists afterwards (no cleanup); +* ``shared_staging_dir`` opt-in — skips all SSH staging; +* partial staging failure — raises a descriptive error naming the bad host; +* no ``rm``/``rm -rf``/``rmdir``/``rmtree`` is ever invoked over SSH; +* staged-script path is emitted at INFO for debuggability; +* ``mkdir -p`` command single-quotes the staging path. + +Also covers the X11 env injection that silences the +``Authorization required, but no authorization protocol specified`` +noise reported in the original issue. +""" + +from __future__ import annotations + +import json +import logging +import os +import subprocess +from typing import List +from unittest import mock + +import pytest + +from mlpstorage_py import cluster_collector as cc +from mlpstorage_py.config import MPIRUN + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_fake_run(output_path_getter, write_output: bool = True): + """Build a fake ``subprocess.run`` that fakes ``mpirun`` by writing the + expected rank-0 output JSON to disk, and records every call. + + Parameters + ---------- + output_path_getter: callable returning the expected cluster_info.json path + at the time ``mpirun`` is invoked (resolved lazily so the per-run + staging path created inside ``collect()`` is honored). + write_output: when False, ``mpirun`` "succeeds" but produces no output + file — simulating a cluster where staging succeeded but mpirun itself + failed to run the script on every rank. + """ + calls: List[dict] = [] + + def fake_run(cmd, *args, **kwargs): + if isinstance(cmd, list): + argv = cmd + kind = argv[0] + else: + argv = cmd.split() + kind = "mpirun" if "mpirun" in cmd or "mpiexec" in cmd else argv[0] + + calls.append({ + "argv": argv, + "kind": kind, + "env": kwargs.get("env"), + "shell": kwargs.get("shell", False), + }) + + # Successful mpirun: write the aggregated JSON rank 0 would produce. + if kind in ("mpirun", "mpiexec") and write_output: + output_path = output_path_getter() + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + json.dump( + { + "host-a": {"hostname": "host-a", "total_memory_kb": 1024}, + "host-b": {"hostname": "host-b", "total_memory_kb": 2048}, + }, + f, + ) + + return subprocess.CompletedProcess(argv, 0, stdout="", stderr="") + + return fake_run, calls + + +def _collector(tmp_path, hosts, results_dir=None, **kwargs): + """ + Construct a collector for testing. + + ``results_dir`` defaults to ``tmp_path`` so every test gets a hermetic + results directory per the new staging design. Callers that want to test + a specific results_dir layout (absolute vs relative, shared, etc.) can + pass one explicitly. + """ + logger = logging.getLogger("test.cluster_collector") + logger.setLevel(logging.DEBUG) + if results_dir is None: + results_dir = str(tmp_path) + return cc.MPIClusterCollector( + hosts=hosts, + mpi_bin=MPIRUN, + logger=logger, + timeout_seconds=30, + results_dir=results_dir, + **kwargs, + ) + + +def _expected_staging_dir(results_dir) -> str: + """Canonical staging path per the Option-1 design wolfgang-desalvador asked for.""" + return os.path.join(os.path.abspath(str(results_dir)), "collector-staging") + + +def _expected_script_path(results_dir) -> str: + return os.path.join(_expected_staging_dir(results_dir), "mlps_collector.py") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestMPIStaging: + """Default path: the collector must SCP the script to remote hosts. + + Under the new design the staging root is ``/collector-staging/`` + — not a tempfile.gettempdir() subtree — and nothing is ever removed. + """ + + def test_stages_script_on_each_remote_host(self, tmp_path, monkeypatch): + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + + # Make _is_localhost return True only for 'host-a' so host-b is staged. + monkeypatch.setattr(cc, "_is_localhost", + lambda h: h in ("host-a", "localhost", "127.0.0.1")) + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + result = collector.collect() + + # mpirun was invoked + mpi_calls = [c for c in calls if c["kind"] in ("mpirun", "mpiexec")] + assert len(mpi_calls) == 1, f"expected 1 mpirun call, got {calls}" + + # The script was staged to host-b (only remote host) via ssh+scp + ssh_calls = [c for c in calls if c["kind"] == "ssh"] + scp_calls = [c for c in calls if c["kind"] == "scp"] + assert any("host-b" in " ".join(c["argv"]) for c in ssh_calls), \ + "expected at least one ssh call targeting host-b" + assert any("host-b" in " ".join(c["argv"]) for c in scp_calls), \ + "expected at least one scp call targeting host-b" + + # Result shape unchanged from before the fix + assert "host-a" in result and "host-b" in result + + def test_staged_script_path_is_under_results_dir(self, tmp_path, monkeypatch): + """Staging root lives under results_dir, not under tempfile.gettempdir().""" + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + # Every ssh and scp call must reference a path rooted under results_dir. + expected_root = _expected_staging_dir(tmp_path) + + for c in calls: + if c["kind"] not in ("ssh", "scp"): + continue + joined = " ".join(c["argv"]) + assert expected_root in joined, \ + f"{c['kind']} call must reference results_dir staging path; " \ + f"got: {joined}" + + def test_staging_path_is_absolute_even_with_relative_results_dir( + self, tmp_path, monkeypatch + ): + """mpirun needs the same absolute path on every node; relative inputs + must be resolved at construction time.""" + monkeypatch.chdir(tmp_path) + (tmp_path / "rel-results").mkdir() + + collector = _collector( + tmp_path, + hosts=["host-a:1"], + results_dir="rel-results", + ) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path / "rel-results"), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + mpi_call = next(c for c in calls if c["kind"] in ("mpirun", "mpiexec")) + joined = " ".join(mpi_call["argv"]) + # Absolute path to the resolved results_dir must appear in the mpirun argv. + assert str((tmp_path / "rel-results").resolve()) in joined, \ + f"mpirun must use an absolute staging path; got: {joined}" + + def test_single_localhost_skips_staging(self, tmp_path, monkeypatch): + """A localhost-only invocation must not SSH anywhere.""" + collector = _collector(tmp_path, hosts=["127.0.0.1:1"]) + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + assert not any(c["kind"] in ("ssh", "scp") for c in calls), \ + "localhost-only run must not invoke ssh or scp" + + +class TestNoRemoteCleanup: + """wolfgang-desalvador: no programmatic destructive command over SSH. + + The staged script must persist after the run (on both launch and remote + hosts) and no ``rm``/``rm -rf``/``rmdir``/``rmtree`` command may ever be + issued to any host — success or failure. + """ + + _FORBIDDEN_TOKENS = ("rm -rf", " rm ", "rmdir", "rmtree") + + def _assert_no_destructive_calls(self, calls): + for c in calls: + joined = " " + " ".join(c["argv"]) + " " + for tok in self._FORBIDDEN_TOKENS: + assert tok not in joined, ( + f"forbidden destructive command {tok.strip()!r} " + f"appeared in {c['kind']} call: {joined!r}" + ) + + def test_no_rm_on_successful_run(self, tmp_path, monkeypatch): + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + self._assert_no_destructive_calls(calls) + + def test_no_rm_on_mpi_failure(self, tmp_path, monkeypatch): + """Even when mpirun fails, no cleanup rm is emitted.""" + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + recorded: List[dict] = [] + + def failing_mpirun(cmd, *args, **kwargs): + argv = cmd if isinstance(cmd, list) else cmd.split() + joined = " ".join(argv) + kind = "mpirun" if "mpirun" in joined or "mpiexec" in joined else argv[0] + recorded.append({ + "argv": argv, + "kind": kind, + "env": kwargs.get("env"), + "shell": kwargs.get("shell", False), + }) + if kind in ("mpirun", "mpiexec"): + return subprocess.CompletedProcess( + argv, 1, stdout="", stderr="mpirun: simulated failure" + ) + return subprocess.CompletedProcess(argv, 0, stdout="", stderr="") + + monkeypatch.setattr(cc.subprocess, "run", failing_mpirun) + + # collect() may raise or return a partial result; either is acceptable + # as long as no destructive call leaks out. + try: + collector.collect() + except Exception: + pass + + self._assert_no_destructive_calls(recorded) + + def test_staged_script_persists_after_run(self, tmp_path, monkeypatch): + """After a successful collect(), the staged script is still on disk.""" + collector = _collector(tmp_path, hosts=["host-a:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, _calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + script_path = _expected_script_path(tmp_path) + assert os.path.isfile(script_path), ( + f"staged collector script must persist as a run artifact; " + f"expected at {script_path}" + ) + + def test_rerun_is_idempotent(self, tmp_path, monkeypatch): + """Two consecutive collects against the same results_dir must not + fail on pre-existing staging dir or stale script file.""" + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, _calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + _collector(tmp_path, hosts=["host-a:1"]).collect() + # Second collector reuses the same staging dir; must not raise. + _collector(tmp_path, hosts=["host-a:1"]).collect() + + assert os.path.isfile(_expected_script_path(tmp_path)) + + +class TestMkdirQuoting: + """russfellows: the remaining ``mkdir -p`` path must be single-quoted + when it appears inside an ssh shell command string.""" + + def test_remote_mkdir_path_is_single_quoted(self, tmp_path, monkeypatch): + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + ssh_mkdir_calls = [ + c for c in calls + if c["kind"] == "ssh" + and any("mkdir" in a for a in c["argv"]) + ] + assert ssh_mkdir_calls, \ + "expected at least one ssh mkdir call for the remote staging dir" + + expected = _expected_staging_dir(tmp_path) + quoted = f"'{expected}'" + assert any( + any(quoted in a for a in c["argv"]) for c in ssh_mkdir_calls + ), ( + f"staging path must be single-quoted in ssh mkdir args; " + f"expected substring {quoted!r} in one of {ssh_mkdir_calls!r}" + ) + + +class TestSharedStagingDir: + """Opt-in fast path: when shared_staging_dir is set, no SSH staging at all. + + (Formerly ``shared_tmp_dir``; renamed to reflect that staging is no longer + a tempdir concept.) + """ + + def test_shared_staging_dir_skips_staging(self, tmp_path, monkeypatch): + shared = tmp_path / "shared_scratch" + shared.mkdir() + + collector = _collector( + tmp_path, + hosts=["host-a:1", "host-b:1"], + shared_staging_dir=str(shared), + ) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_holder = {} + original_makedirs = os.makedirs + + def spy_makedirs(path, *a, **kw): + if "output_path" not in output_holder and str(shared) in path: + output_holder["output_path"] = os.path.join( + path, "cluster_info.json" + ) + return original_makedirs(path, *a, **kw) + + fake_run, calls = _make_fake_run( + lambda: output_holder["output_path"] + ) + monkeypatch.setattr(cc.os, "makedirs", spy_makedirs) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + # Zero ssh/scp calls when a shared staging dir is provided + assert not any(c["kind"] in ("ssh", "scp") for c in calls), \ + f"shared_staging_dir path must not SSH; got {calls}" + + # The working dir must live under the shared path + mpi_call = next( + c for c in calls if c["kind"] in ("mpirun", "mpiexec") + ) + joined = " ".join(mpi_call["argv"]) + assert str(shared) in joined, \ + f"mpirun command must use shared_staging_dir path; got: {joined}" + + +class TestStagingFailure: + """Staging failure must raise a clear error naming the bad host.""" + + def test_stage_failure_raises_with_host_info(self, tmp_path, monkeypatch): + collector = _collector(tmp_path, hosts=["host-a:1", "bad-host:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + def fake_run(cmd, *args, **kwargs): + # Every ssh/scp to bad-host fails + argv = cmd if isinstance(cmd, list) else cmd.split() + if argv[0] in ("ssh", "scp") and any( + "bad-host" in a for a in argv + ): + return subprocess.CompletedProcess( + argv, 255, stdout="", + stderr="ssh: connect to host bad-host port 22: " + "Connection refused", + ) + # mpirun should never be reached in this test + if argv[0] in ("mpirun", "mpiexec"): + pytest.fail( + "mpirun must not run when staging failed on any host" + ) + return subprocess.CompletedProcess(argv, 0, stdout="", stderr="") + + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + with pytest.raises(RuntimeError) as excinfo: + collector.collect() + + msg = str(excinfo.value) + assert "bad-host" in msg, f"error must name the failing host; got: {msg}" + assert "stage" in msg.lower() or "staging" in msg.lower() \ + or "passwordless ssh" in msg.lower(), \ + f"error must mention staging/SSH; got: {msg}" + + +class TestLoggingVisibility: + """russfellows: staging progress and staged-script path must be visible + at the default INFO log level, not DEBUG.""" + + def test_staging_progress_logged_at_info(self, tmp_path, monkeypatch, caplog): + collector = _collector(tmp_path, hosts=["host-a:1", "host-b:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, _calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + with caplog.at_level(logging.INFO, logger="test.cluster_collector"): + collector.collect() + + info_msgs = [ + r.getMessage() for r in caplog.records if r.levelno == logging.INFO + ] + + assert any("stag" in m.lower() for m in info_msgs), ( + f"expected a 'staging' INFO line at default level; got: {info_msgs}" + ) + assert any("mpi" in m.lower() for m in info_msgs), ( + f"expected an MPI-related INFO line at default level; got: {info_msgs}" + ) + + def test_staged_script_path_logged_at_info(self, tmp_path, monkeypatch, caplog): + """Absolute staged-script path must appear in INFO output so users + can find it post-run for debugging.""" + collector = _collector(tmp_path, hosts=["host-a:1"]) + monkeypatch.setattr(cc, "_is_localhost", lambda h: h == "host-a") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, _calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + with caplog.at_level(logging.INFO, logger="test.cluster_collector"): + collector.collect() + + script_path = _expected_script_path(tmp_path) + info_msgs = [ + r.getMessage() for r in caplog.records if r.levelno == logging.INFO + ] + assert any(script_path in m for m in info_msgs), ( + f"expected staged-script absolute path {script_path!r} in INFO " + f"logs; got: {info_msgs}" + ) + + +class TestConstructionPreconditions: + """results_dir is required under the new design; without it there is no + defensible staging location (tempdir-based staging is gone).""" + + def test_missing_results_dir_raises(self): + logger = logging.getLogger("test.cluster_collector") + with pytest.raises((ValueError, TypeError)): + cc.MPIClusterCollector( + hosts=["host-a:1"], + mpi_bin=MPIRUN, + logger=logger, + timeout_seconds=30, + results_dir=None, + ) + + +class TestX11Silence: + """The mpirun subprocess must receive an env that disables X11 forwarding.""" + + def test_plm_rsh_agent_disables_x11(self, tmp_path, monkeypatch): + collector = _collector(tmp_path, hosts=["127.0.0.1:1"]) + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + # Ensure the test environment does NOT pre-set PLM_RSH_AGENT, so + # we are verifying that the collector itself injects it. + monkeypatch.delenv("PLM_RSH_AGENT", raising=False) + + collector.collect() + + mpi_call = next( + c for c in calls if c["kind"] in ("mpirun", "mpiexec") + ) + env = mpi_call["env"] + assert env is not None, "mpirun must be invoked with a custom env" + assert "PLM_RSH_AGENT" in env, \ + "PLM_RSH_AGENT must be set to silence X11 warnings" + assert "ForwardX11=no" in env["PLM_RSH_AGENT"], \ + f"PLM_RSH_AGENT must disable X11 forwarding; got {env['PLM_RSH_AGENT']!r}" + + def test_existing_plm_rsh_agent_is_preserved(self, tmp_path, monkeypatch): + """If the user has their own PLM_RSH_AGENT, don't clobber it.""" + collector = _collector(tmp_path, hosts=["127.0.0.1:1"]) + monkeypatch.setenv("PLM_RSH_AGENT", "ssh -i /custom/key") + + output_path = os.path.join( + _expected_staging_dir(tmp_path), "cluster_info.json" + ) + fake_run, calls = _make_fake_run(lambda: output_path) + monkeypatch.setattr(cc.subprocess, "run", fake_run) + + collector.collect() + + mpi_call = next( + c for c in calls if c["kind"] in ("mpirun", "mpiexec") + ) + assert mpi_call["env"]["PLM_RSH_AGENT"] == "ssh -i /custom/key", \ + "user-provided PLM_RSH_AGENT must be preserved" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 38637628..aa53855a 100755 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -483,6 +483,116 @@ def test_splits_comma_separated_hosts(self): update_args(args) assert args.hosts == ['host1', 'host2', 'host3'] + # ------------------------------------------------------------------- + # Regression tests for https://github.com/mlcommons/storage/issues/322 + # + # These exercise every form of `--hosts` the CLI can plausibly receive, + # including the forms that used to silently produce a single "host" + # containing whitespace and then crash `ssh`. + # ------------------------------------------------------------------- + + def test_hosts_space_separated_list_unchanged(self): + """`--hosts h1 h2 h3` -> argparse nargs='+' gives a clean list; pass through.""" + args = argparse.Namespace( + hosts=['host1', 'host2', 'host3'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['host1', 'host2', 'host3'] + + def test_hosts_single_quoted_space_separated_string(self): + """`--hosts 'h1 h2 h3'` -> one token with spaces must be split (issue #322 Sample 3).""" + args = argparse.Namespace( + hosts=['srt017-e0 srt018-e0'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['srt017-e0', 'srt018-e0'] + + def test_hosts_equals_quoted_space_separated(self): + """`--hosts='h1 h2 h3'` -> same single-token-with-spaces case (issue #322 Sample 3).""" + args = argparse.Namespace( + hosts=['host-a host-b host-c'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['host-a', 'host-b', 'host-c'] + + def test_hosts_mixed_comma_and_space(self): + """Accept mixed comma/space separators in a single token.""" + args = argparse.Namespace( + hosts=['host1, host2,host3 host4'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['host1', 'host2', 'host3', 'host4'] + + def test_hosts_mixed_list_and_internal_split(self): + """A list where some entries need splitting and others don't.""" + args = argparse.Namespace( + hosts=['h1', 'h2 h3', 'h4,h5'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['h1', 'h2', 'h3', 'h4', 'h5'] + + def test_hosts_preserves_slots_suffix(self): + """`host:N` slot notation must survive the split.""" + args = argparse.Namespace( + hosts=['host1:2 host2:4'], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['host1:2', 'host2:4'] + + def test_hosts_strips_stray_whitespace_and_empty_tokens(self): + """Multiple spaces and leading/trailing whitespace don't produce empty entries.""" + args = argparse.Namespace( + hosts=[' host1 host2 ,,, host3 '], + params=None, + mpi_params=None, + ) + update_args(args) + assert args.hosts == ['host1', 'host2', 'host3'] + + def test_hosts_empty_after_parsing_exits(self): + """An input that normalizes to zero tokens is a user error; exit cleanly.""" + args = argparse.Namespace( + hosts=[' ,,, '], + params=None, + mpi_params=None, + ) + with pytest.raises(SystemExit): + update_args(args) + + def test_num_client_hosts_derived_when_none(self): + """When argparse leaves num_client_hosts=None (user didn't pass it), derive from hosts.""" + args = argparse.Namespace( + hosts=['h1', 'h2', 'h3'], + num_client_hosts=None, + params=None, + mpi_params=None, + ) + update_args(args) + assert args.num_client_hosts == 3 + + def test_num_client_hosts_respected_when_set(self): + """An explicit --num-client-hosts must not be overwritten.""" + args = argparse.Namespace( + hosts=['h1', 'h2'], + num_client_hosts=5, + params=None, + mpi_params=None, + ) + update_args(args) + assert args.num_client_hosts == 5 + def test_sets_num_client_hosts_from_hosts(self): """Should set num_client_hosts from hosts length.""" args = argparse.Namespace( @@ -503,7 +613,12 @@ def test_sets_default_runtime_for_vectordb(self): ) update_args(args) assert args.runtime is not None - + + def test_num_client_hosts_zero_is_preserved(self): + """Regression: --num-client-hosts 0 must not be re-derived from len(hosts).""" + args = Namespace(hosts=['h1', 'h2', 'h3'], num_client_hosts=0) + update_args(args) + assert args.num_client_hosts == 0 class TestApplyYamlConfigOverrides: """Tests for apply_yaml_config_overrides function.""" diff --git a/tests/unit/test_environment.py b/tests/unit/test_environment.py index 9c875feb..f9e42cdf 100755 --- a/tests/unit/test_environment.py +++ b/tests/unit/test_environment.py @@ -431,6 +431,31 @@ def test_skips_127_0_0_1(self): assert 'skipped' in results[0][2].lower() mock_run.assert_not_called() + def test_rejects_whitespace_token_without_running_ssh(self): + """Regression for #322: a host token with whitespace must fail fast, not call ssh.""" + with patch('shutil.which', return_value='/usr/bin/ssh'): + with patch('subprocess.run') as mock_run: + results = validate_ssh_connectivity(['srt017-e0 srt018-e0']) + + assert len(results) == 1 + host, ok, message = results[0] + assert host == 'srt017-e0 srt018-e0' + assert ok is False + # Message should point the user at the argparse pitfall, + # not leave them guessing about SSH. + assert 'whitespace' in message.lower() or 'invalid host token' in message.lower() + mock_run.assert_not_called() + + def test_rejects_empty_token_without_running_ssh(self): + """An empty or whitespace-only host token should be rejected, not passed to ssh.""" + with patch('shutil.which', return_value='/usr/bin/ssh'): + with patch('subprocess.run') as mock_run: + results = validate_ssh_connectivity([' ']) + + assert len(results) == 1 + assert results[0][1] is False + mock_run.assert_not_called() + def test_parses_host_slots_format(self): """Should parse host:slots format correctly.""" with patch('shutil.which', return_value='/usr/bin/ssh'):