diff --git a/src/ucode/agents/claude.py b/src/ucode/agents/claude.py index 85cbfc3..d5bcd6f 100644 --- a/src/ucode/agents/claude.py +++ b/src/ucode/agents/claude.py @@ -7,6 +7,7 @@ import shutil import subprocess from pathlib import Path +from typing import cast from ucode.agent_updates import available_npm_package_update from ucode.config_io import ( @@ -63,19 +64,27 @@ def _resolve_web_search_model(state: dict) -> str | None: WEB_SEARCH_MCP_NAME = "web_search" _CLAUDE_MODEL_RE = re.compile(r"^databricks-claude-(opus|sonnet)-(\d+)-(\d+)(.*)$") -# Env keys consumed by the MLflow Claude tracing plugin. Written into the -# settings `env` block; the plugin runtime (installed separately) reads them. +# Env keys the MLflow Stop hook reads to route traces. Written into the +# settings `env` block alongside the hook itself. CLAUDE_TRACING_ENV_KEYS = ( "MLFLOW_CLAUDE_TRACING_ENABLED", "MLFLOW_TRACKING_URI", "MLFLOW_EXPERIMENT_ID", + "MLFLOW_TRACING_SQL_WAREHOUSE_ID", ) -CLAUDE_TRACING_MARKETPLACE = "mlflow/mlflow" -CLAUDE_TRACING_PLUGIN = "mlflow-tracing@mlflow-plugins" -# The plugin runtime shells out to the `mlflow` CLI, so it must be on PATH at -# this minimum version. ucode installs/upgrades it via `uv tool`. -MLFLOW_CLI_SPEC = "mlflow[databricks]>=3.4" -MINIMUM_MLFLOW_VERSION = (3, 4) +CLAUDE_TRACING_STOP_HOOK_SUFFIX = " autolog claude stop-hook" +# Tracing is driven by an `mlflow autolog claude stop-hook` Stop hook, run by +# the `mlflow` CLI on each session end. Pin to 3.11.x: 3.12 dropped the Unity +# Catalog trace-write path, so traces silently land in the classic store +# instead of the experiment's UC table. ucode installs this via `uv tool` at +# `configure tracing` time (where UV_INDEX_URL is set), then writes the hook +# with the resolved absolute path — so the hook needs no uv or index at run +# time, and can't be shadowed by a project venv's mlflow. +MLFLOW_CLI_SPEC = "mlflow[databricks]>=3.11,<3.12" +MINIMUM_MLFLOW_VERSION = (3, 11) +# Upper bound (exclusive) — an installed mlflow at or above this is too new and +# must be replaced, not just left alone. +MAXIMUM_MLFLOW_VERSION = (3, 12) def _web_search_mcp_entry(workspace: str, search_model: str, profile: str | None = None) -> dict: @@ -215,18 +224,31 @@ def write_tool_config(state: dict, model: str) -> dict: profile=state.get("profile"), ) tracing_env_vars = tracing_env(state, "claude") + stop_hook_command = claude_tracing_stop_hook_command() if tracing_env_vars else None if tracing_env_vars: overlay["env"]["MLFLOW_CLAUDE_TRACING_ENABLED"] = "true" overlay["env"].update(tracing_env_vars) managed_keys = managed_keys + [["env", key] for key in CLAUDE_TRACING_ENV_KEYS] + if stop_hook_command: + managed_keys = managed_keys + [["hooks", "Stop"]] + else: + print_warning( + "MLflow tracing env was written, but the `mlflow` CLI could not be located " + "to install the Claude Stop hook — traces won't be emitted. Re-run " + "`ucode configure tracing`." + ) existing = read_json_safe(CLAUDE_SETTINGS_PATH) merged = deep_merge_dict(existing, overlay) + if tracing_env_vars and stop_hook_command: + _upsert_tracing_stop_hook(merged, stop_hook_command) if not tracing_env_vars: env_block = merged.get("env") if isinstance(env_block, dict): for key in CLAUDE_TRACING_ENV_KEYS: env_block.pop(key, None) + # Strip only ucode's tracing Stop hook so user hooks stay intact. + _remove_tracing_stop_hook(merged) write_json_file(CLAUDE_SETTINGS_PATH, merged) if web_search_model: @@ -237,15 +259,67 @@ def write_tool_config(state: dict, model: str) -> dict: return state +def _is_tracing_stop_hook(hook: object) -> bool: + if not isinstance(hook, dict): + return False + hook = cast(dict, hook) + if hook.get("type") != "command": + return False + command = hook.get("command") + return isinstance(command, str) and command.endswith(CLAUDE_TRACING_STOP_HOOK_SUFFIX) + + +def _remove_tracing_stop_hook(settings: dict) -> None: + hooks = settings.get("hooks") + if not isinstance(hooks, dict): + return + stop_entries = hooks.get("Stop") + if not isinstance(stop_entries, list): + return + + cleaned_entries = [] + for entry in stop_entries: + if not isinstance(entry, dict): + cleaned_entries.append(entry) + continue + hook_list = entry.get("hooks") + if not isinstance(hook_list, list): + cleaned_entries.append(entry) + continue + cleaned_hooks = [hook for hook in hook_list if not _is_tracing_stop_hook(hook)] + if cleaned_hooks: + cleaned_entry = dict(entry) + cleaned_entry["hooks"] = cleaned_hooks + cleaned_entries.append(cleaned_entry) + + if cleaned_entries: + hooks["Stop"] = cleaned_entries + else: + hooks.pop("Stop", None) + if not hooks: + settings.pop("hooks", None) + + +def _upsert_tracing_stop_hook(settings: dict, command: str) -> None: + _remove_tracing_stop_hook(settings) + hooks = settings.get("hooks") + if not isinstance(hooks, dict): + hooks = {} + settings["hooks"] = hooks + stop_entries = hooks.get("Stop") + if not isinstance(stop_entries, list): + stop_entries = [] + hooks["Stop"] = stop_entries + stop_entries.append({"hooks": [{"type": "command", "command": command}]}) + + def ensure_tracing_runtime() -> bool: - """Ensure Claude's MLflow tracing runtime is ready: an `mlflow` CLI >= 3.4 on - PATH (the plugin shells out to it) and the MLflow Claude plugin installed. + """Ensure the MLflow tracing runtime is ready: a pinned `mlflow` CLI (3.11.x) + installed via `uv tool`, whose absolute path the Stop hook will call. - Best-effort — warns and returns False if a piece can't be set up, so + Best-effort — warns and returns False if it can't be set up, so `ucode configure tracing` can still finish for other agents.""" - if not _ensure_mlflow_cli(): - return False - return _install_claude_tracing_plugin() + return _ensure_mlflow_cli() def _parse_mlflow_version(text: str) -> tuple[int, int] | None: @@ -255,37 +329,76 @@ def _parse_mlflow_version(text: str) -> tuple[int, int] | None: return int(match.group(1)), int(match.group(2)) +def _uv_tool_mlflow_path() -> str | None: + """Absolute path to the `mlflow` installed by `uv tool`, or None. + + Resolved from `uv tool dir --bin` rather than ``shutil.which`` so a project + venv's (possibly wrong-versioned) mlflow can't shadow the one ucode pins — + the Stop hook must always run the uv-tool copy.""" + if not shutil.which("uv"): + return None + try: + result = subprocess.run( + ["uv", "tool", "dir", "--bin"], + check=False, + capture_output=True, + text=True, + timeout=30, + ) + except (OSError, subprocess.TimeoutExpired): + return None + bin_dir = (result.stdout or "").strip() + if result.returncode != 0 or not bin_dir: + return None + candidate = Path(bin_dir) / "mlflow" + return str(candidate) if candidate.exists() else None + + def _installed_mlflow_version() -> tuple[int, int] | None: - """The (major, minor) of the `mlflow` CLI on PATH, or None if absent.""" - if not shutil.which("mlflow"): + """The (major, minor) of the uv-tool `mlflow`, or None if absent.""" + path = _uv_tool_mlflow_path() + if not path: return None try: result = subprocess.run( - ["mlflow", "--version"], check=False, capture_output=True, text=True, timeout=30 + [path, "--version"], check=False, capture_output=True, text=True, timeout=30 ) except (OSError, subprocess.TimeoutExpired): return None return _parse_mlflow_version(result.stdout or result.stderr or "") +def claude_tracing_stop_hook_command() -> str | None: + """The Stop hook command string: the absolute uv-tool `mlflow` invoking its + `autolog claude stop-hook` handler. None when mlflow isn't installed. + + Using the absolute path means the hook needs neither `uv` nor a package + index at run time (the minimal env Claude runs hooks in lacks UV_INDEX_URL), + and can't be shadowed by another mlflow on PATH.""" + path = _uv_tool_mlflow_path() + if not path: + return None + return f"{path} autolog claude stop-hook" + + def _ensure_mlflow_cli() -> bool: - """Ensure an `mlflow` CLI >= 3.4 is on PATH, installing or upgrading it via - `uv tool` when needed.""" + """Ensure the pinned `mlflow` CLI (3.11.x) is installed via `uv tool`, + installing or replacing an out-of-range version when needed.""" current = _installed_mlflow_version() - if current and current >= MINIMUM_MLFLOW_VERSION: + if current and MINIMUM_MLFLOW_VERSION <= current < MAXIMUM_MLFLOW_VERSION: return True if not shutil.which("uv"): - verb = "upgrade" if current else "install" + verb = "replace" if current else "install" print_warning( - f"Claude tracing needs the `mlflow` CLI >= 3.4 on PATH, but `uv` is not " - f'available to {verb} it. Run `uv tool install "{MLFLOW_CLI_SPEC}"` ' - f'(or `pip install "{MLFLOW_CLI_SPEC}"`), then re-run `ucode configure tracing`.' + f"Claude tracing needs the `mlflow` CLI ({MLFLOW_CLI_SPEC}), but `uv` is not " + f'available to {verb} it. Run `uv tool install "{MLFLOW_CLI_SPEC}"`, then ' + "re-run `ucode configure tracing`." ) return False - print_note(f"{'Upgrading' if current else 'Installing'} the mlflow CLI ({MLFLOW_CLI_SPEC})...") - # --force replaces an existing (older) uv-managed mlflow tool in place. + print_note(f"{'Replacing' if current else 'Installing'} the mlflow CLI ({MLFLOW_CLI_SPEC})...") + # --force replaces an existing (out-of-range) uv-managed mlflow in place. cmd = ["uv", "tool", "install", MLFLOW_CLI_SPEC] if current: cmd.append("--force") @@ -295,54 +408,16 @@ def _ensure_mlflow_cli() -> bool: print_warning(f"Could not install the mlflow CLI automatically: {exc}") return False - if not shutil.which("mlflow"): + if not _uv_tool_mlflow_path(): print_warning( - "Installed mlflow, but `mlflow` is still not on PATH. Ensure your uv tool " - "bin directory (e.g. ~/.local/bin) is on PATH, then re-run `ucode configure tracing`." + "Installed mlflow via `uv tool`, but its binary could not be located. " + "Re-run `ucode configure tracing`." ) return False print_success("mlflow CLI ready") return True -def _install_claude_tracing_plugin() -> bool: - binary = SPEC["binary"] - if not shutil.which(binary): - print_warning("`claude` is not installed; skipping MLflow tracing plugin install.") - return False - commands = [ - [ - binary, - "plugin", - "marketplace", - "add", - CLAUDE_TRACING_MARKETPLACE, - "--sparse", - ".claude-plugin", - ], - [binary, "plugin", "install", CLAUDE_TRACING_PLUGIN], - ] - for cmd in commands: - try: - result = subprocess.run(cmd, check=False, capture_output=True, text=True, timeout=120) - except (OSError, subprocess.TimeoutExpired) as exc: - print_warning(f"Could not install the Claude MLflow plugin: {exc}") - return False - if result.returncode != 0: - output = (result.stderr or result.stdout or "").strip() - last = output.splitlines()[-1] if output else f"exit {result.returncode}" - # `marketplace add` / `install` are idempotent; treat "already - # added/installed" as success and keep going. Best-effort match - # against stderr — an upstream wording change would degrade this - # to a noisy warning on re-runs, but never corrupts state. - if "already" in last.lower(): - continue - print_warning(f"Claude MLflow plugin step failed: {last}") - return False - print_success("Claude MLflow tracing plugin installed") - return True - - def default_model(state: dict) -> str | None: claude_models = state.get("claude_models") or {} return claude_models.get("opus") or claude_models.get("sonnet") or claude_models.get("haiku") diff --git a/src/ucode/agents/codex.py b/src/ucode/agents/codex.py index b154338..3719eea 100644 --- a/src/ucode/agents/codex.py +++ b/src/ucode/agents/codex.py @@ -4,7 +4,6 @@ import os import re -import subprocess from pathlib import Path from ucode.agent_updates import available_npm_package_update @@ -23,8 +22,6 @@ ) from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version -from ucode.tracing import agent_tracing, apply_tracing_env -from ucode.ui import print_success, print_warning CODEX_CONFIG_DIR = Path.home() / ".codex" CODEX_PROFILE_NAME = "ucode" @@ -35,8 +32,6 @@ CODEX_MODEL_PROVIDER_NAME = "ucode-databricks" MINIMUM_CODEX_VERSION = (0, 134, 0) MINIMUM_CODEX_VERSION_TEXT = "0.134.0" -CODEX_TRACING_NOTIFY = ["mlflow-codex", "notify-hook"] -CODEX_TRACING_PACKAGE = "@mlflow/codex" SPEC: ToolSpec = { @@ -240,7 +235,6 @@ def write_tool_config(state: dict, model: str | None = None) -> dict: overlay = render_legacy_overlay(workspace, chosen_model, databricks_profile) doc = read_toml_safe(LEGACY_CODEX_CONFIG_PATH) deep_merge_dict(doc, overlay) - _apply_tracing_notify(doc, state) write_toml_file(LEGACY_CODEX_CONFIG_PATH, doc) state = mark_tool_managed(state, "codex", LEGACY_MANAGED_KEYS) save_state(state) @@ -251,55 +245,12 @@ def write_tool_config(state: dict, model: str | None = None) -> dict: overlay = render_overlay(workspace, chosen_model, databricks_profile) doc = read_toml_safe(CODEX_CONFIG_PATH) deep_merge_dict(doc, overlay) - _apply_tracing_notify(doc, state) write_toml_file(CODEX_CONFIG_PATH, doc) state = mark_tool_managed(state, "codex", MANAGED_KEYS) save_state(state) return state -def _apply_tracing_notify(doc: dict, state: dict) -> None: - """Set/clear the Codex ``notify`` hook that streams session traces to MLflow. - - Only ucode's own notify value is removed on disable, so a user-defined - ``notify`` is left intact. When enabling on top of a pre-existing user - ``notify``, warn before replacing — the prior value is in the backup file - but the user has to restore it manually.""" - if agent_tracing(state, "codex") is not None: - existing = doc.get("notify") - if existing is not None and list(existing) != CODEX_TRACING_NOTIFY: - print_warning( - f"Codex `notify` is already set to {existing!r}; replacing it with the " - "MLflow tracing hook. The previous value is preserved in the Codex " - "config backup — restore it manually if you need both." - ) - doc["notify"] = list(CODEX_TRACING_NOTIFY) - elif list(doc.get("notify") or []) == CODEX_TRACING_NOTIFY: - doc.pop("notify", None) - - -def ensure_tracing_dependency() -> bool: - """Install the `@mlflow/codex` npm package that provides the `mlflow-codex` - notify-hook binary. Best-effort: warns and returns False on failure.""" - import shutil - - if shutil.which("mlflow-codex"): - return True - if not shutil.which("npm"): - print_warning( - f"`npm` is not available to install {CODEX_TRACING_PACKAGE}; " - "Codex tracing will be inactive until it is installed." - ) - return False - try: - subprocess.run(["npm", "install", "-g", CODEX_TRACING_PACKAGE], check=True, timeout=300) - except (subprocess.CalledProcessError, subprocess.TimeoutExpired): - print_warning(f"Could not install {CODEX_TRACING_PACKAGE}; Codex tracing will be inactive.") - return False - print_success("Codex MLflow tracing hook installed") - return True - - def default_model(state: dict) -> str | None: """Pick the newest GPT model when multiple are available. @@ -328,10 +279,6 @@ def launch(state: dict, tool_args: list[str]) -> None: workspace = state.get("workspace") if workspace: os.environ["OAUTH_TOKEN"] = get_databricks_token(workspace, state.get("profile")) - # The notify hook subprocess Codex spawns inherits this env, so MLflow - # routing flows through to it without writing a separate tracing config. - # When tracing is off this also clears any stale outer-shell value. - apply_tracing_env(os.environ, state, "codex") os.execvp(binary, [binary, "--profile", CODEX_PROFILE_NAME, *tool_args]) diff --git a/src/ucode/agents/opencode.py b/src/ucode/agents/opencode.py index 8d05e9c..8792625 100644 --- a/src/ucode/agents/opencode.py +++ b/src/ucode/agents/opencode.py @@ -23,14 +23,12 @@ ) from ucode.state import mark_tool_managed, save_state from ucode.telemetry import agent_version, ucode_version -from ucode.tracing import agent_tracing, apply_tracing_env OPENCODE_XDG_CONFIG_HOME = APP_DIR / "opencode-xdg" OPENCODE_CONFIG_DIR = OPENCODE_XDG_CONFIG_HOME / "opencode" OPENCODE_CONFIG_PATH = OPENCODE_CONFIG_DIR / "opencode.json" OPENCODE_BACKUP_PATH = APP_DIR / "opencode-config.backup.json" OPENCODE_MCP_AUTH_HEADER_VALUE = "Bearer {env:OAUTH_TOKEN}" -OPENCODE_TRACING_PLUGIN = "@mlflow/opencode" SPEC: ToolSpec = { "binary": "opencode", @@ -152,29 +150,12 @@ def write_tool_config( for stale in ("databricks-anthropic", "databricks-google", "databricks-openai"): providers.pop(stale, None) merged = deep_merge_dict(existing, overlay) - _apply_tracing_plugin(merged, state) write_json_file(OPENCODE_CONFIG_PATH, merged) state = mark_tool_managed(state, "opencode", managed_keys) save_state(state) return state, token -def _apply_tracing_plugin(config: dict, state: dict) -> None: - """Add/remove the MLflow plugin in opencode.json's top-level ``plugin`` list - to match the current tracing state, leaving any user plugins untouched. - OpenCode auto-installs listed npm plugins at startup.""" - plugins = config.get("plugin") - plugins = ( - [p for p in plugins if p != OPENCODE_TRACING_PLUGIN] if isinstance(plugins, list) else [] - ) - if agent_tracing(state, "opencode") is not None: - plugins.append(OPENCODE_TRACING_PLUGIN) - if plugins: - config["plugin"] = plugins - else: - config.pop("plugin", None) - - def build_mcp_server_entry(url: str) -> dict: return { "type": "remote", @@ -239,10 +220,6 @@ def build_runtime_env(token: str, state: dict | None = None) -> dict[str, str]: env = os.environ.copy() env["OAUTH_TOKEN"] = token env["XDG_CONFIG_HOME"] = str(OPENCODE_XDG_CONFIG_HOME) - if state is not None: - # apply_tracing_env clears the MLflow vars when tracing is off, so a - # stale outer-shell value can't leak through. - apply_tracing_env(env, state, "opencode") return env diff --git a/src/ucode/cli.py b/src/ucode/cli.py index 1dd5295..f32bddf 100644 --- a/src/ucode/cli.py +++ b/src/ucode/cli.py @@ -47,7 +47,7 @@ purge_cross_workspace_mcp_residue, revert_mcp_configs, ) -from ucode.state import STATE_PATH, clear_state, load_state, save_state +from ucode.state import STATE_PATH, clear_state, load_full_state, load_state, save_state from ucode.tracing import configure_tracing_command from ucode.ui import ( console, @@ -394,12 +394,16 @@ def status() -> int: if tracing.get("enabled"): print_kv("MLflow tracing", "enabled") print_kv("Tracking URI", str(tracing.get("tracking_uri") or "unknown")) - for tool, entry in (tracing.get("agents") or {}).items(): - if isinstance(entry, dict): - print_kv( - f"{tool} experiment", - f"{entry.get('experiment_name')} (id {entry.get('experiment_id')})", - ) + print_kv( + "Experiment", + f"{tracing.get('experiment_name')} (id {tracing.get('experiment_id')})", + ) + uc_destination = tracing.get("uc_destination") + if uc_destination: + print_kv("Unity Catalog", str(uc_destination)) + sql_warehouse_id = tracing.get("sql_warehouse_id") + if sql_warehouse_id: + print_kv("SQL warehouse", str(sql_warehouse_id)) else: print_kv("MLflow tracing", "disabled") @@ -607,6 +611,13 @@ def configure( help="Configure a comma-separated list of workspaces without prompting.", ), ] = None, + tracing: Annotated[ + bool, + typer.Option( + "--tracing", + help="Also enable MLflow tracing for the configured workspace(s).", + ), + ] = False, ) -> None: """Configure workspace URL and AI Gateway.""" if ctx.invoked_subcommand is not None: @@ -639,6 +650,16 @@ def configure( configure_workspace_command() else: configure_workspace_command(workspaces=workspace_entries) + if tracing: + # The workspaces were just configured, so enable tracing for them + # directly instead of re-prompting. Fall back to the workspace that + # `configure_workspace_command` made current (the interactive pick). + tracing_workspaces = workspace_entries + if tracing_workspaces is None: + current = load_full_state().get("current_workspace") + tracing_workspaces = [(current, None)] if current else None + if tracing_workspaces: + configure_tracing_command(workspaces=tracing_workspaces) except RuntimeError as exc: print_err(str(exc)) raise typer.Exit(1) from None diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 6b4a4f7..90c1808 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -17,7 +17,7 @@ from typing import Literal, cast, overload from urllib import error as urllib_error from urllib import request as urllib_request -from urllib.parse import quote, urlparse +from urllib.parse import urlparse from databricks.sql.exc import ServerOperationError @@ -300,56 +300,112 @@ def get_current_user_name(workspace: str, token: str) -> str | None: return None -def _mlflow_experiment_id_from_payload(payload: dict | list | None) -> str | None: - """Pull an experiment_id from a get-by-name (`{"experiment": {...}}`) or - create (`{"experiment_id": ...}`) MLflow REST response.""" - if not isinstance(payload, dict): - return None - experiment = payload.get("experiment") - if isinstance(experiment, dict) and experiment.get("experiment_id"): - return str(experiment["experiment_id"]) - if payload.get("experiment_id"): - return str(payload["experiment_id"]) +# Experiment tag Databricks sets when an experiment's traces are written to a +# Unity Catalog table. Its value is the UC destination, e.g. +# "my_catalog.my_schema.my_table". A plain (file/DBFS-backed) experiment does +# not carry this tag, so its presence is our signal that traces land in UC. +UC_TRACE_DESTINATION_TAG = "mlflow.experiment.databricksTraceDestinationPath" + + +def _experiment_tags(experiment: dict) -> dict[str, str | None]: + """Flatten an experiment's ``tags`` list ([{key, value}, ...]) into a dict.""" + out: dict[str, str | None] = {} + tags = experiment.get("tags") + if isinstance(tags, list): + for tag in tags: + if isinstance(tag, dict) and isinstance(tag.get("key"), str): + out[tag["key"]] = tag.get("value") + return out + + +def _uc_trace_destination(experiment: dict) -> str | None: + """The Unity Catalog destination (``catalog.schema.table``) an experiment + logs traces to, or None when it isn't UC-backed. Any three-part UC name + qualifies — the specific catalog/schema/table is not constrained.""" + value = _experiment_tags(experiment).get(UC_TRACE_DESTINATION_TAG) + if isinstance(value, str): + parts = value.split(".") + if len(parts) == 3 and all(parts): + return value return None -def get_or_create_mlflow_experiment( - workspace: str, token: str, name: str -) -> tuple[str | None, str | None]: - """Resolve the numeric MLflow experiment id for ``name`` in this workspace, - creating the experiment if it doesn't exist. Returns (experiment_id, reason); - reason is None on success, otherwise describes the failure.""" +def find_uc_backed_experiment( + workspace: str, token: str, leaf_name: str +) -> tuple[dict | None, str | None]: + """Find an existing experiment whose final path segment is ``leaf_name`` and + whose traces are backed by Unity Catalog. + + Returns (experiment, reason). On success ``experiment`` is + ``{"experiment_id", "experiment_name", "uc_destination"}`` and reason is + None. On failure ``experiment`` is None and reason explains why (no such + experiment, or it exists but isn't UC-backed) so the caller can tell the + user to create one.""" hostname = workspace_hostname(workspace) - encoded = quote(name, safe="") - payload, get_reason = _http_get_json( - f"https://{hostname}/api/2.0/mlflow/experiments/get-by-name?experiment_name={encoded}", + # Leaf-match in the filter (anything ending in the name), then confirm the + # exact leaf segment in Python so "/Users//ucode-traces" matches but + # "team-ucode-traces" does not. + safe_leaf = leaf_name.replace("'", "") + payload, reason = _http_post_json( + f"https://{hostname}/api/2.0/mlflow/experiments/search", token, + {"filter": f"name LIKE '%{safe_leaf}'", "max_results": 1000}, ) - existing = _mlflow_experiment_id_from_payload(payload) - if existing: - return existing, None + if not isinstance(payload, dict): + return None, reason or "could not search MLflow experiments" + + experiments = payload.get("experiments") + named = [ + exp + for exp in (experiments if isinstance(experiments, list) else []) + if isinstance(exp, dict) + and str(exp.get("name") or "").rsplit("/", 1)[-1] == leaf_name + and exp.get("experiment_id") + ] + if not named: + return None, f"no experiment named '{leaf_name}' exists on this workspace" + + for exp in named: + dest = _uc_trace_destination(exp) + if dest: + return { + "experiment_id": str(exp["experiment_id"]), + "experiment_name": str(exp.get("name") or leaf_name), + "uc_destination": dest, + }, None - created, create_reason = _http_post_json( - f"https://{hostname}/api/2.0/mlflow/experiments/create", - token, - {"name": name}, + return ( + None, + f"experiment '{leaf_name}' exists but its traces are not backed by Unity Catalog", ) - created_id = _mlflow_experiment_id_from_payload(created) - if created_id: - return created_id, None - - # A concurrent create (or a get that failed transiently) can leave us here - # even though the experiment now exists — re-fetch once before giving up. - if create_reason and "RESOURCE_ALREADY_EXISTS" in create_reason: - payload, _ = _http_get_json( - f"https://{hostname}/api/2.0/mlflow/experiments/get-by-name?experiment_name={encoded}", - token, - ) - retried = _mlflow_experiment_id_from_payload(payload) - if retried: - return retried, None - return None, create_reason or get_reason or "could not resolve MLflow experiment" + +def resolve_sql_warehouse_id(workspace: str, token: str) -> tuple[str | None, str | None]: + """Pick a SQL warehouse for writing traces to a UC-backed experiment. + + Writing traces to a Unity Catalog table requires a SQL warehouse + (``MLFLOW_TRACING_SQL_WAREHOUSE_ID``); without one the MLflow exporter + silently drops them. We prefer a RUNNING warehouse so the first trace isn't + blocked on a cold start, falling back to any existing warehouse (a stopped + one auto-starts on first query). Returns (warehouse_id, reason); reason is + None on success, else explains why none could be resolved.""" + hostname = workspace_hostname(workspace) + payload, reason = _http_get_json(f"https://{hostname}/api/2.0/sql/warehouses", token) + if not isinstance(payload, dict): + return None, reason or "could not list SQL warehouses" + + warehouses = payload.get("warehouses") + warehouses = ( + [w for w in warehouses if isinstance(w, dict) and w.get("id")] + if isinstance(warehouses, list) + else [] + ) + if not warehouses: + return None, "no SQL warehouse exists on this workspace" + + running = next((w for w in warehouses if str(w.get("state")).upper() == "RUNNING"), None) + chosen = running or warehouses[0] + return str(chosen["id"]), None @overload diff --git a/src/ucode/tracing.py b/src/ucode/tracing.py index bf74011..55d99e1 100644 --- a/src/ucode/tracing.py +++ b/src/ucode/tracing.py @@ -1,13 +1,16 @@ -"""MLflow tracing: route coding-agent sessions to a Databricks experiment. +"""MLflow tracing: route Claude Code sessions to a Databricks experiment. -ucode points each agent's MLflow integration at a Databricks-hosted experiment -(tracking URI ``databricks``/``databricks://`` + a numeric experiment -id), reusing the workspace auth ucode already configures in ``~/.databrickscfg``. -Traces land in the experiment's default MLflow trace store (the MLflow Traces -UI), not Unity Catalog. +ucode points Claude Code's MLflow integration at a single, pre-provisioned +experiment named ``ucode-traces`` whose traces are stored in a Unity Catalog +table. ucode asserts the experiment exists and is UC-backed (it does not create +it), resolves a SQL warehouse for UC writes, and persists the tracking URI, +experiment id, and warehouse id. Auth reuses the workspace profile ucode +already configures in ``~/.databrickscfg``. -Scope: Claude Code, OpenCode, Codex. Gemini's exporter is OTLP-only and is not -wired here yet. +Scope: Claude Code only. Its `mlflow autolog claude` Stop hook writes traces to +the UC table. Codex and OpenCode used the `@mlflow/codex`/`@mlflow/opencode` JS +clients, which only reach the classic (non-UC) trace store, so their tracing was +removed. Gemini's exporter is OTLP-only and was never wired here. This module must not import ``mlflow`` (the heavy optional dependency) or ``ucode.agents`` at import time: agents import the small helpers here, and the @@ -16,13 +19,11 @@ from __future__ import annotations -from collections.abc import MutableMapping - from ucode.databricks import ( ensure_databricks_auth, - get_current_user_name, + find_uc_backed_experiment, get_databricks_token, - get_or_create_mlflow_experiment, + resolve_sql_warehouse_id, ) from ucode.state import hydrate_state, load_full_state, save_state, set_current_workspace from ucode.ui import ( @@ -30,23 +31,23 @@ print_note, print_section, print_success, - print_warning, prompt_for_workspace, spinner, ) -# Agents whose MLflow integration routes to a Databricks tracking URI. Gemini is -# excluded: its exporter is OTLP-only and needs a separate endpoint. -TRACING_AGENTS: tuple[str, ...] = ("claude", "opencode", "codex") +# Agents whose MLflow integration routes to a Databricks tracking URI. Only +# Claude Code is supported: its `mlflow autolog claude` Stop hook writes traces +# to the experiment's Unity Catalog table. Codex and OpenCode used the +# `@mlflow/codex`/`@mlflow/opencode` JS clients, which only reach the classic +# (non-UC) trace store, so their tracing was removed. +TRACING_AGENTS: tuple[str, ...] = ("claude",) -# Per-agent experiment-name slug, so each agent's sessions land in their own -# experiment (e.g. `/Users//ucode-claude-code-traces`) rather than a -# single shared one. -AGENT_EXPERIMENT_SLUG: dict[str, str] = { - "claude": "claude-code", - "codex": "codex", - "opencode": "opencode", -} +# Leaf name of the experiment every agent and every user routes to. ucode does +# not create it: an admin provisions an experiment with this name whose traces +# are backed by Unity Catalog, and ucode asserts it exists. The full path can be +# anything (e.g. `/Users//ucode-traces` or `/Shared/ucode-traces`) — only +# the final segment must match, and the traces must land in a UC table. +EXPERIMENT_LEAF_NAME = "ucode-traces" def tracking_uri_for_state(state: dict) -> str: @@ -57,12 +58,39 @@ def tracking_uri_for_state(state: dict) -> str: return f"databricks://{profile}" if profile else "databricks" -def experiment_name(tool: str, user_name: str | None) -> str: - """Per-agent experiment path. Per-user (`/Users//...`) when the - current user resolves, else a shared (`/Shared/...`) path.""" - slug = AGENT_EXPERIMENT_SLUG[tool] - base = f"/Users/{user_name}" if user_name else "/Shared" - return f"{base}/ucode-{slug}-traces" +def experiment_name() -> str: + """The leaf name of the shared, UC-backed experiment ucode requires.""" + return EXPERIMENT_LEAF_NAME + + +def _missing_experiment_error(name: str, reason: str | None) -> str: + """Actionable error when no UC-backed ``ucode-traces`` experiment is found. + ucode no longer creates the experiment, so the fix is for an admin to set + one up with Unity Catalog trace storage.""" + detail = f" ({reason})" if reason else "" + return ( + f"No Unity Catalog-backed MLflow experiment named '{name}' was found on this " + f"workspace{detail}.\n" + "ucode does not create it — an admin must provision one whose traces are " + "stored in Unity Catalog:\n" + f" 1. In the workspace, create an MLflow experiment named '{name}'.\n" + " 2. Configure its trace storage to a Unity Catalog table " + "(any catalog.schema.table you have access to).\n" + " 3. Re-run `ucode configure tracing`." + ) + + +def _missing_warehouse_error(reason: str | None) -> str: + """Actionable error when no SQL warehouse can back UC trace writes. The + warehouse is mandatory: traces to a UC table are silently dropped without + ``MLFLOW_TRACING_SQL_WAREHOUSE_ID``.""" + detail = f" ({reason})" if reason else "" + return ( + f"No SQL warehouse is available to write traces to Unity Catalog{detail}.\n" + "Writing traces to a UC-backed experiment requires a SQL warehouse:\n" + " 1. Create a SQL warehouse in the workspace (SQL > SQL Warehouses).\n" + " 2. Re-run `ucode configure tracing`." + ) def tracing_config(state: dict) -> dict | None: @@ -74,47 +102,41 @@ def tracing_config(state: dict) -> dict | None: def agent_tracing(state: dict, tool: str) -> dict | None: - """The resolved per-agent tracing entry ({experiment_id, experiment_name}) - for ``tool``, or None when tracing is off or that agent has no experiment.""" + """The resolved shared tracing entry ({experiment_id, experiment_name}) when + ``tool`` should be traced, else None. All tracing-capable agents share one + experiment, so this returns the same entry for each; it still gates per-tool + (Gemini/Copilot/Pi never trace) and on the experiment having resolved.""" cfg = tracing_config(state) - if not cfg: + if not cfg or tool not in TRACING_AGENTS: return None - entry = (cfg.get("agents") or {}).get(tool) - if isinstance(entry, dict) and entry.get("experiment_id"): - return entry - return None + if not cfg.get("experiment_id"): + return None + return { + "experiment_id": cfg["experiment_id"], + "experiment_name": cfg.get("experiment_name"), + } def tracing_env(state: dict, tool: str) -> dict[str, str]: """MLflow env vars for one agent. Empty when tracing is disabled for it. The tracking URI carries the profile, so auth resolves from ``~/.databrickscfg`` - without extra vars.""" + without extra vars. + + ``MLFLOW_TRACING_SQL_WAREHOUSE_ID`` is required for the experiment's + Unity Catalog trace table: without it the MLflow exporter silently drops + every trace.""" cfg = tracing_config(state) entry = agent_tracing(state, tool) if not cfg or not entry: return {} - return { + env = { "MLFLOW_TRACKING_URI": str(cfg["tracking_uri"]), "MLFLOW_EXPERIMENT_ID": str(entry["experiment_id"]), } - - -# Keys ``tracing_env`` produces — also the set we actively clear when tracing -# is off, so a stale value already in the outer shell can't leak into the -# agent subprocess and route traces somewhere unintended. -TRACING_ENV_KEYS: tuple[str, ...] = ("MLFLOW_TRACKING_URI", "MLFLOW_EXPERIMENT_ID") - - -def apply_tracing_env(env: MutableMapping[str, str], state: dict, tool: str) -> None: - """Set MLflow tracing vars on ``env`` when tracing is on for ``tool``; - actively remove them when it's off, so an outer-shell value doesn't bleed - into the agent subprocess.""" - new = tracing_env(state, tool) - if new: - env.update(new) - return - for key in TRACING_ENV_KEYS: - env.pop(key, None) + warehouse_id = cfg.get("sql_warehouse_id") + if warehouse_id: + env["MLFLOW_TRACING_SQL_WAREHOUSE_ID"] = str(warehouse_id) + return env def disable_tracing(state: dict) -> dict: @@ -175,8 +197,7 @@ def _select_tracing_workspace(*, only_enabled: bool = False) -> dict: candidates = _tracing_capable_workspaces(full) if not candidates: raise RuntimeError( - "No tracing-capable agents are configured. Run `ucode configure` for " - "Claude Code, OpenCode, or Codex first." + "Claude Code is not configured. Run `ucode configure` for Claude Code first." ) current = full.get("current_workspace") @@ -220,18 +241,19 @@ def _rewrite_agent_configs(state: dict) -> dict: def _install_agent_tracing_deps(state: dict) -> None: - """One-time, per-agent dependency installs (Claude plugin + mlflow CLI, - Codex npm package), only for agents whose experiment resolved. OpenCode's - plugin is auto-installed by OpenCode from the ``plugin`` list.""" - from ucode.agents import claude, codex + """Install the Claude tracing runtime (pinned mlflow CLI) when Claude is + configured on this workspace and has tracing on. Claude is the only + tracing-capable agent.""" + from ucode.agents import claude - if agent_tracing(state, "claude"): + configured = _configured_tracing_agents(state) + if "claude" in configured and agent_tracing(state, "claude"): claude.ensure_tracing_runtime() - if agent_tracing(state, "codex"): - codex.ensure_tracing_dependency() -def configure_tracing_command(disable: bool = False) -> int: +def configure_tracing_command( + disable: bool = False, workspaces: list[tuple[str, str | None]] | None = None +) -> int: # `save_state` (called by us and by every agent config writer underneath # us) flips `current_workspace` to the workspace it's saving. Tracing can # be configured on a non-current workspace, so snapshot here and restore @@ -239,16 +261,43 @@ def configure_tracing_command(disable: bool = False) -> int: # `ucode launch` targets. original_current = load_full_state().get("current_workspace") try: + if workspaces is not None: + return _enable_tracing_for_workspaces(workspaces) return _configure_tracing(disable=disable) finally: set_current_workspace(original_current) +def _enable_tracing_for_workspaces(workspaces: list[tuple[str, str | None]]) -> int: + """Enable tracing for an explicit set of (url, profile) workspaces without + prompting — used by `configure --tracing`, which already knows which + workspace(s) it just set up. Workspaces with no tracing-capable agent are + skipped with a note rather than treated as an error.""" + full = load_full_state() + enabled_any = False + for workspace, profile in workspaces: + state = _hydrate_workspace_entry(full, workspace, profile) + if not _configured_tracing_agents(state): + print_section("MLflow Tracing") + print_note(f"{workspace}: no tracing-capable agents configured — skipping.") + continue + _enable_tracing_for_state(state) + enabled_any = True + return 0 if enabled_any else 1 + + def _configure_tracing(disable: bool) -> int: if disable: return _disable_tracing_command() state = _select_tracing_workspace() + _enable_tracing_for_state(state) + return 0 + + +def _enable_tracing_for_state(state: dict) -> dict: + """Resolve the shared experiment, persist tracing config, install deps, and + rewrite agent configs for one already-selected, hydrated workspace state.""" workspace = state["workspace"] configured = _configured_tracing_agents(state) profile = state.get("profile") @@ -260,37 +309,48 @@ def _configure_tracing(disable: bool) -> int: # Running `ucode configure tracing` is itself the opt-in, so there's no # confirmation prompt; `--disable` is the explicit way back off. token = get_databricks_token(workspace, profile) - user_name = get_current_user_name(workspace, token) - - agents_cfg: dict[str, dict] = {} - for tool in configured: - name = experiment_name(tool, user_name) - with spinner(f"Resolving MLflow experiment for {tool}..."): - exp_id, reason = get_or_create_mlflow_experiment(workspace, token, name) - if not exp_id: - print_warning(f"{tool}: could not resolve experiment {name}: {reason}") - continue - agents_cfg[tool] = {"experiment_id": exp_id, "experiment_name": name} - if not agents_cfg: - raise RuntimeError("Could not resolve an MLflow experiment for any configured agent.") + # ucode does not create the experiment: an admin must have already + # provisioned a `ucode-traces` experiment whose traces are backed by Unity + # Catalog. Assert it exists (and is UC-backed), failing with setup steps if + # not, so every agent and user routes to the same UC-backed sink. + name = experiment_name() + with spinner("Looking for the ucode-traces experiment..."): + experiment, reason = find_uc_backed_experiment(workspace, token, name) + if not experiment: + raise RuntimeError(_missing_experiment_error(name, reason)) + + # A UC-backed experiment needs a SQL warehouse to write traces — without + # `MLFLOW_TRACING_SQL_WAREHOUSE_ID` the exporter silently drops them — so a + # warehouse is mandatory, not optional. + with spinner("Resolving a SQL warehouse for trace storage..."): + warehouse_id, wh_reason = resolve_sql_warehouse_id(workspace, token) + if not warehouse_id: + raise RuntimeError(_missing_warehouse_error(wh_reason)) state["tracing"] = { "enabled": True, "tracking_uri": tracking_uri_for_state(state), - "agents": agents_cfg, + "experiment_id": experiment["experiment_id"], + "experiment_name": experiment["experiment_name"], + "uc_destination": experiment["uc_destination"], + "sql_warehouse_id": warehouse_id, } save_state(state) print_kv("Tracking URI", str(state["tracing"]["tracking_uri"])) - for tool, entry in agents_cfg.items(): - print_kv(f"{tool} experiment", f"{entry['experiment_name']} (id {entry['experiment_id']})") + print_kv( + "Experiment", + f"{experiment['experiment_name']} (id {experiment['experiment_id']})", + ) + print_kv("Unity Catalog", experiment["uc_destination"]) + print_kv("SQL warehouse", warehouse_id) _install_agent_tracing_deps(state) state = _rewrite_agent_configs(state) - print_success(f"Tracing configured for: {', '.join(agents_cfg)}") - return 0 + print_success(f"Tracing configured for: {', '.join(configured)}") + return state def _disable_tracing_command() -> int: diff --git a/tests/test_e2e_tracing.py b/tests/test_e2e_tracing.py index 7129791..8f0890c 100644 --- a/tests/test_e2e_tracing.py +++ b/tests/test_e2e_tracing.py @@ -4,11 +4,12 @@ UCODE_TEST_WORKSPACE=https://your-workspace.databricks.com uv run pytest tests/test_e2e_tracing.py -v The flow mirrors `ucode configure tracing` + a real agent run: - 1. Resolve (get-or-create) the per-agent MLflow experiment in the workspace. - 2. Enable tracing in state and write the agent's config (env + plugin). - 3. Install the agent's tracing runtime (Claude plugin + mlflow CLI). - 4. Launch the agent headless with a trivial prompt so it emits a trace. - 5. Poll the experiment via the MLflow SDK until a NEW trace id appears. + 1. Find the shared, UC-backed `ucode-traces` experiment in the workspace. + 2. Resolve a SQL warehouse (required to write traces to the UC table). + 3. Enable tracing in state and write the agent's config (env + plugin). + 4. Install the agent's tracing runtime (Claude plugin + mlflow CLI). + 5. Launch the agent headless with a trivial prompt so it emits a trace. + 6. Poll the experiment via the MLflow SDK until a NEW trace id appears. Skipped automatically unless UCODE_TEST_WORKSPACE is set, the `claude` binary is installed, `mlflow` is importable, and the tracing runtime can be set up. @@ -26,7 +27,7 @@ import pytest from ucode import tracing -from ucode.databricks import get_current_user_name, get_or_create_mlflow_experiment +from ucode.databricks import find_uc_backed_experiment, resolve_sql_warehouse_id # How long to wait for an emitted trace to show up in the experiment. Trace # ingestion is asynchronous, so we poll. @@ -91,13 +92,21 @@ def test_claude_session_lands_a_trace(self, tmp_path, monkeypatch, e2e_state, e2 monkeypatch.setattr(claude, "CLAUDE_SETTINGS_PATH", tmp_path / "ucode-settings.json") monkeypatch.setattr(claude, "CLAUDE_BACKUP_PATH", tmp_path / "claude-settings.backup.json") - # Resolve the real per-agent experiment for this user. - user_name = get_current_user_name(e2e_workspace, token) - experiment_name = tracing.experiment_name("claude", user_name) - experiment_id, reason = get_or_create_mlflow_experiment( - e2e_workspace, token, experiment_name - ) - assert experiment_id, f"could not resolve experiment {experiment_name}: {reason}" + # Find the shared, UC-backed `ucode-traces` experiment. ucode no longer + # creates it, so this workspace must already have one provisioned. + leaf_name = tracing.experiment_name() + experiment, reason = find_uc_backed_experiment(e2e_workspace, token, leaf_name) + if not experiment: + pytest.skip(f"no UC-backed '{leaf_name}' experiment on this workspace: {reason}") + experiment_id = experiment["experiment_id"] + experiment_name = experiment["experiment_name"] + + # A UC-backed experiment needs a SQL warehouse, or traces are silently + # dropped (and the verification client can't read them back). + warehouse_id, wh_reason = resolve_sql_warehouse_id(e2e_workspace, token) + if not warehouse_id: + pytest.skip(f"no SQL warehouse for UC trace storage: {wh_reason}") + monkeypatch.setenv("MLFLOW_TRACING_SQL_WAREHOUSE_ID", warehouse_id) state = { **e2e_state, @@ -105,12 +114,10 @@ def test_claude_session_lands_a_trace(self, tmp_path, monkeypatch, e2e_state, e2 "tracing": { "enabled": True, "tracking_uri": tracing.tracking_uri_for_state({"workspace": e2e_workspace}), - "agents": { - "claude": { - "experiment_id": experiment_id, - "experiment_name": experiment_name, - } - }, + "experiment_id": experiment_id, + "experiment_name": experiment_name, + "uc_destination": experiment["uc_destination"], + "sql_warehouse_id": warehouse_id, }, } diff --git a/tests/test_tracing.py b/tests/test_tracing.py index ec7fd0b..9170f5f 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -9,25 +9,12 @@ import ucode.databricks as databricks from ucode import tracing -from ucode.agents import claude, codex, opencode +from ucode.agents import claude WS = "https://example.databricks.com" -_AGENT_EXPERIMENTS = { - "claude": { - "experiment_id": "111", - "experiment_name": "/Users/me@example.com/ucode-claude-code-traces", - }, - "codex": { - "experiment_id": "222", - "experiment_name": "/Users/me@example.com/ucode-codex-traces", - }, - "opencode": { - "experiment_id": "333", - "experiment_name": "/Users/me@example.com/ucode-opencode-traces", - }, -} +SHARED_EXPERIMENT_ID = "111" def _enabled_state(profile: str | None = None) -> dict: @@ -37,7 +24,10 @@ def _enabled_state(profile: str | None = None) -> dict: "tracing": { "enabled": True, "tracking_uri": f"databricks://{profile}" if profile else "databricks", - "agents": {tool: dict(entry) for tool, entry in _AGENT_EXPERIMENTS.items()}, + "experiment_id": SHARED_EXPERIMENT_ID, + "experiment_name": "/Shared/ucode-traces", + "uc_destination": "main.default.ucode_traces", + "sql_warehouse_id": "wh123", }, } @@ -64,15 +54,21 @@ def test_returns_cfg_when_enabled(self): class TestAgentTracing: - def test_returns_agent_entry(self): + def test_returns_shared_entry(self): entry = tracing.agent_tracing(_enabled_state(), "claude") assert entry["experiment_id"] == "111" + def test_none_for_non_tracing_agents(self): + # Claude is the only tracing-capable agent now. + state = _enabled_state() + for tool in ("codex", "opencode", "gemini"): + assert tracing.agent_tracing(state, tool) is None + def test_none_when_disabled(self): assert tracing.agent_tracing({}, "claude") is None - def test_none_when_agent_absent(self): - state = {"tracing": {"enabled": True, "tracking_uri": "databricks", "agents": {}}} + def test_none_when_experiment_unresolved(self): + state = {"tracing": {"enabled": True, "tracking_uri": "databricks"}} assert tracing.agent_tracing(state, "claude") is None @@ -80,183 +76,158 @@ class TestTracingEnv: def test_empty_when_disabled(self): assert tracing.tracing_env({}, "claude") == {} - def test_per_agent_uri_and_experiment(self): - env = tracing.tracing_env(_enabled_state("p"), "codex") + def test_uri_and_experiment(self): + env = tracing.tracing_env(_enabled_state("p"), "claude") assert env == { "MLFLOW_TRACKING_URI": "databricks://p", - "MLFLOW_EXPERIMENT_ID": "222", + "MLFLOW_EXPERIMENT_ID": "111", + "MLFLOW_TRACING_SQL_WAREHOUSE_ID": "wh123", } - def test_distinct_experiment_per_agent(self): + def test_empty_for_non_claude_agents(self): + # Only Claude is tracing-capable; codex/opencode get nothing. + state = _enabled_state() + assert tracing.tracing_env(state, "codex") == {} + assert tracing.tracing_env(state, "opencode") == {} + + def test_includes_sql_warehouse_id(self): + env = tracing.tracing_env(_enabled_state(), "claude") + assert env["MLFLOW_TRACING_SQL_WAREHOUSE_ID"] == "wh123" + + def test_omits_warehouse_when_absent(self): state = _enabled_state() - assert tracing.tracing_env(state, "claude")["MLFLOW_EXPERIMENT_ID"] == "111" - assert tracing.tracing_env(state, "opencode")["MLFLOW_EXPERIMENT_ID"] == "333" + del state["tracing"]["sql_warehouse_id"] + assert "MLFLOW_TRACING_SQL_WAREHOUSE_ID" not in tracing.tracing_env(state, "claude") class TestExperimentName: - def test_per_user_claude_slug(self): - assert ( - tracing.experiment_name("claude", "me@example.com") - == "/Users/me@example.com/ucode-claude-code-traces" - ) + def test_leaf_name(self): + assert tracing.experiment_name() == "ucode-traces" - def test_per_user_codex_slug(self): - assert ( - tracing.experiment_name("codex", "me@example.com") - == "/Users/me@example.com/ucode-codex-traces" - ) - def test_shared_fallback_when_no_user(self): - assert tracing.experiment_name("opencode", None) == "/Shared/ucode-opencode-traces" +def _experiment(name: str, exp_id: str, uc_destination: str | None) -> dict: + tags = [{"key": "mlflow.experiment.sourceName", "value": name}] + if uc_destination is not None: + tags.append({"key": databricks.UC_TRACE_DESTINATION_TAG, "value": uc_destination}) + return {"experiment_id": exp_id, "name": name, "tags": tags} -class TestGetOrCreateExperiment: - def test_existing_returns_id(self): - with patch.object( - databricks, - "_http_get_json", - return_value=({"experiment": {"experiment_id": "42"}}, None), - ): - exp, reason = databricks.get_or_create_mlflow_experiment(WS, "tok", "/Shared/x") - assert exp == "42" +class TestFindUcBackedExperiment: + def test_returns_uc_backed_match(self): + payload = { + "experiments": [ + _experiment("/Users/me@example.com/ucode-traces", "42", "main.default.ucode_traces") + ] + } + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") assert reason is None + assert exp == { + "experiment_id": "42", + "experiment_name": "/Users/me@example.com/ucode-traces", + "uc_destination": "main.default.ucode_traces", + } - def test_creates_when_missing(self): - with ( - patch.object( - databricks, - "_http_get_json", - return_value=(None, "HTTP 404 RESOURCE_DOES_NOT_EXIST"), - ), - patch.object( - databricks, "_http_post_json", return_value=({"experiment_id": "99"}, None) - ) as post, - ): - exp, reason = databricks.get_or_create_mlflow_experiment(WS, "tok", "/Shared/x") - assert exp == "99" - assert reason is None - post.assert_called_once() + def test_any_catalog_schema_table_qualifies(self): + payload = {"experiments": [_experiment("/Shared/ucode-traces", "7", "cat.sch.tbl")]} + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, _ = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp["uc_destination"] == "cat.sch.tbl" + + def test_none_when_no_experiment(self): + with patch.object(databricks, "_http_post_json", return_value=({"experiments": []}, None)): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp is None + assert "no experiment named 'ucode-traces'" in reason + + def test_none_when_match_not_uc_backed(self): + payload = {"experiments": [_experiment("/Shared/ucode-traces", "9", None)]} + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp is None + assert "not backed by Unity Catalog" in reason - def test_refetches_on_already_exists_race(self): - get_results = iter( - [ - (None, "HTTP 404 RESOURCE_DOES_NOT_EXIST"), - ({"experiment": {"experiment_id": "7"}}, None), + def test_rejects_non_three_part_destination(self): + payload = {"experiments": [_experiment("/Shared/ucode-traces", "9", "main.default")]} + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp is None + assert "not backed by Unity Catalog" in reason + + def test_leaf_match_excludes_substring_names(self): + # "team-ucode-traces" ends with the leaf as a substring but is a + # different experiment — only an exact final path segment counts. + payload = {"experiments": [_experiment("/Shared/team-ucode-traces", "1", "c.s.t")]} + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp is None + assert "no experiment named 'ucode-traces'" in reason + + def test_prefers_uc_backed_over_plain_duplicate(self): + payload = { + "experiments": [ + _experiment("/Users/a@x.com/ucode-traces", "1", None), + _experiment("/Shared/ucode-traces", "2", "main.default.tbl"), ] - ) - with ( - patch.object( - databricks, "_http_get_json", side_effect=lambda *a, **k: next(get_results) - ), - patch.object( - databricks, - "_http_post_json", - return_value=(None, "HTTP 400: RESOURCE_ALREADY_EXISTS"), - ), - ): - exp, reason = databricks.get_or_create_mlflow_experiment(WS, "tok", "/Shared/x") - assert exp == "7" - assert reason is None + } + with patch.object(databricks, "_http_post_json", return_value=(payload, None)): + exp, _ = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") + assert exp["experiment_id"] == "2" - def test_returns_reason_on_failure(self): - with ( - patch.object(databricks, "_http_get_json", return_value=(None, "HTTP 403 Forbidden")), - patch.object(databricks, "_http_post_json", return_value=(None, "HTTP 403 Forbidden")), - ): - exp, reason = databricks.get_or_create_mlflow_experiment(WS, "tok", "/Shared/x") + def test_returns_reason_on_search_failure(self): + with patch.object(databricks, "_http_post_json", return_value=(None, "HTTP 403 Forbidden")): + exp, reason = databricks.find_uc_backed_experiment(WS, "tok", "ucode-traces") assert exp is None assert "403" in reason -class TestOpencodeTracingPlugin: - def test_added_when_enabled(self): - config: dict = {} - opencode._apply_tracing_plugin(config, _enabled_state()) - assert config["plugin"] == ["@mlflow/opencode"] - - def test_not_duplicated(self): - config = {"plugin": ["@mlflow/opencode"]} - opencode._apply_tracing_plugin(config, _enabled_state()) - assert config["plugin"] == ["@mlflow/opencode"] - - def test_removed_when_disabled(self): - config = {"plugin": ["@mlflow/opencode"]} - opencode._apply_tracing_plugin(config, {}) - assert "plugin" not in config - - def test_preserves_user_plugins_when_enabled(self): - config = {"plugin": ["user/thing"]} - opencode._apply_tracing_plugin(config, _enabled_state()) - assert config["plugin"] == ["user/thing", "@mlflow/opencode"] - - def test_preserves_user_plugins_when_disabled(self): - config = {"plugin": ["user/thing", "@mlflow/opencode"]} - opencode._apply_tracing_plugin(config, {}) - assert config["plugin"] == ["user/thing"] - - -class TestCodexTracingNotify: - def test_set_when_enabled(self): - doc: dict = {} - codex._apply_tracing_notify(doc, _enabled_state()) - assert doc["notify"] == ["mlflow-codex", "notify-hook"] - - def test_cleared_when_disabled(self): - doc = {"notify": ["mlflow-codex", "notify-hook"]} - codex._apply_tracing_notify(doc, {}) - assert "notify" not in doc - - def test_preserves_user_notify_when_disabled(self): - doc = {"notify": ["user-notify"]} - codex._apply_tracing_notify(doc, {}) - assert doc["notify"] == ["user-notify"] - - def test_warns_when_overwriting_user_notify(self): - doc = {"notify": ["user-hook"]} - with patch.object(codex, "print_warning") as warn: - codex._apply_tracing_notify(doc, _enabled_state()) - assert doc["notify"] == ["mlflow-codex", "notify-hook"] - warn.assert_called_once() - msg = warn.call_args[0][0] - assert "user-hook" in msg - assert "backup" in msg.lower() - - def test_no_warn_when_already_ucode_notify(self): - doc = {"notify": ["mlflow-codex", "notify-hook"]} - with patch.object(codex, "print_warning") as warn: - codex._apply_tracing_notify(doc, _enabled_state()) - warn.assert_not_called() - - -class TestApplyTracingEnv: - def test_sets_keys_when_enabled(self): - env: dict[str, str] = {} - tracing.apply_tracing_env(env, _enabled_state(), "codex") - assert env == {"MLFLOW_TRACKING_URI": "databricks", "MLFLOW_EXPERIMENT_ID": "222"} - - def test_clears_stale_keys_when_disabled(self): - env = { - "MLFLOW_TRACKING_URI": "databricks://stale", - "MLFLOW_EXPERIMENT_ID": "9999", - "UNRELATED": "keep-me", +class TestResolveSqlWarehouseId: + def test_prefers_running_warehouse(self): + payload = { + "warehouses": [ + {"id": "stopped1", "state": "STOPPED"}, + {"id": "running1", "state": "RUNNING"}, + ] } - tracing.apply_tracing_env(env, {}, "codex") - assert "MLFLOW_TRACKING_URI" not in env - assert "MLFLOW_EXPERIMENT_ID" not in env - assert env["UNRELATED"] == "keep-me" + with patch.object(databricks, "_http_get_json", return_value=(payload, None)): + wh, reason = databricks.resolve_sql_warehouse_id(WS, "tok") + assert wh == "running1" + assert reason is None - def test_overwrites_stale_keys_when_enabled(self): - env = {"MLFLOW_TRACKING_URI": "databricks://stale", "MLFLOW_EXPERIMENT_ID": "9999"} - tracing.apply_tracing_env(env, _enabled_state("p"), "opencode") - assert env["MLFLOW_TRACKING_URI"] == "databricks://p" - assert env["MLFLOW_EXPERIMENT_ID"] == "333" + def test_falls_back_to_first_when_none_running(self): + payload = { + "warehouses": [ + {"id": "stopped1", "state": "STOPPED"}, + {"id": "stopped2", "state": "STOPPED"}, + ] + } + with patch.object(databricks, "_http_get_json", return_value=(payload, None)): + wh, reason = databricks.resolve_sql_warehouse_id(WS, "tok") + assert wh == "stopped1" + assert reason is None + + def test_none_when_no_warehouses(self): + with patch.object(databricks, "_http_get_json", return_value=({"warehouses": []}, None)): + wh, reason = databricks.resolve_sql_warehouse_id(WS, "tok") + assert wh is None + assert "no SQL warehouse" in reason + + def test_returns_reason_on_failure(self): + with patch.object(databricks, "_http_get_json", return_value=(None, "HTTP 403 Forbidden")): + wh, reason = databricks.resolve_sql_warehouse_id(WS, "tok") + assert wh is None + assert "403" in reason class TestClaudeTracingEnv: + STOP_HOOK_CMD = "/uv/bin/mlflow autolog claude stop-hook" + def _write(self, state: dict, tmp_path, monkeypatch) -> dict: settings = tmp_path / "ucode-settings.json" monkeypatch.setattr(claude, "CLAUDE_SETTINGS_PATH", settings) monkeypatch.setattr(claude, "CLAUDE_BACKUP_PATH", tmp_path / "backup.json") + # Pin the resolved hook command so tests don't depend on a real uv/mlflow. + monkeypatch.setattr(claude, "claude_tracing_stop_hook_command", lambda: self.STOP_HOOK_CMD) claude.write_tool_config(state, "databricks-claude-opus-4-7") return json.loads(settings.read_text()) @@ -266,11 +237,68 @@ def test_injects_mlflow_env_when_enabled(self, tmp_path, monkeypatch): assert env["MLFLOW_CLAUDE_TRACING_ENABLED"] == "true" assert env["MLFLOW_TRACKING_URI"] == "databricks" assert env["MLFLOW_EXPERIMENT_ID"] == "111" + assert env["MLFLOW_TRACING_SQL_WAREHOUSE_ID"] == "wh123" + + def test_writes_stop_hook_when_enabled(self, tmp_path, monkeypatch): + state = {**_enabled_state(), "claude_models": {}} + settings = self._write(state, tmp_path, monkeypatch) + hooks = settings["hooks"]["Stop"] + assert hooks[0]["hooks"][0]["command"] == self.STOP_HOOK_CMD + + def test_preserves_user_hooks_when_enabled(self, tmp_path, monkeypatch): + settings = tmp_path / "ucode-settings.json" + settings.write_text( + json.dumps( + { + "hooks": { + "Stop": [{"hooks": [{"type": "command", "command": "user-stop"}]}], + "PreToolUse": [{"hooks": [{"type": "command", "command": "user-pre"}]}], + } + } + ) + ) + state = {**_enabled_state(), "claude_models": {}} + doc = self._write(state, tmp_path, monkeypatch) + + stop_commands = [ + hook["command"] for entry in doc["hooks"]["Stop"] for hook in entry["hooks"] + ] + assert stop_commands == ["user-stop", self.STOP_HOOK_CMD] + assert doc["hooks"]["PreToolUse"][0]["hooks"][0]["command"] == "user-pre" + + def test_updates_existing_tracing_hook_when_enabled(self, tmp_path, monkeypatch): + settings = tmp_path / "ucode-settings.json" + settings.write_text( + json.dumps( + { + "hooks": { + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "/old/bin/mlflow autolog claude stop-hook", + } + ] + } + ] + } + } + ) + ) + state = {**_enabled_state(), "claude_models": {}} + doc = self._write(state, tmp_path, monkeypatch) + + stop_commands = [ + hook["command"] for entry in doc["hooks"]["Stop"] for hook in entry["hooks"] + ] + assert stop_commands == [self.STOP_HOOK_CMD] def test_no_mlflow_env_when_disabled(self, tmp_path, monkeypatch): state = {"workspace": WS, "claude_models": {}} - env = self._write(state, tmp_path, monkeypatch).get("env", {}) - assert "MLFLOW_TRACKING_URI" not in env + settings = self._write(state, tmp_path, monkeypatch) + assert "MLFLOW_TRACKING_URI" not in settings.get("env", {}) + assert "hooks" not in settings def test_strips_stale_keys_when_disabled(self, tmp_path, monkeypatch): settings = tmp_path / "ucode-settings.json" @@ -281,7 +309,22 @@ def test_strips_stale_keys_when_disabled(self, tmp_path, monkeypatch): "MLFLOW_CLAUDE_TRACING_ENABLED": "true", "MLFLOW_TRACKING_URI": "databricks", "MLFLOW_EXPERIMENT_ID": "1", - } + "MLFLOW_TRACING_SQL_WAREHOUSE_ID": "old-wh", + }, + "hooks": { + "Stop": [ + { + "hooks": [ + {"type": "command", "command": "user-stop"}, + { + "type": "command", + "command": "/old/bin/mlflow autolog claude stop-hook", + }, + ] + } + ], + "PreToolUse": [{"hooks": [{"type": "command", "command": "user-pre"}]}], + }, } ) ) @@ -290,10 +333,14 @@ def test_strips_stale_keys_when_disabled(self, tmp_path, monkeypatch): claude.write_tool_config( {"workspace": WS, "claude_models": {}}, "databricks-claude-opus-4-7" ) - env = json.loads(settings.read_text())["env"] + doc = json.loads(settings.read_text()) + env = doc["env"] assert "MLFLOW_TRACKING_URI" not in env assert "MLFLOW_EXPERIMENT_ID" not in env assert "MLFLOW_CLAUDE_TRACING_ENABLED" not in env + assert "MLFLOW_TRACING_SQL_WAREHOUSE_ID" not in env + assert doc["hooks"]["Stop"][0]["hooks"][0]["command"] == "user-stop" + assert doc["hooks"]["PreToolUse"][0]["hooks"][0]["command"] == "user-pre" class TestSelectTracingWorkspace: @@ -302,15 +349,15 @@ def _full(self) -> dict: "current_workspace": "https://a.databricks.com", "workspaces": { "https://a.databricks.com": {"available_tools": ["claude"], "profile": "pa"}, - "https://b.databricks.com": {"available_tools": ["codex"], "profile": "pb"}, - # gemini isn't a tracing-capable agent → excluded from candidates - "https://c.databricks.com": {"available_tools": ["gemini"]}, + "https://b.databricks.com": {"available_tools": ["claude"], "profile": "pb"}, + # codex/gemini aren't tracing-capable agents → excluded from candidates + "https://c.databricks.com": {"available_tools": ["codex", "gemini"]}, }, } def test_raises_when_none_configured(self): with patch.object(tracing, "load_full_state", return_value={"workspaces": {}}): - with pytest.raises(RuntimeError, match="No tracing-capable"): + with pytest.raises(RuntimeError, match="Claude Code is not configured"): tracing._select_tracing_workspace() def test_lists_current_first_and_excludes_non_tracing(self): @@ -341,7 +388,7 @@ def test_returns_picked_workspace_state(self): state = tracing._select_tracing_workspace() assert state["workspace"] == "https://b.databricks.com" assert state["profile"] == "pb" - assert "codex" in state["available_tools"] + assert "claude" in state["available_tools"] def test_raises_when_picked_workspace_unconfigured(self): with ( @@ -485,6 +532,71 @@ def test_disable_with_none_enabled_still_calls_restore(self): assert captured["restored_to"] == "https://a.databricks.com" +class TestEnableTracingForWorkspaces: + """``configure --tracing`` enables tracing for explicit workspaces without + prompting, and skips workspaces with no tracing-capable agent.""" + + def _full(self) -> dict: + return { + "current_workspace": "https://a.databricks.com", + "workspaces": { + "https://a.databricks.com": {"available_tools": ["claude"], "profile": "pa"}, + # no tracing-capable agent → skipped, not an error + "https://b.databricks.com": {"available_tools": ["gemini"], "profile": "pb"}, + }, + } + + def test_enables_without_prompting(self): + enabled: list[str] = [] + with ( + patch.object(tracing, "load_full_state", return_value=self._full()), + patch.object(tracing, "prompt_for_workspace") as prompt, + patch.object(tracing, "set_current_workspace"), + patch.object( + tracing, + "_enable_tracing_for_state", + side_effect=lambda s: enabled.append(s["workspace"]) or s, + ), + ): + rc = tracing.configure_tracing_command(workspaces=[("https://a.databricks.com", None)]) + prompt.assert_not_called() + assert rc == 0 + assert enabled == ["https://a.databricks.com"] + + def test_skips_workspace_without_tracing_agent(self): + enabled: list[str] = [] + with ( + patch.object(tracing, "load_full_state", return_value=self._full()), + patch.object(tracing, "set_current_workspace"), + patch.object( + tracing, + "_enable_tracing_for_state", + side_effect=lambda s: enabled.append(s["workspace"]) or s, + ), + ): + rc = tracing.configure_tracing_command(workspaces=[("https://b.databricks.com", None)]) + assert enabled == [] + assert rc == 1 + + +class TestInstallAgentTracingDeps: + """Only Claude has a tracing runtime; it installs when Claude is configured + with tracing on, and is skipped otherwise.""" + + def test_installs_claude_runtime_when_configured(self): + state = {**_enabled_state(), "available_tools": ["claude"]} + with patch("ucode.agents.claude.ensure_tracing_runtime") as claude_dep: + tracing._install_agent_tracing_deps(state) + claude_dep.assert_called_once() + + def test_skips_when_claude_not_configured(self): + # Claude isn't configured on this workspace, so its runtime is skipped. + state = {**_enabled_state(), "available_tools": ["codex"]} + with patch("ucode.agents.claude.ensure_tracing_runtime") as claude_dep: + tracing._install_agent_tracing_deps(state) + claude_dep.assert_not_called() + + class TestDisableTracing: def test_sets_disabled_and_rewrites_configs(self): state = _enabled_state() @@ -497,6 +609,18 @@ def test_sets_disabled_and_rewrites_configs(self): rewrite.assert_called_once() +class TestStopHookCommand: + def test_builds_command_from_resolved_path(self, monkeypatch): + monkeypatch.setattr(claude, "_uv_tool_mlflow_path", lambda: "/uv/bin/mlflow") + assert ( + claude.claude_tracing_stop_hook_command() == "/uv/bin/mlflow autolog claude stop-hook" + ) + + def test_none_when_mlflow_missing(self, monkeypatch): + monkeypatch.setattr(claude, "_uv_tool_mlflow_path", lambda: None) + assert claude.claude_tracing_stop_hook_command() is None + + class TestParseMlflowVersion: def test_parses_full_version(self): assert claude._parse_mlflow_version("mlflow, version 3.12.0") == (3, 12) @@ -509,9 +633,9 @@ def test_returns_none_on_garbage(self): class TestEnsureMlflowCli: - def test_noop_when_already_satisfied(self): + def test_noop_when_already_in_range(self): with ( - patch.object(claude, "_installed_mlflow_version", return_value=(3, 5)), + patch.object(claude, "_installed_mlflow_version", return_value=(3, 11)), patch.object(claude.subprocess, "run") as run, ): assert claude._ensure_mlflow_cli() is True @@ -520,15 +644,27 @@ def test_noop_when_already_satisfied(self): def test_installs_when_missing(self, monkeypatch): monkeypatch.setattr(claude, "_installed_mlflow_version", lambda: None) monkeypatch.setattr(claude.shutil, "which", lambda binary: f"/bin/{binary}") + monkeypatch.setattr(claude, "_uv_tool_mlflow_path", lambda: "/bin/mlflow") with patch.object(claude.subprocess, "run") as run: assert claude._ensure_mlflow_cli() is True cmd = run.call_args[0][0] assert cmd[:3] == ["uv", "tool", "install"] + assert claude.MLFLOW_CLI_SPEC in cmd assert "--force" not in cmd - def test_force_upgrades_when_below_minimum(self, monkeypatch): - monkeypatch.setattr(claude, "_installed_mlflow_version", lambda: (3, 1)) + def test_force_replaces_when_below_minimum(self, monkeypatch): + monkeypatch.setattr(claude, "_installed_mlflow_version", lambda: (3, 4)) + monkeypatch.setattr(claude.shutil, "which", lambda binary: f"/bin/{binary}") + monkeypatch.setattr(claude, "_uv_tool_mlflow_path", lambda: "/bin/mlflow") + with patch.object(claude.subprocess, "run") as run: + assert claude._ensure_mlflow_cli() is True + assert "--force" in run.call_args[0][0] + + def test_force_replaces_when_above_maximum(self, monkeypatch): + # 3.12 dropped UC trace writes — it must be replaced, not left alone. + monkeypatch.setattr(claude, "_installed_mlflow_version", lambda: (3, 12)) monkeypatch.setattr(claude.shutil, "which", lambda binary: f"/bin/{binary}") + monkeypatch.setattr(claude, "_uv_tool_mlflow_path", lambda: "/bin/mlflow") with patch.object(claude.subprocess, "run") as run: assert claude._ensure_mlflow_cli() is True assert "--force" in run.call_args[0][0]