diff --git a/src/mldebug/input_parser.py b/src/mldebug/input_parser.py index 93e64d3..5a1e1c1 100644 --- a/src/mldebug/input_parser.py +++ b/src/mldebug/input_parser.py @@ -17,6 +17,7 @@ from mldebug.arch import load_aie_arch, AIE_DEV_PHX, AIE_DEV_STX, AIE_DEV_TEL from mldebug.backend.core_dump_impl import CoreDumpFallbackReader +from mldebug.backend.factory import BackendConfig, create_backend from mldebug.utils import LOGGER, cleanup_and_exit, input_with_timeout, is_aarch64, is_windows # Seconds to wait at interactive prompts before giving up and exiting. @@ -256,13 +257,73 @@ def print_hw_context_table(current_contexts: dict[str, dict[str, str]]) -> None: LOGGER.log(f"{context:<12} {columns_str:<30} {context_data['pid']:<12} {context_data['status']:<12}") +def _validate_contexts_with_read(contexts: dict, device: str, aie_iface) -> list[tuple[int, int]] | None: + """ + Validate ALL contexts by reading CORE_STATUS register (verifies register access) + + Args: + contexts: All hardware contexts from xrt-smi (context_id -> info incl. status) + device: Device name (for backend initialization) + aie_iface: Already-loaded AIE interface, or None to load it + + Returns: + List of (context_id, pid) tuples that passed validation, or None if none passed. + """ + # Use first AIE core tile for test read + # Tile layout: Row 0=Shim, Rows 1 to (OFFSET-1)=Memory, Rows OFFSET+=AIE cores + # For Telluride: (0, 3), For PHX/STX: (0, 2) + test_col = 0 + test_row = aie_iface.AIE_TILE_ROW_OFFSET + + # CORE_STATUS register - safe read-only register + # Device-specific addresses: Telluride=0x38004, PHX/STX=0x32004 + test_reg = aie_iface.Core_registers["CORE_STATUS"] + test_tiles = [(test_col, test_row)] + + valid_contexts = [] + for ctx_id, ctx_info in contexts.items(): + backend = None + try: + pid = int(ctx_info["pid"]) + ctx = int(ctx_id) + + config = BackendConfig( + tiles=test_tiles, + ctx_id=ctx, + pid=pid, + device=device, + ) + backend = create_backend("xrt", config) + + backend.read_register(test_col, test_row, test_reg) + valid_contexts.append((ctx, pid)) + + # TODO: catch device-specific errors (e.g. EBUSY from XRT) instead of Exception + except Exception as e: + print(f"[DEBUG] Context {ctx_id} failed validation: {type(e).__name__}: {e}") + continue + + # Clean up the test backend to avoid resource leaks + finally: + del backend + + if not valid_contexts: + print("[WARNING] No contexts passed validation") + return None + return valid_contexts + + def check_hw_context(args) -> tuple[int, int]: """ - Returns (ctx_id, pid) from xrt-smi, prompting the user as a fallback. - Manual prompts time out after ``HW_CONTEXT_INPUT_TIMEOUT_S`` seconds and - call ``cleanup_and_exit(args, 1)`` on failure / timeout. + Returns (ctx_id, pid) from xrt-smi. + + 1. If only one context exists, auto-select it. + 2. If multiple exist, validate all (Active and Idle) with a CORE_STATUS register read. + 3. If no context passes validation, prompt the user (60s timeout; invalid input or timeout + calls ``cleanup_and_exit(args, 1)``). """ device = args.device + aie_iface = args.aie_iface filename = "xrt-smi_output.json" use_shell = is_windows() @@ -290,14 +351,23 @@ def check_hw_context(args) -> tuple[int, int]: if not current_contexts: print("Warning: xrt-smi could find no applications running. Please launch an application to use MLDebugger.") raise FileNotFoundError + + # Path 1: Single context found -> auto-select it if len(current_contexts) == 1: ctx = int(list(current_contexts.keys())[0]) pid = int(list(current_contexts.values())[0]["pid"]) - else: + return ctx, pid + + # Path 2: Multiple contexts found -> validate all with register read test + print(f"[INFO] Found {len(current_contexts)} hardware context(s). Validating with register read test...") + valid_contexts = _validate_contexts_with_read(current_contexts, device, aie_iface) + + # Path 2a: No contexts passed validation -> prompt user for input + if valid_contexts is None: print_hw_context_table(current_contexts) # Ask user selected_context_id = input_with_timeout( - "Multiple Contexts Found. Please enter the Context ID you want to select: ", + "No Contexts passed validation. Please enter the Context ID you want to select: ", HW_CONTEXT_INPUT_TIMEOUT_S, ) if selected_context_id in current_contexts: @@ -306,6 +376,33 @@ def check_hw_context(args) -> tuple[int, int]: else: LOGGER.log("Could not find the provided context, Exiting now.") cleanup_and_exit(args, 1) + return ctx, pid + + # Path 2b: Single valid context found -> auto-select it + elif len(valid_contexts) == 1: + ctx, pid = valid_contexts[0] + return ctx, pid + + # Path 2c: Multiple valid contexts found -> prompt user for input + else: + lookup = {str(ctx): (ctx, pid) for ctx, pid in valid_contexts} + valid_ids = set(lookup.keys()) + valid_only = {k: v for k, v in current_contexts.items() if str(k) in valid_ids} + print_hw_context_table(valid_only) + # Ask user + selected_context_id = input_with_timeout( + f"{len(valid_contexts)} Contexts passed validation. " + "Please enter the Context ID you want to select: ", + HW_CONTEXT_INPUT_TIMEOUT_S, + ) + if selected_context_id in valid_only: + ctx = int(selected_context_id) + pid = int(valid_only[selected_context_id]["pid"]) + else: + LOGGER.log(f"Context ID {selected_context_id} not found. Valid options: {', '.join(valid_only.keys())}") + cleanup_and_exit(args, 1) + return ctx, pid + except (FileNotFoundError, subprocess.CalledProcessError, json.JSONDecodeError): LOGGER.log( f"Error with xrt-smi. Please enter ctx, pid manually "