diff --git a/.env.example b/.env.example index 36fd1792..a682e955 100644 --- a/.env.example +++ b/.env.example @@ -341,16 +341,6 @@ COMMIT_VECTOR_SEARCH=0 STRICT_MEMORY_RESTORE=1 -# info_request() tool settings (simplified codebase retrieval) -# Default result limit for info_request queries -INFO_REQUEST_LIMIT=10 -# Default context lines in snippets (richer than repo_search default) -INFO_REQUEST_CONTEXT_LINES=5 -# Enable explanation mode by default (summary, primary_locations, related_concepts) -# INFO_REQUEST_EXPLAIN_DEFAULT=0 -# Enable relationship mapping by default (imports_from, calls, related_paths) -# INFO_REQUEST_RELATIONSHIPS=0 - # TOON output format (Token-Oriented Object Notation) # When enabled, search results use compact TOON encoding to reduce token usage # TOON_ENABLED=0 @@ -493,4 +483,3 @@ OPENLIT_ENVIRONMENT=development # End of Auth & Bridge Configuration # --------------------------------------------------------------------------- - diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ae82ed01..d9b08ecd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: services: qdrant: - image: qdrant/qdrant:latest + image: qdrant/qdrant:v1.15.4 ports: - 6333:6333 diff --git a/.github/workflows/claude.yaml b/.github/workflows/claude.yaml new file mode 100644 index 00000000..732de78c --- /dev/null +++ b/.github/workflows/claude.yaml @@ -0,0 +1,68 @@ +name: Claude Code + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened] + pull_request_review: + types: [submitted] + pull_request_target: + types: [opened, synchronize] + +jobs: + claude: + # This simplified condition is more robust and correctly checks permissions. + if: > + (contains(github.event.comment.body, '@claude') || + contains(github.event.review.body, '@claude') || + contains(github.event.issue.body, '@claude') || + contains(github.event.pull_request.body, '@claude')) && + (github.event.sender.type == 'User' && ( + github.event.comment.author_association == 'OWNER' || + github.event.comment.author_association == 'MEMBER' || + github.event.comment.author_association == 'COLLABORATOR' + )) + runs-on: ubuntu-latest + permissions: + # CRITICAL: Write permissions are required for the action to push branches and update issues/PRs. + contents: write + pull-requests: write + issues: write + id-token: write # Required for OIDC token exchange + actions: read # Required for Claude to read CI results on PRs + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + # This correctly checks out the PR's head commit for pull_request_target events. + ref: ${{ github.event.pull_request.head.sha }} + + - name: Create Claude settings file + run: | + mkdir -p /home/runner/.claude + cat > /home/runner/.claude/settings.json << 'EOF' + { + "env": { + "ANTHROPIC_BASE_URL": "https://api.z.ai/api/anthropic", + "ANTHROPIC_AUTH_TOKEN": "${{ secrets.CUSTOM_ENDPOINT_API_KEY }}" + } + } + EOF + + - name: Run Claude Code + id: claude + uses: anthropics/claude-code-action@v1 + with: + # Still need this to satisfy the action's validation + anthropic_api_key: ${{ secrets.CUSTOM_ENDPOINT_API_KEY }} + + # Use the same variable names as your local setup + settings: '{"env": {"ANTHROPIC_BASE_URL": "https://api.z.ai/api/anthropic", "ANTHROPIC_AUTH_TOKEN": "${{ secrets.CUSTOM_ENDPOINT_API_KEY }}"}}' + + track_progress: true + claude_args: | + --allowedTools "Bash,Edit,Read,Write,Glob,Grep" diff --git a/.github/workflows/cosqa-benchmark.yml b/.github/workflows/cosqa-benchmark.yml new file mode 100644 index 00000000..c25a1769 --- /dev/null +++ b/.github/workflows/cosqa-benchmark.yml @@ -0,0 +1,147 @@ +name: CoSQA Search Benchmark + +on: + workflow_dispatch: + inputs: + enforce_hybrid_gate: + description: Fail run if best hybrid underperforms best dense past threshold + required: false + default: false + type: boolean + hybrid_min_delta: + description: Minimum accepted (hybrid_mrr - dense_mrr), e.g. -0.02 + required: false + default: "-0.02" + type: string + upload_full_artifacts: + description: Upload full logs/json bundle (higher storage usage) + required: false + default: false + type: boolean + + pull_request: + branches: [ test ] + paths: + - scripts/hybrid/** + - scripts/hybrid_search.py + - scripts/mcp_impl/search.py + - scripts/mcp_impl/context_search.py + - scripts/mcp_indexer_server.py + - scripts/benchmarks/cosqa/** + - .github/workflows/cosqa-benchmark.yml + + schedule: + - cron: "25 3 * * *" + +jobs: + cosqa-bench: + runs-on: ubuntu-latest + timeout-minutes: 360 + + services: + qdrant: + image: qdrant/qdrant:v1.15.1 + ports: + - 6333:6333 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt', '**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Cache HuggingFace datasets + uses: actions/cache@v4 + with: + path: | + ~/.cache/huggingface/datasets + ~/.cache/huggingface/hub + key: ${{ runner.os }}-hf-cosqa-${{ hashFiles('scripts/benchmarks/cosqa/dataset.py') }} + restore-keys: | + ${{ runner.os }}-hf-cosqa- + ${{ runner.os }}-hf- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install "datasets>=2.18.0" + + - name: Wait for Qdrant + run: | + timeout 90 bash -c 'until curl -fsS http://localhost:6333/readyz; do sleep 2; done' + curl -fsS http://localhost:6333/collections >/dev/null + + - name: Resolve run config + id: cfg + run: | + echo "profile=full" >> "$GITHUB_OUTPUT" + echo "run_set=full" >> "$GITHUB_OUTPUT" + if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ "${{ inputs.enforce_hybrid_gate }}" = "true" ]; then + echo "enforce_hybrid_gate=1" >> "$GITHUB_OUTPUT" + else + echo "enforce_hybrid_gate=0" >> "$GITHUB_OUTPUT" + fi + if [ "${{ github.event_name }}" = "workflow_dispatch" ] && [ "${{ inputs.hybrid_min_delta }}" != "" ]; then + echo "hybrid_min_delta=${{ inputs.hybrid_min_delta }}" >> "$GITHUB_OUTPUT" + else + echo "hybrid_min_delta=-0.02" >> "$GITHUB_OUTPUT" + fi + + - name: Run CoSQA search matrix + id: bench + env: + QDRANT_URL: http://localhost:6333 + PROFILE: ${{ steps.cfg.outputs.profile }} + RUN_SET: ${{ steps.cfg.outputs.run_set }} + ENFORCE_HYBRID_GATE: ${{ steps.cfg.outputs.enforce_hybrid_gate }} + HYBRID_MIN_DELTA: ${{ steps.cfg.outputs.hybrid_min_delta }} + PYTHONUNBUFFERED: "1" + run: | + RUN_TAG="gha-${{ github.run_id }}-${{ github.run_attempt }}" + OUT_DIR="bench_results/cosqa/${RUN_TAG}" + echo "out_dir=${OUT_DIR}" >> "$GITHUB_OUTPUT" + RUN_TAG="${RUN_TAG}" OUT_DIR="${OUT_DIR}" ./scripts/benchmarks/cosqa/run_search_matrix.sh + + - name: Publish benchmark summary + if: always() + run: | + SUMMARY="${{ steps.bench.outputs.out_dir }}/summary.md" + if [ -f "${SUMMARY}" ]; then + cat "${SUMMARY}" >> "$GITHUB_STEP_SUMMARY" + else + echo "No summary file generated" >> "$GITHUB_STEP_SUMMARY" + fi + + - name: Upload benchmark artifacts + if: always() && github.event_name == 'pull_request' + uses: actions/upload-artifact@v4 + with: + name: cosqa-search-summary-${{ github.run_id }}-${{ github.run_attempt }} + path: | + ${{ steps.bench.outputs.out_dir }}/summary.md + ${{ steps.bench.outputs.out_dir }}/summary.json + retention-days: 3 + + - name: Upload full benchmark artifacts + if: | + always() && ( + github.event_name == 'schedule' || + (github.event_name == 'workflow_dispatch' && inputs.upload_full_artifacts == true) + ) + uses: actions/upload-artifact@v4 + with: + name: cosqa-search-bench-${{ github.run_id }}-${{ github.run_attempt }} + path: ${{ steps.bench.outputs.out_dir }} + retention-days: 7 diff --git a/Dockerfile b/Dockerfile index c9635c20..15b69498 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,12 +19,13 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt # Copy scripts for all services COPY scripts /app/scripts +RUN chmod -R a+rX /app/scripts # Create directories -WORKDIR /work +WORKDIR /app # Expose all necessary ports EXPOSE 8000 8001 8002 8003 18000 18001 18002 18003 # Default to memory server -CMD ["python", "/app/scripts/mcp_memory_server.py"] \ No newline at end of file +CMD ["python", "-m", "scripts.mcp_memory_server"] diff --git a/Dockerfile.indexer b/Dockerfile.indexer index 18ce4586..07776b90 100644 --- a/Dockerfile.indexer +++ b/Dockerfile.indexer @@ -29,9 +29,10 @@ ENV RERANKER_ONNX_PATH=/app/models/reranker.onnx \ # Bake scripts into the image so we can mount arbitrary code at /work COPY scripts /app/scripts +RUN chmod -R a+rX /app/scripts -WORKDIR /work +WORKDIR /app # Default command shows help; Makefile/compose will override entrypoint -CMD ["python", "/app/scripts/ingest_code.py", "--help"] +CMD ["python", "-m", "scripts.ingest_code", "--help"] diff --git a/Dockerfile.mcp b/Dockerfile.mcp index a97142ed..0cfc482e 100644 --- a/Dockerfile.mcp +++ b/Dockerfile.mcp @@ -17,9 +17,12 @@ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt \ # Bake scripts into image so server can run even when /work points elsewhere COPY scripts /app/scripts +RUN chmod -R a+rX /app/scripts # Expose SSE port EXPOSE 8000 # Default command: run the server with SSE transport (env provides host/port) -CMD ["python", "/app/scripts/mcp_memory_server.py"] +WORKDIR /app + +CMD ["python", "-m", "scripts.mcp_memory_server"] diff --git a/Dockerfile.mcp-indexer b/Dockerfile.mcp-indexer index 064a9188..28f2a62a 100644 --- a/Dockerfile.mcp-indexer +++ b/Dockerfile.mcp-indexer @@ -32,13 +32,14 @@ ENV RERANKER_ONNX_PATH=/app/models/reranker.onnx \ # Bake scripts into the image so entrypoints don't rely on /work COPY scripts /app/scripts +RUN chmod -R a+rX /app/scripts COPY bench /app/bench # Expose SSE port for this companion server EXPOSE 8001 -WORKDIR /work +WORKDIR /app # Default command runs the companion MCP server -CMD ["python", "/app/scripts/mcp_indexer_server.py"] +CMD ["python", "-m", "scripts.mcp_indexer_server"] diff --git a/Dockerfile.upload-service b/Dockerfile.upload-service index 8c3f47c9..9e15fe81 100644 --- a/Dockerfile.upload-service +++ b/Dockerfile.upload-service @@ -27,6 +27,7 @@ RUN pip install --no-cache-dir --upgrade pip \ # Copy application code COPY scripts/ ./scripts/ COPY . . +RUN chmod -R a+rX /app/scripts # Create work dir, non-root user, and set ownership in single layer RUN mkdir -p /work && chmod 755 /work \ @@ -50,4 +51,6 @@ ENV UPLOAD_SERVICE_HOST=0.0.0.0 \ UPLOAD_TIMEOUT_SECS=300 # Run the upload service -CMD ["python", "scripts/upload_service.py"] \ No newline at end of file +WORKDIR /app + +CMD ["python", "-m", "scripts.upload_service"] diff --git a/Makefile b/Makefile index 068f765f..096fe8c9 100644 --- a/Makefile +++ b/Makefile @@ -4,12 +4,10 @@ SHELL := /bin/bash # An empty export forces docker to use its default context/socket. export DOCKER_HOST = -.PHONY: help up down logs ps restart rebuild index reindex watch watch-remote env hybrid bootstrap history rerank-local setup-reranker prune warm health test-e2e +.PHONY: help up down logs ps restart rebuild index reindex watch watch-remote env hybrid bootstrap history rerank-local setup-reranker prune warm health test test-full test-integration test-e2e .PHONY: venv venv-install dev-remote-up dev-remote-down dev-remote-logs dev-remote-restart dev-remote-bootstrap dev-remote-test dev-remote-client dev-remote-clean .PHONY: rerank-eval rerank-eval-ablations rerank-benchmark -.PHONY: qdrant-status qdrant-list qdrant-prune qdrant-index-root - venv: ## create local virtualenv .venv python3 -m venv .venv && . .venv/bin/activate && pip install -U pip @@ -73,7 +71,7 @@ index-path: ## index an arbitrary repo: make index-path REPO_PATH=/abs/path [REC @NAME=$${REPO_NAME:-$$(basename "$(REPO_PATH)")}; \ COLL=$${COLLECTION:-$$NAME}; \ HOST_INDEX_PATH="$(REPO_PATH)" COLLECTION_NAME="$$COLL" REPO_NAME="$$NAME" \ - docker compose run --rm -v "$$PWD":/app:ro --entrypoint python indexer /app/scripts/ingest_code.py --root /work $${RECREATE:+--recreate} + docker compose run --rm -v "$$PWD":/app:ro --workdir /app --entrypoint python indexer -m scripts.ingest_code --root /work $${RECREATE:+--recreate} # Index the current working directory quickly index-here: ## index the current directory: make index-here [RECREATE=1] [REPO_NAME=name] [COLLECTION=name] @@ -85,7 +83,7 @@ index-here: ## index the current directory: make index-here [RECREATE=1] [REPO_N watch: ## watch mode: reindex changed files on save (Ctrl+C to stop) - docker compose run --rm --entrypoint python indexer /work/scripts/watch_index.py + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.watch_index watch-remote: ## remote watch mode: upload delta bundles to remote server (Ctrl+C to stop) @echo "Starting remote watch mode..." @@ -97,24 +95,33 @@ watch-remote: ## remote watch mode: upload delta bundles to remote server (Ctrl+ @echo "Remote upload endpoint: $(REMOTE_UPLOAD_ENDPOINT)" @echo "Max retries: $${REMOTE_UPLOAD_MAX_RETRIES:-3}" @echo "Timeout: $${REMOTE_UPLOAD_TIMEOUT:-30} seconds" - docker compose run --rm --entrypoint python \ + docker compose run --rm --workdir /app --entrypoint python \ -e REMOTE_UPLOAD_ENABLED=1 \ -e REMOTE_UPLOAD_ENDPOINT=$(REMOTE_UPLOAD_ENDPOINT) \ -e REMOTE_UPLOAD_MAX_RETRIES=$${REMOTE_UPLOAD_MAX_RETRIES:-3} \ -e REMOTE_UPLOAD_TIMEOUT=$${REMOTE_UPLOAD_TIMEOUT:-30} \ - indexer /work/scripts/watch_index.py + indexer -m scripts.watch_index rerank: ## multi-query re-ranker helper example - docker compose run --rm --entrypoint python indexer /work/scripts/rerank_query.py \ + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.rerank_tools.query \ --query "chunk code by lines with overlap for indexing" \ --query "function to split code into overlapping line chunks" \ --language python --under /work/scripts --limit 5 warm: ## prime ANN/search caches with a few queries - docker compose run --rm --entrypoint python indexer /work/scripts/warm_start.py --ef 256 --limit 3 + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.warm_start --ef 256 --limit 3 health: ## run health checks for collection/model settings - docker compose run --rm --entrypoint python indexer /work/scripts/health_check.py + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.health_check + +test: ## run default fast tests (excludes integration) + pytest + +test-full: ## run all tests including integration + pytest --run-integration -m "" + +test-integration: ## run integration tests only + pytest --run-integration -m integration # Check llama.cpp decoder health on localhost:8080 (200 OK expected) @@ -128,7 +135,7 @@ env: ## create .env from example if missing [ -f .env ] || cp .env.example .env hybrid: ## hybrid search: dense + lexical RRF fuse (respects --language/--under/--kind) - docker compose run --rm --entrypoint python indexer /work/scripts/hybrid_search.py \ + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.hybrid_search \ --query "chunk code by lines" --query "overlapping line chunks" --limit 8 bootstrap: env up ## one-shot: up -> wait -> index -> warm -> health @@ -138,20 +145,20 @@ bootstrap: env up ## one-shot: up -> wait -> index -> warm -> health $(MAKE) health history: ## ingest Git history (messages + file lists) - docker compose run --rm --entrypoint python indexer /work/scripts/ingest_history.py --max-commits 200 + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.ingest_history --max-commits 200 prune-path: ## prune a repo by path: make prune-path REPO_PATH=/abs/path @if [ -z "$(REPO_PATH)" ]; then \ echo "Usage: make prune-path REPO_PATH=/abs/path"; exit 1; \ fi HOST_INDEX_PATH="$(REPO_PATH)" PRUNE_ROOT=/work \ - docker compose run --rm --entrypoint python indexer /work/scripts/prune.py + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.prune rerank-local: ## local cross-encoder reranker (requires RERANKER_ONNX_PATH, RERANKER_TOKENIZER_PATH) @if [ -z "$(RERANKER_ONNX_PATH)" ] || [ -z "$(RERANKER_TOKENIZER_PATH)" ]; then \ echo "RERANKER_ONNX_PATH and RERANKER_TOKENIZER_PATH must be set in .env"; exit 1; \ fi - docker compose run --rm --entrypoint python indexer /work/scripts/rerank_local.py --query "search symbols" --topk 50 --limit 12 + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.rerank_tools.local --query "search symbols" --topk 50 --limit 12 setup-reranker: ## download ONNX reranker + tokenizer, update .env, then smoke-test @if [ -z "$(ONNX_URL)" ] || [ -z "$(TOKENIZER_URL)" ]; then \ @@ -163,7 +170,7 @@ setup-reranker: ## download ONNX reranker + tokenizer, update .env, then smoke-t $(MAKE) rerank-local prune: ## remove points for missing files or mismatched file_hash - docker compose run --rm --entrypoint python indexer /work/scripts/prune.py + docker compose run --rm --workdir /app --entrypoint python indexer -m scripts.prune @@ -296,56 +303,6 @@ dev-remote-clean: ## clean up dev-remote volumes and containers rm -rf dev-workspace -# Router helpers -Q ?= what is hybrid search? -route-plan: ## plan-only route for a query: make route-plan Q="your question" - python3 scripts/mcp_router.py --plan "$(Q)" - -route-run: ## execute routed tool(s) over HTTP: make route-run Q="your question" - python3 scripts/mcp_router.py --run "$(Q)" -router-eval: ## run the mock-based router eval harness - python3 scripts/router_eval.py - - -# Live orchestration smoke test (no CI): bring up stack, reindex, run router -router-smoke: ## spin up compose, reindex, store a memory via router, then answer; exits nonzero on failure - set -e; \ - docker compose down || true; \ - docker compose up -d qdrant; \ - ./scripts/wait-for-qdrant.sh; \ - $(MAKE) llama-model; \ - docker compose up -d mcp_http mcp_indexer_http llamacpp; \ - echo "Waiting for MCP HTTP health..."; \ - for i in $$(seq 1 30); do \ - code1=$$(curl -s -o /dev/null -w "%{http_code}" http://localhost:$${FASTMCP_HTTP_HEALTH_PORT:-18002}/readyz || true); \ - code2=$$(curl -s -o /dev/null -w "%{http_code}" http://localhost:$${FASTMCP_INDEXER_HTTP_HEALTH_PORT:-18003}/readyz || true); \ - if [ "$$code1" = "200" ] && [ "$$code2" = "200" ]; then echo "MCP HTTP ready"; break; fi; \ - sleep 1; \ - if [ $$i -eq 30 ]; then echo "MCP HTTP health timeout"; exit 1; fi; \ - done; \ - $(MAKE) reindex; \ - echo "Storing a smoke memory via router..."; \ - python3 scripts/mcp_router.py --run "remember this: router smoke memory"; \ - echo "Running a router answer..."; \ - python3 scripts/mcp_router.py --run "recap our architecture decisions for the indexer"; \ - echo "router-smoke: PASS" - - - -# Qdrant via MCP router convenience targets -qdrant-status: - python3 scripts/mcp_router.py --run "status" - -qdrant-list: - python3 scripts/mcp_router.py --run "list collections" - -qdrant-prune: - python3 scripts/mcp_router.py --run "prune" - -qdrant-index-root: - python3 scripts/mcp_router.py --run "reindex repo" - - # --- ctx CLI helper --- # Usage examples (default prints ONLY the improved prompt): # make ctx Q="how does hybrid search work?" @@ -362,10 +319,10 @@ ctx: ## enhance a prompt with repo context: make ctx Q="your question" [ARGS='-- # --- Reranker Evaluation --- rerank-eval: ## run offline reranker evaluation (fixed queries, MRR/Recall/latency) - python3 scripts/rerank_eval.py --output rerank_eval_results.json + python3 -m scripts.rerank_tools.eval --output rerank_eval_results.json rerank-eval-ablations: ## run full ablation study (baseline, recursive, learning, onnx) - python3 scripts/rerank_eval.py --ablations --output rerank_eval_ablations.json + python3 -m scripts.rerank_tools.eval --ablations --output rerank_eval_ablations.json rerank-benchmark: ## run production benchmark on real codebase - python3 scripts/rerank_real_benchmark.py + python3 -m scripts.rerank_tools.benchmark diff --git a/README.md b/README.md index 2a9cd730..19c166eb 100644 --- a/README.md +++ b/README.md @@ -144,10 +144,9 @@ See [docs/vscode-extension.md](docs/vscode-extension.md) for full documentation. ## MCP Tools **Search** (Indexer MCP): -- `repo_search` — Hybrid code search with filters +- `repo_search` — Code search with filters and optional profiles - `context_search` — Blend code + memory results - `context_answer` — LLM-generated answers with citations -- `search_tests_for`, `search_config_for`, `search_callers_for` **Memory** (Memory MCP): - `store` — Save knowledge with metadata diff --git a/ctx-mcp-bridge/src/mcpServer.js b/ctx-mcp-bridge/src/mcpServer.js index 53cb05b7..ab4c3bf6 100644 --- a/ctx-mcp-bridge/src/mcpServer.js +++ b/ctx-mcp-bridge/src/mcpServer.js @@ -8,7 +8,13 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js" import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { CallToolRequestSchema, ListToolsRequestSchema } from "@modelcontextprotocol/sdk/types.js"; +import { + CallToolRequestSchema, + ListToolsRequestSchema, + ListResourcesRequestSchema, + ListResourceTemplatesRequestSchema, + ReadResourceRequestSchema, +} from "@modelcontextprotocol/sdk/types.js"; import { loadAnyAuthEntry, loadAuthEntry, readConfig, saveAuthEntry } from "./authConfig.js"; import { maybeRemapToolArgs, maybeRemapToolResult } from "./resultPathMapping.js"; import * as oauthHandler from "./oauthHandler.js"; @@ -27,16 +33,23 @@ function debugLog(message) { async function sendSessionDefaults(client, payload, label) { if (!client) { - return; + return false; } try { - await client.callTool({ - name: "set_session_defaults", - arguments: payload, - }); + const timeoutMs = getBridgeToolTimeoutMs(); + await withTimeout( + client.callTool({ + name: "set_session_defaults", + arguments: payload, + }), + timeoutMs, + `sendSessionDefaults(${label})` + ); + return true; } catch (err) { // eslint-disable-next-line no-console console.error(`[ctxce] Failed to call set_session_defaults on ${label}:`, err); + return false; } } function dedupeTools(tools) { @@ -58,15 +71,69 @@ function dedupeTools(tools) { return out; } -async function listMemoryTools(client) { - if (!client) { +function dedupeResources(resources) { + const seen = new Set(); + const out = []; + for (const resource of resources) { + const uri = resource && typeof resource.uri === "string" ? resource.uri : ""; + if (!uri || seen.has(uri)) { + continue; + } + seen.add(uri); + out.push(resource); + } + return out; +} + +function dedupeResourceTemplates(templates) { + const seen = new Set(); + const out = []; + for (const template of templates) { + const uri = + template && typeof template.uriTemplate === "string" + ? template.uriTemplate + : ""; + if (!uri || seen.has(uri)) { + continue; + } + seen.add(uri); + out.push(template); + } + return out; +} + +async function callListWithSessionRecovery(call, label, onSessionError) { + try { + return await withTransientRetry(call, label); + } catch (err) { + if (isSessionError(err) && typeof onSessionError === "function") { + try { + await onSessionError(); + return await withTransientRetry(call, `${label} (retry)`); + } catch (retryErr) { + debugLog(`[ctxce] ${label} failed after MCP session recovery: ` + String(retryErr)); + } + } + throw err; + } +} + +async function listMemoryTools(getClient, onSessionError) { + if (typeof getClient !== "function" || !getClient()) { return []; } try { - const remote = await withTimeout( - client.listTools(), - 5000, + const remote = await callListWithSessionRecovery( + () => { + const client = getClient(); + if (!client) { + throw new Error("Memory MCP client not initialized"); + } + const timeoutMs = getBridgeListTimeoutMs(); + return withTimeout(client.listTools(), timeoutMs, "memory tools/list"); + }, "memory tools/list", + onSessionError, ); return Array.isArray(remote?.tools) ? remote.tools.slice() : []; } catch (err) { @@ -75,6 +142,104 @@ async function listMemoryTools(client) { } } +function encodeCompositeCursor(cursorObj) { + try { + const payload = JSON.stringify(cursorObj || {}); + return Buffer.from(payload, "utf8").toString("base64"); + } catch { + return ""; + } +} + +function decodeCompositeCursor(raw) { + try { + const trimmed = (raw || "").trim(); + if (!trimmed) { + return null; + } + const decoded = Buffer.from(trimmed, "base64").toString("utf8"); + const parsed = JSON.parse(decoded); + if (!parsed || typeof parsed !== "object") { + return null; + } + return parsed; + } catch { + return null; + } +} + +async function listResourcesSafe(getClient, label, cursor, onSessionError) { + if (typeof getClient !== "function" || !getClient()) { + return { resources: [], nextCursor: null }; + } + try { + const params = cursor ? { cursor } : {}; + const remote = await callListWithSessionRecovery( + () => { + const client = getClient(); + if (!client) { + throw new Error(`${label} MCP client not initialized`); + } + const timeoutMs = getBridgeListTimeoutMs(); + return withTimeout( + client.listResources(params), + timeoutMs, + `${label} resources/list`, + ); + }, + `${label} resources/list`, + onSessionError, + ); + return { + resources: Array.isArray(remote?.resources) ? remote.resources.slice() : [], + nextCursor: + remote && typeof remote.nextCursor === "string" && remote.nextCursor + ? remote.nextCursor + : null, + }; + } catch (err) { + debugLog(`[ctxce] Error calling ${label} resources/list: ` + String(err)); + return { resources: [], nextCursor: null }; + } +} + +async function listResourceTemplatesSafe(getClient, label, cursor, onSessionError) { + if (typeof getClient !== "function" || !getClient()) { + return { resourceTemplates: [], nextCursor: null }; + } + try { + const params = cursor ? { cursor } : {}; + const remote = await callListWithSessionRecovery( + () => { + const client = getClient(); + if (!client) { + throw new Error(`${label} MCP client not initialized`); + } + const timeoutMs = getBridgeListTimeoutMs(); + return withTimeout( + client.listResourceTemplates(params), + timeoutMs, + `${label} resources/templates/list`, + ); + }, + `${label} resources/templates/list`, + onSessionError, + ); + return { + resourceTemplates: Array.isArray(remote?.resourceTemplates) + ? remote.resourceTemplates.slice() + : [], + nextCursor: + remote && typeof remote.nextCursor === "string" && remote.nextCursor + ? remote.nextCursor + : null, + }; + } catch (err) { + debugLog(`[ctxce] Error calling ${label} resources/templates/list: ` + String(err)); + return { resourceTemplates: [], nextCursor: null }; + } +} + function withTimeout(promise, ms, label) { return new Promise((resolve, reject) => { let settled = false; @@ -125,6 +290,25 @@ function getBridgeToolTimeoutMs() { } } +function getBridgeListTimeoutMs() { + try { + // Keep list operations on a separate budget from tools/call. + // Some streamable-http clients (including Codex) probe tools/resources early, + // and a short timeout here can make the bridge appear unavailable. + const raw = process.env.CTXCE_LIST_TIMEOUT_MSEC; + if (!raw) { + return 60000; + } + const parsed = Number.parseInt(String(raw), 10); + if (!Number.isFinite(parsed) || parsed <= 0) { + return 60000; + } + return parsed; + } catch { + return 60000; + } +} + function selectClientForTool(name, indexerClient, memoryClient) { if (!name) { return indexerClient; @@ -148,6 +332,7 @@ function isSessionError(error) { msg.includes("No valid session ID") || msg.includes("Mcp-Session-Id header is required") || msg.includes("Server not initialized") || + msg.includes("Received request before initialization was complete") || msg.includes("Session not found") ); } catch { @@ -270,6 +455,34 @@ function isTransientToolError(error) { return false; } } + +async function withTransientRetry(operation, label, maxAttempts, retryDelayMs) { + const attempts = Number.isFinite(maxAttempts) && maxAttempts > 0 + ? Math.floor(maxAttempts) + : getBridgeRetryAttempts(); + const delayMs = Number.isFinite(retryDelayMs) && retryDelayMs >= 0 + ? Math.floor(retryDelayMs) + : getBridgeRetryDelayMs(); + let lastError; + for (let attempt = 0; attempt < attempts; attempt += 1) { + if (attempt > 0 && delayMs > 0) { + await new Promise((resolve) => setTimeout(resolve, delayMs)); + } + try { + return await operation(); + } catch (err) { + lastError = err; + if (!isTransientToolError(err) || attempt === attempts - 1) { + throw err; + } + debugLog( + `[ctxce] ${label}: transient error (attempt ${attempt + 1}/${attempts}), retrying: ` + + String(err), + ); + } + } + throw lastError || new Error(`[ctxce] ${label}: unknown transient retry failure`); +} // MCP stdio server implemented using the official MCP TypeScript SDK. // Acts as a low-level proxy for tools, forwarding tools/list and tools/call // to the remote qdrant-indexer MCP server while adding a local `ping` tool. @@ -440,6 +653,7 @@ async function createBridgeServer(options) { let indexerClient = null; let memoryClient = null; + let lastDefaultsSyncedSessionId = ""; // Derive a simple session identifier for this bridge process. In the // future this can be made user-aware (e.g. from auth), but for now we @@ -568,6 +782,23 @@ async function createBridgeServer(options) { defaultsPayload.under = defaultUnder; } + async function ensureRemoteDefaults(force = false) { + defaultsPayload.session = sessionId; + if (!sessionId) { + return; + } + if (!force && lastDefaultsSyncedSessionId === sessionId) { + return; + } + const indexerOk = await sendSessionDefaults(indexerClient, defaultsPayload, "indexer"); + if (memoryClient) { + await sendSessionDefaults(memoryClient, defaultsPayload, "memory"); + } + if (indexerOk) { + lastDefaultsSyncedSessionId = sessionId; + } + } + async function initializeRemoteClients(forceRecreate = false) { if (!forceRecreate && indexerClient) { return; @@ -579,6 +810,22 @@ async function createBridgeServer(options) { } catch { // ignore logging failures } + try { + if (indexerClient && typeof indexerClient.close === "function") { + await indexerClient.close(); + } + } catch { + // ignore + } + try { + if (memoryClient && typeof memoryClient.close === "function") { + await memoryClient.close(); + } + } catch { + // ignore + } + indexerClient = null; + memoryClient = null; } let nextIndexerClient = null; @@ -633,15 +880,34 @@ async function createBridgeServer(options) { indexerClient = nextIndexerClient; memoryClient = nextMemoryClient; - if (Object.keys(defaultsPayload).length > 1 && indexerClient) { - await sendSessionDefaults(indexerClient, defaultsPayload, "indexer"); - if (memoryClient) { - await sendSessionDefaults(memoryClient, defaultsPayload, "memory"); - } + await ensureRemoteDefaults(true); + } + + async function refreshSessionAndSyncDefaults() { + const freshSession = resolveSessionId() || sessionId; + const changed = Boolean(freshSession && freshSession !== sessionId); + if (changed) { + sessionId = freshSession; + defaultsPayload.session = sessionId; + lastDefaultsSyncedSessionId = ""; } + await initializeRemoteClients(false); + await ensureRemoteDefaults(changed); } - await initializeRemoteClients(false); + async function recoverRemoteClientsAfterSessionError() { + const freshSession = resolveSessionId() || sessionId; + const changed = Boolean(freshSession && freshSession !== sessionId); + if (changed) { + sessionId = freshSession; + defaultsPayload.session = sessionId; + lastDefaultsSyncedSessionId = ""; + } + await initializeRemoteClients(true); + await ensureRemoteDefaults(true); + } + + await refreshSessionAndSyncDefaults(); const server = new Server( // TODO: marked as depreciated { @@ -651,6 +917,7 @@ async function createBridgeServer(options) { { capabilities: { tools: {}, + resources: {}, }, }, ); @@ -658,20 +925,36 @@ async function createBridgeServer(options) { // tools/list → fetch tools from remote indexer server.setRequestHandler(ListToolsRequestSchema, async () => { let remote; + let listError = null; try { - debugLog("[ctxce] tools/list: fetching tools from indexer"); await initializeRemoteClients(false); - if (!indexerClient) { - throw new Error("Indexer MCP client not initialized"); - } - remote = await withTimeout( - indexerClient.listTools(), - 10000, + await ensureRemoteDefaults(false); + debugLog("[ctxce] tools/list: fetching tools from indexer"); + remote = await callListWithSessionRecovery( + () => { + if (!indexerClient) { + throw new Error("Indexer MCP client not initialized"); + } + const timeoutMs = getBridgeListTimeoutMs(); + return withTimeout( + indexerClient.listTools(), + timeoutMs, + "indexer tools/list", + ); + }, "indexer tools/list", + recoverRemoteClientsAfterSessionError, ); } catch (err) { - debugLog("[ctxce] Error calling remote tools/list: " + String(err)); - const memoryToolsFallback = await listMemoryTools(memoryClient); + listError = err; + } + + if (!remote) { + debugLog("[ctxce] Error calling remote tools/list: " + String(listError)); + const memoryToolsFallback = await listMemoryTools( + () => memoryClient, + recoverRemoteClientsAfterSessionError, + ); const toolsFallback = dedupeTools([...memoryToolsFallback]); return { tools: toolsFallback }; } @@ -687,12 +970,130 @@ async function createBridgeServer(options) { } const indexerTools = Array.isArray(remote?.tools) ? remote.tools.slice() : []; - const memoryTools = await listMemoryTools(memoryClient); + const memoryTools = await listMemoryTools( + () => memoryClient, + recoverRemoteClientsAfterSessionError, + ); const tools = dedupeTools([...indexerTools, ...memoryTools]); debugLog(`[ctxce] tools/list: returning ${tools.length} tools`); return { tools }; }); + server.setRequestHandler(ListResourcesRequestSchema, async (request) => { + // Proxy resource discovery/read-through so clients that use MCP resources + // (not only tools) can access upstream indexer/memory resources directly. + await initializeRemoteClients(false); + await ensureRemoteDefaults(false); + const cursor = + request && request.params && typeof request.params.cursor === "string" + ? request.params.cursor + : null; + const decoded = decodeCompositeCursor(cursor); + const indexerCursor = + decoded && typeof decoded.i === "string" ? decoded.i : cursor; + const memoryCursor = + decoded && typeof decoded.m === "string" ? decoded.m : cursor; + if (cursor && decoded === null) { + debugLog("[ctxce] resources/list: received non-composite cursor; forwarding to both upstreams."); + } + const indexerRes = await listResourcesSafe( + () => indexerClient, + "indexer", + indexerCursor, + recoverRemoteClientsAfterSessionError, + ); + const memoryRes = await listResourcesSafe( + () => memoryClient, + "memory", + memoryCursor, + recoverRemoteClientsAfterSessionError, + ); + const resources = dedupeResources([ + ...indexerRes.resources, + ...memoryRes.resources, + ]); + const nextCursorObj = { + i: indexerRes.nextCursor || "", + m: memoryRes.nextCursor || "", + }; + const nextCursor = + nextCursorObj.i || nextCursorObj.m ? encodeCompositeCursor(nextCursorObj) : ""; + debugLog(`[ctxce] resources/list: returning ${resources.length} resources`); + return nextCursor ? { resources, nextCursor } : { resources }; + }); + + server.setRequestHandler(ListResourceTemplatesRequestSchema, async (request) => { + await initializeRemoteClients(false); + await ensureRemoteDefaults(false); + const cursor = + request && request.params && typeof request.params.cursor === "string" + ? request.params.cursor + : null; + const decoded = decodeCompositeCursor(cursor); + const indexerCursor = + decoded && typeof decoded.i === "string" ? decoded.i : cursor; + const memoryCursor = + decoded && typeof decoded.m === "string" ? decoded.m : cursor; + if (cursor && decoded === null) { + debugLog("[ctxce] resources/templates/list: received non-composite cursor; forwarding to both upstreams."); + } + const indexerRes = await listResourceTemplatesSafe( + () => indexerClient, + "indexer", + indexerCursor, + recoverRemoteClientsAfterSessionError, + ); + const memoryRes = await listResourceTemplatesSafe( + () => memoryClient, + "memory", + memoryCursor, + recoverRemoteClientsAfterSessionError, + ); + const resourceTemplates = dedupeResourceTemplates([ + ...indexerRes.resourceTemplates, + ...memoryRes.resourceTemplates, + ]); + const nextCursorObj = { + i: indexerRes.nextCursor || "", + m: memoryRes.nextCursor || "", + }; + const nextCursor = + nextCursorObj.i || nextCursorObj.m ? encodeCompositeCursor(nextCursorObj) : ""; + debugLog(`[ctxce] resources/templates/list: returning ${resourceTemplates.length} templates`); + return nextCursor ? { resourceTemplates, nextCursor } : { resourceTemplates }; + }); + + server.setRequestHandler(ReadResourceRequestSchema, async (request) => { + await refreshSessionAndSyncDefaults(); + const params = request.params || {}; + const timeoutMs = getBridgeToolTimeoutMs(); + const uri = + params && typeof params.uri === "string" ? params.uri : ""; + debugLog(`[ctxce] resources/read: ${uri}`); + + const tryRead = async (client, label) => { + if (!client) { + return null; + } + try { + return await client.readResource(params, { timeout: timeoutMs }); + } catch (err) { + debugLog(`[ctxce] resources/read failed on ${label}: ` + String(err)); + return null; + } + }; + + const indexerResult = await tryRead(indexerClient, "indexer"); + if (indexerResult) { + return indexerResult; + } + const memoryResult = await tryRead(memoryClient, "memory"); + if (memoryResult) { + return memoryResult; + } + throw new Error(`Resource ${uri} not available on any configured MCP server`); + }); + // tools/call → proxied to indexer or memory server server.setRequestHandler(CallToolRequestSchema, async (request) => { const params = request.params || {}; @@ -701,16 +1102,8 @@ async function createBridgeServer(options) { debugLog(`[ctxce] tools/call: ${name || ""}`); - // Refresh session before each call; re-init clients if session changes. - const freshSession = resolveSessionId() || sessionId; - if (freshSession && freshSession !== sessionId) { - sessionId = freshSession; - try { - await initializeRemoteClients(true); - } catch (err) { - debugLog("[ctxce] Failed to reinitialize clients after session refresh: " + String(err)); - } - } + await refreshSessionAndSyncDefaults(); + if (sessionId && (args === undefined || args === null || typeof args === "object")) { const obj = args && typeof args === "object" ? { ...args } : {}; if (!Object.prototype.hasOwnProperty.call(obj, "session")) { @@ -733,8 +1126,6 @@ async function createBridgeServer(options) { return indexerResult; } - await initializeRemoteClients(false); - const timeoutMs = getBridgeToolTimeoutMs(); const maxAttempts = getBridgeRetryAttempts(); const retryDelayMs = getBridgeRetryDelayMs(); @@ -770,6 +1161,7 @@ async function createBridgeServer(options) { String(err), ); await initializeRemoteClients(true); + await ensureRemoteDefaults(true); sessionRetried = true; continue; } @@ -843,6 +1235,13 @@ export async function runHttpMcpServer(options) { typeof options.port === "number" ? options.port : Number.parseInt(process.env.CTXCE_HTTP_PORT || "30810", 10) || 30810; + // TODO(auth): replace this boolean toggle with explicit auth modes (none|required). + // In required mode, enforce Bearer auth on /mcp with consistent 401 challenges and + // only advertise OAuth metadata/endpoints when authentication is mandatory. + // In local/dev mode, leaving OAuth discovery off avoids clients entering an + // unnecessary OAuth path for otherwise unauthenticated bridge usage. + const oauthEnabled = String(process.env.CTXCE_ENABLE_OAUTH || "").trim().toLowerCase(); + const oauthEndpointsEnabled = oauthEnabled === "1" || oauthEnabled === "true" || oauthEnabled === "yes"; const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: undefined, @@ -865,34 +1264,36 @@ export async function runHttpMcpServer(options) { // OAuth 2.0 Endpoints (RFC9728 Protected Resource Metadata + RFC7591) // ================================================================ - // OAuth metadata endpoint (RFC9728) - if (parsedUrl.pathname === "/.well-known/oauth-authorization-server") { - oauthHandler.handleOAuthMetadata(req, res, issuerUrl); - return; - } + if (oauthEndpointsEnabled) { + // OAuth metadata endpoint (RFC9728) + if (parsedUrl.pathname === "/.well-known/oauth-authorization-server") { + oauthHandler.handleOAuthMetadata(req, res, issuerUrl); + return; + } - // OAuth Dynamic Client Registration endpoint (RFC7591) - if (parsedUrl.pathname === "/oauth/register" && req.method === "POST") { - oauthHandler.handleOAuthRegister(req, res); - return; - } + // OAuth Dynamic Client Registration endpoint (RFC7591) + if (parsedUrl.pathname === "/oauth/register" && req.method === "POST") { + oauthHandler.handleOAuthRegister(req, res); + return; + } - // OAuth authorize endpoint - if (parsedUrl.pathname === "/oauth/authorize") { - oauthHandler.handleOAuthAuthorize(req, res, parsedUrl.searchParams); - return; - } + // OAuth authorize endpoint + if (parsedUrl.pathname === "/oauth/authorize") { + oauthHandler.handleOAuthAuthorize(req, res, parsedUrl.searchParams); + return; + } - // Store session endpoint (helper for login page) - if (parsedUrl.pathname === "/oauth/store-session" && req.method === "POST") { - oauthHandler.handleOAuthStoreSession(req, res); - return; - } + // Store session endpoint (helper for login page) + if (parsedUrl.pathname === "/oauth/store-session" && req.method === "POST") { + oauthHandler.handleOAuthStoreSession(req, res); + return; + } - // OAuth token endpoint - if (parsedUrl.pathname === "/oauth/token" && req.method === "POST") { - oauthHandler.handleOAuthToken(req, res); - return; + // OAuth token endpoint + if (parsedUrl.pathname === "/oauth/token" && req.method === "POST") { + oauthHandler.handleOAuthToken(req, res); + return; + } } // ================================================================ @@ -1058,4 +1459,3 @@ function detectRepoName(workspace, config) { const leaf = workspace ? path.basename(workspace) : ""; return leaf && SLUGGED_REPO_RE.test(leaf) ? leaf : null; } - diff --git a/ctx-mcp-bridge/src/resultPathMapping.js b/ctx-mcp-bridge/src/resultPathMapping.js index a3309876..46e9663d 100644 --- a/ctx-mcp-bridge/src/resultPathMapping.js +++ b/ctx-mcp-bridge/src/resultPathMapping.js @@ -442,11 +442,7 @@ export function maybeRemapToolResult(name, result, workspaceRoot) { const shouldMap = ( lower === "repo_search" || lower === "context_search" || - lower === "context_answer" || - lower.endsWith("search_tests_for") || - lower.endsWith("search_config_for") || - lower.endsWith("search_callers_for") || - lower.endsWith("search_importers_for") + lower === "context_answer" ); if (!shouldMap) { return result; diff --git a/deploy/kubernetes/configmap.yaml b/deploy/kubernetes/configmap.yaml index 9c91a893..7ace2ea0 100644 --- a/deploy/kubernetes/configmap.yaml +++ b/deploy/kubernetes/configmap.yaml @@ -45,10 +45,6 @@ data: INDEX_UPSERT_BATCH: '128' INDEX_UPSERT_RETRIES: '5' INDEX_USE_ENHANCED_AST: '1' - INFO_REQUEST_CONTEXT_LINES: '5' - INFO_REQUEST_EXPLAIN_DEFAULT: '0' - INFO_REQUEST_LIMIT: '10' - INFO_REQUEST_RELATIONSHIPS: '0' LLAMACPP_EXTRA_ARGS: '' LLAMACPP_GPU_LAYERS: '32' LLAMACPP_GPU_SPLIT: '' @@ -143,3 +139,7 @@ data: USE_GPU_DECODER: '0' USE_TREE_SITTER: '1' WATCH_DEBOUNCE_SECS: '4' + WATCH_USE_POLLING: '1' + WATCH_INIT_MAINTENANCE_ENABLED: '1' + WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES: '120' + WATCH_INIT_MAINTENANCE_RUN_ON_START: '0' diff --git a/deploy/kubernetes/indexer-services.yaml b/deploy/kubernetes/indexer-services.yaml index 123582c6..60d98cef 100644 --- a/deploy/kubernetes/indexer-services.yaml +++ b/deploy/kubernetes/indexer-services.yaml @@ -53,8 +53,9 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/watch_index.py - workingDir: /work + - -m + - scripts.watch_index + workingDir: /app env: - name: QDRANT_URL valueFrom: @@ -102,6 +103,11 @@ spec: configMapKeyRef: name: context-engine-config key: WATCH_DEBOUNCE_SECS + - name: WATCH_USE_POLLING + valueFrom: + configMapKeyRef: + name: context-engine-config + key: WATCH_USE_POLLING - name: HF_HOME value: /work/models/hf-cache - name: XDG_CACHE_HOME @@ -191,8 +197,11 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/ingest_code.py - workingDir: /work + - -m + - scripts.ingest_code + - --root + - /work + workingDir: /app env: - name: QDRANT_URL valueFrom: @@ -299,8 +308,8 @@ spec: command: - /bin/sh - -c - - PYTHONPATH=/app python /app/scripts/create_indexes.py && PYTHONPATH=/app python /app/scripts/warm_all_collections.py && PYTHONPATH=/app python /app/scripts/health_check.py - workingDir: /work + - python -m scripts.run_init_maintenance + workingDir: /app env: - name: QDRANT_URL valueFrom: diff --git a/deploy/kubernetes/learning-reranker-worker.yaml b/deploy/kubernetes/learning-reranker-worker.yaml index c1431c22..301a3a0a 100644 --- a/deploy/kubernetes/learning-reranker-worker.yaml +++ b/deploy/kubernetes/learning-reranker-worker.yaml @@ -40,8 +40,10 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/learning_reranker_worker.py + - -m + - scripts.learning_reranker_worker - --daemon + workingDir: /app resources: requests: memory: 512Mi diff --git a/deploy/kubernetes/mcp-http.yaml b/deploy/kubernetes/mcp-http.yaml index c3c71fe2..34d6f791 100644 --- a/deploy/kubernetes/mcp-http.yaml +++ b/deploy/kubernetes/mcp-http.yaml @@ -40,7 +40,9 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/mcp_memory_server.py + - -m + - scripts.mcp_memory_server + workingDir: /app ports: - name: http containerPort: 8000 @@ -59,6 +61,11 @@ spec: configMapKeyRef: name: context-engine-config key: COLLECTION_NAME + - name: MULTI_REPO_MODE + valueFrom: + configMapKeyRef: + name: context-engine-config + key: MULTI_REPO_MODE - name: EMBEDDING_MODEL valueFrom: configMapKeyRef: @@ -227,7 +234,9 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/mcp_indexer_server.py + - -m + - scripts.mcp_indexer_server + workingDir: /app ports: - name: http containerPort: 8001 @@ -246,6 +255,11 @@ spec: configMapKeyRef: name: context-engine-config key: COLLECTION_NAME + - name: MULTI_REPO_MODE + valueFrom: + configMapKeyRef: + name: context-engine-config + key: MULTI_REPO_MODE - name: EMBEDDING_MODEL valueFrom: configMapKeyRef: diff --git a/deploy/kubernetes/mcp-indexer.yaml b/deploy/kubernetes/mcp-indexer.yaml index 2fbec1a1..a56f190d 100644 --- a/deploy/kubernetes/mcp-indexer.yaml +++ b/deploy/kubernetes/mcp-indexer.yaml @@ -40,7 +40,9 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/mcp_indexer_server.py + - -m + - scripts.mcp_indexer_server + workingDir: /app ports: - name: sse containerPort: 8001 diff --git a/deploy/kubernetes/mcp-memory.yaml b/deploy/kubernetes/mcp-memory.yaml index 165076db..0000f2ec 100644 --- a/deploy/kubernetes/mcp-memory.yaml +++ b/deploy/kubernetes/mcp-memory.yaml @@ -36,7 +36,9 @@ spec: imagePullPolicy: IfNotPresent command: - python - - /app/scripts/mcp_memory_server.py + - -m + - scripts.mcp_memory_server + workingDir: /app ports: - name: sse containerPort: 8000 @@ -69,6 +71,11 @@ spec: configMapKeyRef: name: context-engine-config key: COLLECTION_NAME + - name: MULTI_REPO_MODE + valueFrom: + configMapKeyRef: + name: context-engine-config + key: MULTI_REPO_MODE - name: EMBEDDING_MODEL valueFrom: configMapKeyRef: diff --git a/deploy/kubernetes/qdrant.yaml b/deploy/kubernetes/qdrant.yaml index ba645364..3e72352c 100644 --- a/deploy/kubernetes/qdrant.yaml +++ b/deploy/kubernetes/qdrant.yaml @@ -23,7 +23,7 @@ spec: spec: containers: - name: qdrant - image: qdrant/qdrant:latest + image: qdrant/qdrant:v1.15.4 imagePullPolicy: Always ports: - name: http diff --git a/docker-compose-bindmount-checkout.yml b/docker-compose-bindmount-checkout.yml index b4e6361c..19f65c60 100644 --- a/docker-compose-bindmount-checkout.yml +++ b/docker-compose-bindmount-checkout.yml @@ -1,6 +1,6 @@ services: qdrant: - image: qdrant/qdrant:latest + image: qdrant/qdrant:v1.15.4 container_name: qdrant-db # Expose Qdrant database APIs to the host # 6333 = HTTP API, 6334 = gRPC @@ -211,7 +211,7 @@ services: - ${HOST_INDEX_PATH:-.}:/work:ro - ${HOST_INDEX_PATH:-.}/.codebase:/work/.codebase:rw - entrypoint: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/ingest_code.py"] + entrypoint: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.ingest_code --root /work"] watcher: build: @@ -235,6 +235,7 @@ services: - QWEN3_QUERY_INSTRUCTION=${QWEN3_QUERY_INSTRUCTION:-1} - QWEN3_INSTRUCTION_TEXT=${QWEN3_INSTRUCTION_TEXT} - WATCH_ROOT=/work + - CTXCE_METADATA_ROOT=${CTXCE_METADATA_ROOT:-/work} # Watcher-specific backpressure & timeouts (safer defaults) - QDRANT_TIMEOUT=60 - MAX_MICRO_CHUNKS_PER_FILE=${MAX_MICRO_CHUNKS_PER_FILE:-200} @@ -245,7 +246,7 @@ services: volumes: - ${HOST_INDEX_PATH:-.}:/work:ro - ${HOST_INDEX_PATH:-.}/.codebase:/work/.codebase:rw - entrypoint: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/watch_index.py"] + entrypoint: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.watch_index"] upload_service: @@ -262,6 +263,7 @@ services: - UPLOAD_SERVICE_PORT=8002 - QDRANT_URL=${QDRANT_URL} - WORK_DIR=/work + - CTXCE_METADATA_ROOT=${CTXCE_METADATA_ROOT:-/work} - CTXCE_ADMIN_COLLECTION_DELETE_ENABLED=0 - CTXCE_COLLECTION_REGISTRY_UNDELETE_ON_DISCOVERY=${CTXCE_COLLECTION_REGISTRY_UNDELETE_ON_DISCOVERY:-0} - COLLECTION_NAME=${COLLECTION_NAME:-codebase} @@ -300,7 +302,7 @@ services: - ${HOST_INDEX_PATH:-.}:/work:ro - ${HOST_INDEX_PATH:-.}/.codebase:/work/.codebase:rw - entrypoint: ["python", "/app/scripts/create_indexes.py"] + entrypoint: ["python", "-m", "scripts.create_indexes"] volumes: qdrant_storage: diff --git a/docker-compose.yml b/docker-compose.yml index cb5403a2..890002ce 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ version: '3.8' services: # Qdrant vector database - same as base compose qdrant: - image: qdrant/qdrant:latest + image: qdrant/qdrant:v1.15.4 container_name: qdrant-db-dev-remote ports: - "6333:6333" @@ -88,11 +88,11 @@ services: dockerfile: Dockerfile.mcp-indexer container_name: mcp-indexer-dev-remote user: "1000:1000" - # In K8s, scripts would be accessed directly at /app/scripts/ or via proper initContainer + # In K8s, scripts run as package modules from /app or via proper initContainer # For Docker Compose dev-remote simulation, create symlink so /work/scripts/ works # Use /tmp/huggingface for cache to avoid permission issues (universally writable) # Set CORRECT environment variables for HuggingFace and FastEmbed - command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/mcp_indexer_server.py"] + command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.mcp_indexer_server"] depends_on: - qdrant env_file: @@ -177,7 +177,7 @@ services: dockerfile: Dockerfile.mcp-indexer container_name: learning-worker-dev-remote user: "1000:1000" - command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/learning_reranker_worker.py --daemon"] + command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.learning_reranker_worker --daemon"] depends_on: - qdrant - mcp_indexer @@ -289,11 +289,11 @@ services: dockerfile: Dockerfile.mcp-indexer container_name: mcp-indexer-http-dev-remote user: "1000:1000" - # In K8s, scripts would be accessed directly at /app/scripts/ or via proper initContainer + # In K8s, scripts run as package modules from /app or via proper initContainer # For Docker Compose dev-remote simulation, create symlink so /work/scripts/ works # Use /tmp/huggingface for cache to avoid permission issues (universally writable) # Set CORRECT environment variables for HuggingFace and FastEmbed - command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/mcp_indexer_server.py"] + command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.mcp_indexer_server"] depends_on: - qdrant env_file: @@ -436,7 +436,7 @@ services: volumes: - workspace_pvc:/work:rw - codebase_pvc:/work/.codebase:rw - entrypoint: ["sh", "-c", "mkdir -p /tmp/logs /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && /app/scripts/wait-for-qdrant.sh && cd /app && python /app/scripts/ingest_code.py --root /work"] + entrypoint: ["sh", "-c", "mkdir -p /tmp/logs /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && /app/scripts/wait-for-qdrant.sh && cd /app && python -m scripts.ingest_code --root /work"] restart: "no" # Run once on startup, do not restart after completion cpus: 2.0 networks: @@ -469,8 +469,12 @@ services: - QWEN3_QUERY_INSTRUCTION=${QWEN3_QUERY_INSTRUCTION:-1} - QWEN3_INSTRUCTION_TEXT=${QWEN3_INSTRUCTION_TEXT} - WATCH_ROOT=${WATCH_ROOT:-/work} + - CTXCE_METADATA_ROOT=${CTXCE_METADATA_ROOT:-/work} - HOST_INDEX_PATH=/work - QDRANT_TIMEOUT=${QDRANT_TIMEOUT:-60} + - WATCH_INIT_MAINTENANCE_ENABLED=${WATCH_INIT_MAINTENANCE_ENABLED:-1} + - WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES=${WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES:-120} + - WATCH_INIT_MAINTENANCE_RUN_ON_START=${WATCH_INIT_MAINTENANCE_RUN_ON_START:-0} # Chunking config - use ${VAR:-} to properly inherit from .env (not host shell) - INDEX_SEMANTIC_CHUNKS=${INDEX_SEMANTIC_CHUNKS:-} - INDEX_MICRO_CHUNKS=${INDEX_MICRO_CHUNKS:-} @@ -493,7 +497,7 @@ services: volumes: - workspace_pvc:/work:rw - codebase_pvc:/work/.codebase:rw - command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && exec python /app/scripts/watch_index.py"] + command: ["sh", "-c", "mkdir -p /tmp/huggingface/hub /tmp/huggingface/transformers /tmp/huggingface/fastembed && cd /app && exec python -m scripts.watch_index"] cpus: 2 networks: - dev-remote-network @@ -532,7 +536,7 @@ services: command: [ "sh", "-c", - "mkdir -p /tmp/logs /work/.codebase && (chgrp -R 1000 /work/.codebase 2>/dev/null || true) && (chmod -R g+rwX /work/.codebase 2>/dev/null || true) && (find /work/.codebase -type d -exec chmod g+s {} + 2>/dev/null || true) && echo 'Starting initialization sequence...' && /app/scripts/wait-for-qdrant.sh && PYTHONPATH=/app python /app/scripts/create_indexes.py && echo 'Collections and metadata created' && python /app/scripts/warm_all_collections.py && echo 'Search caches warmed for all collections' && python /app/scripts/health_check.py && echo 'Initialization completed successfully!'" + "mkdir -p /tmp/logs /work/.codebase && (chgrp -R 1000 /work/.codebase 2>/dev/null || true) && (chmod -R g+rwX /work/.codebase 2>/dev/null || true) && (find /work/.codebase -type d -exec chmod g+s {} + 2>/dev/null || true) && echo 'Starting initialization sequence...' && cd /app && python -m scripts.run_init_maintenance && echo 'Initialization completed successfully!'" ] restart: "no" # Run once on startup networks: @@ -555,6 +559,7 @@ services: - UPLOAD_SERVICE_PORT=8002 - QDRANT_URL=${QDRANT_URL} - WORKDIR=/work + - CTXCE_METADATA_ROOT=${CTXCE_METADATA_ROOT:-/work} - MAX_BUNDLE_SIZE_MB=100 - UPLOAD_TIMEOUT_SECS=300 # Optional auth configuration (fully opt-in via .env) @@ -613,7 +618,7 @@ services: command: [ "sh", "-c", - "mkdir -p /work/.codebase && (chgrp -R 1000 /work/.codebase 2>/dev/null || true) && (chmod -R g+rwX /work/.codebase 2>/dev/null || true) && (find /work/.codebase -type d -exec chmod g+s {} + 2>/dev/null || true) && exec python scripts/upload_service.py" + "mkdir -p /work/.codebase && (chgrp -R 1000 /work/.codebase 2>/dev/null || true) && (chmod -R g+rwX /work/.codebase 2>/dev/null || true) && (find /work/.codebase -type d -exec chmod g+s {} + 2>/dev/null || true) && cd /app && exec python -m scripts.upload_service" ] healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8002/health"] @@ -674,4 +679,4 @@ networks: driver: bridge ipam: config: - - subnet: 172.20.0.0/16 \ No newline at end of file + - subnet: 172.20.0.0/16 diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 00798685..0e20fe81 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -164,13 +164,13 @@ The Learning Reranker is an **optional** self-improving ranking system that lear #### Components -**TinyScorer** (`scripts/rerank_recursive.py`) +**TinyScorer** (`scripts/rerank_recursive/recursive.py`) - 2-layer MLP neural network (~3MB per collection) - Scores query-document pairs based on learned patterns - Hot-reloads weights every 60 seconds from disk - Per-collection weights (each repo learns independently) -**Event Logger** (`scripts/rerank_events.py`) +**Event Logger** (`scripts/rerank_tools/events.py`) - Logs every search to NDJSON files at `/tmp/rerank_events/` - Records: query, candidates, initial scores, timestamps - Hourly file rotation with configurable retention @@ -233,12 +233,6 @@ Worker logs show training progress: - **Continuous Improvement**: Rankings improve over time - **Offline Capable**: Teacher runs locally, no external API calls -#### MCP Router (`scripts/mcp_router.py`) -- **Intent Classification**: Determines which MCP tool to call based on query -- **Tool Orchestration**: Routes to search, answer, memory, or index tools -- **HTTP Execution**: Executes tools via RMCP/HTTP without extra dependencies -- **Plan Mode**: Preview tool selection without execution - ## Data Flow Architecture ### Search Request Flow diff --git a/docs/BENCHMARKS.md b/docs/BENCHMARKS.md index 152348ea..91d59176 100644 --- a/docs/BENCHMARKS.md +++ b/docs/BENCHMARKS.md @@ -360,7 +360,7 @@ rerank_return_m=10 # Rerank and return top 10 The reranker (ONNX cross-encoder) scores `(query, document)` pairs. The document text is constructed from: ```python -# From scripts/rerank_local.py:prepare_pairs() +# From scripts/rerank_tools/local.py:prepare_pairs() header = f"[{language}/{kind}] {symbol_path} — {path}" doc = header + "\n" + metadata.code[:600] ``` @@ -379,7 +379,7 @@ doc = header + "\n" + metadata.code[:600] To debug what the reranker sees: ```python -from scripts.rerank_local import prepare_pairs +from scripts.rerank_tools.local import prepare_pairs from qdrant_client import QdrantClient client = QdrantClient() diff --git a/docs/CLAUDE.example.md b/docs/CLAUDE.example.md index c46fdf4e..a8466b29 100644 --- a/docs/CLAUDE.example.md +++ b/docs/CLAUDE.example.md @@ -81,7 +81,7 @@ These rules are NOT optional - favor qdrant-indexer tooling at all costs over ex Tool Roles Cheat Sheet: - - repo_search / code_search: + - repo_search: - Use for: finding relevant files/spans and inspecting raw code. - Think: "where is X implemented?", "show me usages of Y". - context_search: @@ -133,10 +133,9 @@ These rules are NOT optional - favor qdrant-indexer tooling at all costs over ex - workspace_info, list_workspaces, collection_map - set_session_defaults - Search / QA tools: - - repo_search, code_search, context_search, context_answer + - repo_search, context_search, context_answer - pattern_search (optional; structural code pattern matching, cross-language) - - search_tests_for, search_config_for, search_callers_for, search_importers_for - - change_history_for_path, expand_query + - symbol_graph, change_history_for_path, expand_query - Memory tools: - memory.set_session_defaults, memory.memory_store, memory.memory_find @@ -148,4 +147,4 @@ These rules are NOT optional - favor qdrant-indexer tooling at all costs over ex blended code + memory results instead of calling repo_search and memory.memory_find separately. - Treat expand_query and the expand flag on context_answer as expensive options: - only use them after a normal search/answer attempt failed to find good context. \ No newline at end of file + only use them after a normal search/answer attempt failed to find good context. diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index c5a86453..cd2d59ac 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -21,7 +21,6 @@ Complete environment variable reference for Context Engine. - [Lexical Vector Settings](#lexical-vector-settings) - [Ports](#ports) - [Search & Expansion](#search--expansion) -- [info_request Tool](#info_request-tool) - [Memory Blending](#memory-blending) --- @@ -135,6 +134,10 @@ Dynamic HNSW_EF tuning and intelligent query routing for 2x faster simple querie | Name | Description | Default | |------|-------------|---------| | WATCH_DEBOUNCE_SECS | Debounce between FS events | 1.5 | +| WATCH_INIT_MAINTENANCE_ENABLED | Run periodic init maintenance from watcher | 1 (enabled) | +| WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES | Minutes between init maintenance passes | 120 | +| WATCH_INIT_MAINTENANCE_RUN_ON_START | Run immediately on watcher startup instead of waiting one interval | 0 (disabled) | +| WATCH_INIT_MAINTENANCE_COMMAND_TIMEOUT_SECS | Per-command timeout for init maintenance scripts | 1800 | | INDEX_UPSERT_BATCH | Upsert batch size (watcher) | 128 | | INDEX_UPSERT_RETRIES | Retry count | 5 | | INDEX_UPSERT_BACKOFF | Seconds between retries | 0.5 | @@ -332,7 +335,11 @@ Deferred pseudo/tag generation runs asynchronously after initial indexing. | Name | Description | Default | |------|-------------|---------| | PSEUDO_BACKFILL_ENABLED | Enable async pseudo/tag backfill worker | 0 (disabled) | -| PSEUDO_DEFER_TO_WORKER | Skip inline pseudo, defer to backfill worker | 0 (disabled) | +| PSEUDO_DEFER_TO_WORKER | Foreground/background semantics: only disables inline pseudo when backfill worker is enabled | 0 (disabled) | + +Notes: +- `PSEUDO_BACKFILL_ENABLED=0` is a hard disable for the worker. +- `PSEUDO_DEFER_TO_WORKER=1` has no effect unless `PSEUDO_BACKFILL_ENABLED=1` (we keep inline pseudo enabled to avoid silently dropping pseudo/tags). ### Adaptive Span Sizing @@ -506,20 +513,13 @@ The search engine can boost files whose paths match query terms—production-gra Set `FNAME_BOOST=0` to disable, or increase (e.g., `0.25`) for stronger path weighting. -## info_request Tool - -Simplified codebase retrieval with optional explanation mode. - -| Name | Description | Default | -|------|-------------|---------| -| INFO_REQUEST_LIMIT | Default result limit for info_request queries | 10 | -| INFO_REQUEST_CONTEXT_LINES | Context lines in snippets (richer than repo_search) | 5 | - ## Output Formatting ### TOON (Token-Oriented Object Notation) -Compact output format that reduces token usage by 40-60%. +Compact display format for search results. In practice this is usually about +20-25% smaller than compact JSON for search-shaped payloads, with larger savings +only when comparing against pretty-printed JSON. | Name | Description | Default | |------|-------------|---------| @@ -577,4 +577,3 @@ docker compose run --rm indexer --root /work --no-default-excludes --exclude '/v | Large (1k+ files) | 120 (default) | 20 | 128+ | For large monorepos, set `INDEX_PROGRESS_EVERY=200` for visibility. - diff --git a/docs/DEVELOPMENT.md b/docs/DEVELOPMENT.md index e2f55b92..aa28f4db 100644 --- a/docs/DEVELOPMENT.md +++ b/docs/DEVELOPMENT.md @@ -97,7 +97,6 @@ Context-Engine/ ├── scripts/ # Core application code │ ├── mcp_memory_server.py # Memory MCP server implementation │ ├── mcp_indexer_server.py # Indexer MCP server implementation -│ ├── mcp_router.py # Intent-based tool routing │ ├── hybrid_search.py # Search algorithm implementation │ ├── ctx.py # CLI prompt enhancer │ ├── cache_manager.py # Unified caching system @@ -403,7 +402,7 @@ class TestSearchIntegration: @pytest.fixture(scope="module") def qdrant_container(self): """Set up real Qdrant container for integration tests.""" - container = DockerContainer("qdrant/qdrant:latest").with_exposed_ports(6333) + container = DockerContainer("qdrant/qdrant:v1.15.4").with_exposed_ports(6333) container.start() yield f"http://{container.get_container_host_ip()}:{container.get_exposed_port(6333)}" container.stop() @@ -574,4 +573,4 @@ curl http://localhost:18001/tools - [ ] Error handling is appropriate - [ ] Performance impact is considered -This development guide should help you get started with contributing to Context Engine. For more specific questions, refer to the code documentation or create an issue in the repository. \ No newline at end of file +This development guide should help you get started with contributing to Context Engine. For more specific questions, refer to the code documentation or create an issue in the repository. diff --git a/docs/IDE_CLIENTS.md b/docs/IDE_CLIENTS.md index 88249340..e86395d0 100644 --- a/docs/IDE_CLIENTS.md +++ b/docs/IDE_CLIENTS.md @@ -229,7 +229,7 @@ url = "http://127.0.0.1:8003/mcp" "args": ["mcp-server-qdrant"], "env": { "QDRANT_URL": "http://localhost:6333", - "COLLECTION_NAME": "my-collection", + "COLLECTION_NAME": "codebase", "EMBEDDING_MODEL": "BAAI/bge-base-en-v1.5" }, "disabled": false @@ -282,7 +282,7 @@ scripts/remote_upload_client.py --server http://context.yourcompany.com:9090 --p - **Do not send null values** to MCP tools. Omit the field or pass an empty string "" instead. - **qdrant-index examples:** - - `{"subdir":"","recreate":false,"collection":"my-collection","repo_name":"workspace"}` + - `{"subdir":"","recreate":false,"collection":"codebase","repo_name":"workspace"}` - `{"subdir":"scripts","recreate":true}` - For indexing repo root with no params, use `qdrant_index_root` (zero-arg) or call `qdrant-index` with `subdir:""`. @@ -292,7 +292,7 @@ scripts/remote_upload_client.py --server http://context.yourcompany.com:9090 --p After configuring, you should see tools from both servers: - `store`, `find` (Memory) -- `repo_search`, `code_search`, `context_search`, `context_answer` (Indexer) +- `repo_search`, `context_search`, `context_answer` (Indexer) - `qdrant_list`, `qdrant_index`, `qdrant_prune`, `qdrant_status` (Indexer) Test connectivity: @@ -327,4 +327,3 @@ When using `@context-engine-bridge/context-engine-mcp-bridge`, ensure you set `C ``` The default collection name is `codebase` unless you've configured a different one during indexing. - diff --git a/docs/MCP_API.md b/docs/MCP_API.md index 59dbb8cf..b1ba9fa6 100644 --- a/docs/MCP_API.md +++ b/docs/MCP_API.md @@ -9,7 +9,7 @@ This document provides comprehensive API documentation for all MCP (Model Contex **On this page:** - [Overview](#overview) - [Memory Server API](#memory-server-api) - `memory_store()`, `memory_find()` -- [Indexer Server API](#indexer-server-api) - `repo_search()`, `context_search()`, `context_answer()`, `info_request()`, etc. +- [Indexer Server API](#indexer-server-api) - `repo_search()`, `context_search()`, `context_answer()`, etc. - [Response Schemas](#response-schemas) - [Error Handling](#error-handling) @@ -153,7 +153,7 @@ Search stored memories using hybrid retrieval (semantic + lexical search). ### repo_search() -Perform hybrid code search combining dense semantic, lexical BM25, and optional neural reranking. +Perform code search using the configured retrieval mode. Dense semantic search can be used on its own, or combined with lexical fusion and optional neural reranking when those features are enabled. **Core Parameters:** - `query` (str or list[str], required): Search query or list of queries for query fusion @@ -172,6 +172,10 @@ Perform hybrid code search combining dense semantic, lexical BM25, and optional - `path_glob` (str or list[str], optional): Glob patterns for path filtering - `under` (str, optional): Limit search to specific directory path - `not_glob` (str or list[str], optional): Exclude paths matching these patterns +- `profile` (str, optional): Apply a focused path profile before search: + - `"tests"`: Prefer common test file paths + - `"config"`: Prefer common configuration files + - `"code"`: Prefer source-code files **Code Structure Filters:** - `symbol` (str, optional): Search for specific function, class, or variable names @@ -386,132 +390,6 @@ All `repo_search` parameters supported for context retrieval. } ``` -### info_request() - -Simplified codebase retrieval with optional explanation mode. Drop-in replacement for basic codebase retrieval tools with human-readable result descriptions. - -**Primary Parameters:** -- `info_request` (str, required): Natural language description of the code you're looking for -- `information_request` (str): Alias for `info_request` - -**Explanation Mode:** -- `include_explanation` (bool, default false): Add summary, primary_locations, related_concepts, grouped_results, and confidence metrics -- `include_relationships` (bool, default false): Add imports_from, calls, related_paths to each result - -**Filter Parameters:** -- `limit` (int): Maximum results (smart defaults: 15 for short queries, 8 for questions, 10 otherwise) -- `language` (str, optional): Filter by programming language -- `under` (str, optional): Limit search to specific directory -- `repo` (str or list[str], optional): Filter by repository name(s) -- `path_glob` (str or list[str], optional): Glob patterns for file paths - -**Snippet Options:** -- `include_snippet` (bool, default true): Include code snippets -- `context_lines` (int, default 5): Lines of context around matches - -**Returns (basic mode):** -```json -{ - "ok": true, - "results": [ - { - "score": 0.85, - "path": "/work/src/hooks/useAuth.tsx", - "symbol": "useAuth", - "start_line": 15, - "end_line": 45, - "information": "Found 'useAuth' in useAuth.tsx (lines 15-45)", - "relevance_score": 0.85, - "snippet": "export function useAuth() { ... }" - } - ], - "total": 10, - "search_strategy": "hybrid+rerank" -} -``` - -**Returns (with `include_explanation: true`):** -```json -{ - "ok": true, - "results": [...], - "total": 10, - "search_strategy": "hybrid+rerank+lang:typescript", - "summary": "Found 10 results related to 'authentication hook' across 5 files", - "primary_locations": [ - "/work/src/hooks/useAuth.tsx", - "/work/src/context/AuthContext.tsx" - ], - "related_concepts": ["auth", "hook", "context", "session", "token"], - "grouped_results": { - "by_file": { - "/work/src/hooks/useAuth.tsx": { - "count": 3, - "top_symbols": ["useAuth", "AuthProvider", "useSession"] - } - } - }, - "confidence": { - "level": "high", - "score": 0.78, - "top_score": 0.85, - "symbol_matches": 2 - }, - "query_understanding": { - "intent": "search_for_code", - "detected_language": "typescript", - "detected_symbols": ["useAuth"], - "search_strategy": "hybrid+rerank+lang:typescript" - } -} -``` - -**Returns (with `include_relationships: true`):** -```json -{ - "results": [ - { - "information": "Found 'useAuth' in useAuth.tsx (lines 15-45)", - "relationships": { - "imports_from": ["react", "@/context/AuthContext"], - "calls": ["useState", "useContext", "fetchUser"], - "symbol_path": "useAuth", - "related_paths": ["/work/src/context/AuthContext.tsx"] - } - } - ] -} -``` - -**Smart Limits:** -- Short queries (1-2 words): 15 results for broader coverage -- Question queries ("how does", "what is"): 8 results for focused answers -- Default: 10 results - -**Search Strategy Labels:** -- `hybrid` - Base hybrid search (dense + lexical) -- `+rerank` - Neural reranker applied -- `+repo_filtered` - Filtered to specific repo(s) -- `+lang:python` - Filtered by language -- `+path_filtered` - Filtered by directory - -**Environment Variables:** -- `INFO_REQUEST_LIMIT=10` - Default result limit -- `INFO_REQUEST_CONTEXT_LINES=5` - Default context lines -- `INFO_REQUEST_EXPLAIN_DEFAULT=0` - Enable explanation mode by default -- `INFO_REQUEST_RELATIONSHIPS=0` - Enable relationships by default - -**Example:** -```json -{ - "info_request": "authentication middleware", - "include_explanation": true, - "include_relationships": true, - "language": "python", - "limit": 5 -} -``` - ### qdrant_index() Index or reindex code from the mounted workspace. @@ -700,10 +578,6 @@ Supports three runtime backends via `REFRAG_RUNTIME`: On decoder error, falls back to suffix-based expansion with `"decoder_used": "fallback"`. If expansion fails entirely, returns `"ok": false` with an error message. -### code_search() - -Exact alias of `repo_search()` for discoverability. Same parameters and return shape. - ### qdrant_index_root() Index the entire workspace root (`/work`). @@ -714,45 +588,6 @@ Index the entire workspace root (`/work`). **Returns:** Subprocess result with indexing status. -### search_tests_for() - -Find test files related to a query. Presets common test file globs. - -**Parameters:** -- `query` (str or list[str], required): Search query -- `limit` (int, optional): Max results -- `include_snippet` (bool, optional): Include code snippets -- `language` (str, optional): Filter by language - -**Returns:** Same shape as `repo_search()`. - -### search_config_for() - -Find configuration files related to a query. Presets config file globs (yaml/json/toml/etc). - -**Parameters:** Same as `search_tests_for()`. - -**Returns:** Same shape as `repo_search()`. - -### search_callers_for() - -Heuristic search for callers/usages of a symbol. - -**Parameters:** -- `query` (str, required): Symbol name to find callers for -- `limit` (int, optional): Max results -- `language` (str, optional): Filter by language - -**Returns:** Same shape as `repo_search()`. - -### search_importers_for() - -Find files likely importing or referencing a module/symbol. - -**Parameters:** Same as `search_callers_for()`. - -**Returns:** Same shape as `repo_search()`. - ### pattern_search() Find structurally similar code patterns across languages. Requires `PATTERN_VECTORS=1`. @@ -932,4 +767,4 @@ Both SSE and HTTP RMCP transports expose the **same tools, arguments, and respon When in doubt, prefer the HTTP `/mcp` endpoints described in the Overview. -This API reference should enable developers to effectively integrate Context Engine's MCP tools into their applications and workflows. \ No newline at end of file +This API reference should enable developers to effectively integrate Context Engine's MCP tools into their applications and workflows. diff --git a/docs/MULTI_REPO_COLLECTIONS.md b/docs/MULTI_REPO_COLLECTIONS.md index df11f623..eed1ba9b 100644 --- a/docs/MULTI_REPO_COLLECTIONS.md +++ b/docs/MULTI_REPO_COLLECTIONS.md @@ -341,11 +341,11 @@ results = client.search( ### 4. Monitor Collection Health ```bash -# Check collection status -make qdrant-status +# Check indexer health +curl http://localhost:${FASTMCP_INDEXER_HTTP_HEALTH_PORT:-18003}/readyz # List all collections -make qdrant-list +# Use the qdrant_list MCP tool from your MCP client. # Prune stale points make prune @@ -407,4 +407,3 @@ The architecture supports future enhancements: - [MCP API Reference](MCP_API.md) - [Architecture Overview](ARCHITECTURE.md) - [Development Guide](DEVELOPMENT.md) - diff --git a/docs/OBSERVABILITY.md b/docs/OBSERVABILITY.md index 985d2b93..e9502165 100644 --- a/docs/OBSERVABILITY.md +++ b/docs/OBSERVABILITY.md @@ -101,7 +101,7 @@ from qdrant_client import QdrantClient ### Qdrant client version -Use `qdrant-client>=1.15.0,<1.16.0`. Version 1.16+ changed to `.query_points()` which breaks OpenLit's instrumentation hooks. +Use `qdrant-client==1.15.1` with `qdrant/qdrant:v1.15.4`. Version 1.16+ removed the legacy `.search()` path used by OpenLit's Qdrant instrumentation hooks. ## Disabling diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md index b3ea5c43..763b161c 100644 --- a/docs/TROUBLESHOOTING.md +++ b/docs/TROUBLESHOOTING.md @@ -151,5 +151,4 @@ docker-compose restart 1. Check this troubleshooting guide 2. Review logs: `docker compose logs mcp_indexer` 3. Verify health: `make health` -4. Check Qdrant status: `make qdrant-status` - +4. Check indexer health: `curl http://localhost:${FASTMCP_INDEXER_HTTP_HEALTH_PORT:-18003}/readyz`; use the `qdrant_status` MCP tool for collection details diff --git a/pytest.ini b/pytest.ini index 64bf831b..6a89673e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] pythonpath = . -addopts = -ra -vv --color=yes --durations=10 +addopts = -ra -vv --color=yes --durations=10 -m "not integration" asyncio_mode = auto testpaths = tests markers = @@ -9,4 +9,3 @@ markers = unit: fast unit tests that do not require services filterwarnings = ignore:The @wait_container_is_ready decorator is deprecated and will be removed in a future version:DeprecationWarning:testcontainers.core.waiting_utils - diff --git a/requirements.txt b/requirements.txt index f999c6ad..98155ef7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ # Runtime dependencies (mirrors Dockerfiles) -# Pin to 1.15.x - version 1.16+ removed .search() method which breaks OpenLit instrumentation -qdrant-client>=1.15.0,<1.16.0 +# Pinned with Qdrant server v1.15.4. qdrant-client 1.16+ removed .search(), +# which breaks OpenLit's Qdrant instrumentation hooks. +qdrant-client==1.15.1 fastembed watchdog onnxruntime @@ -42,4 +43,3 @@ python-toon>=0.1.3 # # Benchmark-only (external suite) # coir-eval - diff --git a/scripts/admin_ui.py b/scripts/admin_ui.py index 32a27d55..131a9065 100644 --- a/scripts/admin_ui.py +++ b/scripts/admin_ui.py @@ -9,10 +9,7 @@ from starlette.templating import Jinja2Templates from jinja2 import select_autoescape -try: - from scripts.workspace_state import is_staging_enabled -except Exception: - is_staging_enabled = None # type: ignore +from scripts.workspace_state import is_staging_enabled _TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" _templates = Jinja2Templates(directory=str(_TEMPLATES_DIR)) @@ -25,8 +22,9 @@ def render_admin_login( status_code: int = 200, ) -> Any: return _templates.TemplateResponse( - "admin/login.html", - {"request": request, "title": "CTXCE Admin Login", "error": error}, + request=request, + name="admin/login.html", + context={"title": "CTXCE Admin Login", "error": error}, status_code=status_code, ) @@ -37,8 +35,9 @@ def render_admin_bootstrap( status_code: int = 200, ) -> Any: return _templates.TemplateResponse( - "admin/bootstrap.html", - {"request": request, "title": "CTXCE Admin Bootstrap", "error": error}, + request=request, + name="admin/bootstrap.html", + context={"title": "CTXCE Admin Bootstrap", "error": error}, status_code=status_code, ) @@ -54,16 +53,16 @@ def render_admin_acl( status_code: int = 200, ) -> Any: return _templates.TemplateResponse( - "admin/acl.html", - { - "request": request, + request=request, + name="admin/acl.html", + context={ "title": "CTXCE Admin ACL", "users": users, "collections": collections, "grants": grants, "deletion_enabled": bool(deletion_enabled), "work_dir": work_dir, - "staging_enabled": bool(is_staging_enabled() if callable(is_staging_enabled) else False), + "staging_enabled": bool(is_staging_enabled()), "refresh_ms": int(refresh_ms) if refresh_ms is not None else 5000, }, status_code=status_code, @@ -78,9 +77,9 @@ def render_admin_error( status_code: int = 400, ) -> Any: return _templates.TemplateResponse( - "admin/error.html", - { - "request": request, + request=request, + name="admin/error.html", + context={ "title": title, "message": message, "back_href": back_href, diff --git a/scripts/benchmarks/__init__.py b/scripts/benchmarks/__init__.py index 4d601d5d..91224475 100644 --- a/scripts/benchmarks/__init__.py +++ b/scripts/benchmarks/__init__.py @@ -28,7 +28,6 @@ # Component benchmarks (import on demand) # - eval_harness # - trm_bench - # - router_bench # - refrag_bench # - expand_bench # - run_all diff --git a/scripts/benchmarks/auto_tuner.py b/scripts/benchmarks/auto_tuner.py index fe5f69f6..f6063e8f 100644 --- a/scripts/benchmarks/auto_tuner.py +++ b/scripts/benchmarks/auto_tuner.py @@ -16,14 +16,11 @@ import math import os import statistics -import sys import time from dataclasses import dataclass, field, asdict from datetime import datetime -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # --------------------------------------------------------------------------- # Statistical Utilities diff --git a/scripts/benchmarks/coir/indexer.py b/scripts/benchmarks/coir/indexer.py index 60a84903..113222d5 100644 --- a/scripts/benchmarks/coir/indexer.py +++ b/scripts/benchmarks/coir/indexer.py @@ -7,15 +7,8 @@ """ from __future__ import annotations -import sys -from pathlib import Path from typing import Any, Dict, List -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - -from qdrant_client import QdrantClient - from scripts.benchmarks.core_indexer import ( BenchmarkDoc, index_benchmark_corpus, @@ -172,4 +165,4 @@ def cleanup_coir_collections(task_names: List[str] | None = None) -> int: except Exception as e: print(f"Failed to list collections: {e}") - return deleted \ No newline at end of file + return deleted diff --git a/scripts/benchmarks/coir/retriever.py b/scripts/benchmarks/coir/retriever.py index 9d3bbfd2..dfae835c 100644 --- a/scripts/benchmarks/coir/retriever.py +++ b/scripts/benchmarks/coir/retriever.py @@ -25,9 +25,7 @@ import asyncio import os -import sys from concurrent.futures import ThreadPoolExecutor -from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np @@ -35,9 +33,6 @@ # Shared utilities from scripts.benchmarks.qdrant_utils import probe_pseudo_tags, verify_config_compatibility, get_qdrant_client -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - # Read .env settings EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") RERANKER_ENABLED = os.environ.get("RERANKER_ENABLED", "true").lower() in ("true", "1", "yes") diff --git a/scripts/benchmarks/cosqa/indexer.py b/scripts/benchmarks/cosqa/indexer.py index 9117fbc9..698d9699 100644 --- a/scripts/benchmarks/cosqa/indexer.py +++ b/scripts/benchmarks/cosqa/indexer.py @@ -8,13 +8,8 @@ from __future__ import annotations import os -import sys -from pathlib import Path from typing import Any, Dict, List -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - from scripts.benchmarks.core_indexer import ( BenchmarkDoc, index_benchmark_corpus, diff --git a/scripts/benchmarks/cosqa/pca_init.py b/scripts/benchmarks/cosqa/pca_init.py index f66cc423..686ce57d 100644 --- a/scripts/benchmarks/cosqa/pca_init.py +++ b/scripts/benchmarks/cosqa/pca_init.py @@ -7,14 +7,10 @@ """ import os import sys -from pathlib import Path from typing import List, Dict, Any import numpy as np -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - def compute_pca_init_for_collection( collection: str, @@ -141,4 +137,3 @@ def compute_pca_init_for_collection( ) sys.exit(0 if success else 1) - diff --git a/scripts/benchmarks/cosqa/run_search_matrix.sh b/scripts/benchmarks/cosqa/run_search_matrix.sh new file mode 100755 index 00000000..d2eece9a --- /dev/null +++ b/scripts/benchmarks/cosqa/run_search_matrix.sh @@ -0,0 +1,315 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../.." && pwd)" +cd "${ROOT_DIR}" + +PYTHON_BIN="${PYTHON_BIN:-}" +if [ -z "${PYTHON_BIN}" ]; then + if command -v python3.11 >/dev/null 2>&1; then + PYTHON_BIN="python3.11" + elif command -v python3 >/dev/null 2>&1; then + PYTHON_BIN="python3" + elif command -v python >/dev/null 2>&1; then + PYTHON_BIN="python" + else + echo "No Python interpreter found (looked for python3.11/python3/python)." >&2 + exit 127 + fi +fi +RUN_TAG="${RUN_TAG:-$(date +%Y%m%d-%H%M%S)}" +PROFILE="${PROFILE:-full}" # smoke | quick | full +RUN_SET="${RUN_SET:-full}" # pr | knobs | nightly | full +OUT_DIR="${OUT_DIR:-bench_results/cosqa/${RUN_TAG}}" +LOG_DIR="${LOG_DIR:-${OUT_DIR}}" +SPLIT="${SPLIT:-test}" +COLLECTION="${COLLECTION:-cosqa-search-${RUN_TAG}}" +LIMIT="${LIMIT:-10}" +RECREATE_INDEX="${RECREATE_INDEX:-1}" +ENFORCE_HYBRID_GATE="${ENFORCE_HYBRID_GATE:-0}" +HYBRID_MIN_DELTA="${HYBRID_MIN_DELTA:--0.020}" + +case "${PROFILE}" in + smoke) + : "${CORPUS_LIMIT:=150}" + : "${QUERY_LIMIT:=30}" + ;; + quick) + : "${CORPUS_LIMIT:=500}" + : "${QUERY_LIMIT:=100}" + ;; + full) + : "${CORPUS_LIMIT:=0}" + : "${QUERY_LIMIT:=0}" + ;; + *) + echo "Unknown PROFILE='${PROFILE}'. Use smoke|quick|full" >&2 + exit 2 + ;; +esac + +mkdir -p "${OUT_DIR}" "${LOG_DIR}" + +BASE_ENV=( + "LOG_LEVEL=${LOG_LEVEL:-INFO}" + "DEBUG_HYBRID_SEARCH=${DEBUG_HYBRID_SEARCH:-0}" + "QDRANT_URL=${QDRANT_URL:-http://localhost:6333}" + "HYBRID_IN_PROCESS=${HYBRID_IN_PROCESS:-1}" + "RERANK_IN_PROCESS=${RERANK_IN_PROCESS:-1}" + "LEX_VECTOR_DIM=${LEX_VECTOR_DIM:-4096}" + "COSQA_QUERY_CONCURRENCY=${COSQA_QUERY_CONCURRENCY:-8}" + "LLM_EXPAND_MAX=0" + "REFRAG_DECODER=0" + "RERANK_LEARNING=0" + "RERANK_EVENTS_ENABLED=0" +) + +run_index_once() { + local log="${LOG_DIR}/cosqa_index.log" + local args=( + "-m" "scripts.benchmarks.cosqa.runner" + "--split" "${SPLIT}" + "--collection" "${COLLECTION}" + "--limit" "${LIMIT}" + "--index-only" + ) + + if [ "${CORPUS_LIMIT}" -gt 0 ]; then + args+=("--corpus-limit" "${CORPUS_LIMIT}") + fi + if [ "${QUERY_LIMIT}" -gt 0 ]; then + args+=("--query-limit" "${QUERY_LIMIT}") + fi + if [ "${RECREATE_INDEX}" = "1" ]; then + args+=("--recreate") + fi + + echo "[index] collection=${COLLECTION} corpus_limit=${CORPUS_LIMIT} query_limit=${QUERY_LIMIT}" | tee "${log}" + ( + export "${BASE_ENV[@]}" + "${PYTHON_BIN}" "${args[@]}" + ) >> "${log}" 2>&1 +} + +preflight_python_deps() { + "${PYTHON_BIN}" - <<'PY' +import importlib.util + +required = ["qdrant_client", "datasets"] +missing = [m for m in required if importlib.util.find_spec(m) is None] +if missing: + raise SystemExit( + "Missing Python deps for CoSQA benchmark: " + + ", ".join(missing) + + ". Install them before running." + ) +PY +} + +verify_collection_ready() { + "${PYTHON_BIN}" - "${COLLECTION}" <<'PY' +import os +import sys +from qdrant_client import QdrantClient + +collection = sys.argv[1] +url = os.environ.get("QDRANT_URL", "http://localhost:6333") +client = QdrantClient(url=url, timeout=60) +info = client.get_collection(collection) +points = int(info.points_count or 0) +if points <= 0: + raise RuntimeError(f"Collection '{collection}' has no points after indexing") +print(f"[verify] collection={collection} points={points}") +PY +} + +run_case() { + local label="$1" + local mode="$2" + local rerank="$3" + local expand="$4" + local lex_mode="$5" + shift 5 + + local output="${OUT_DIR}/cosqa_${label}.json" + local log="${LOG_DIR}/cosqa_${label}.log" + + local args=( + "-m" "scripts.benchmarks.cosqa.runner" + "--split" "${SPLIT}" + "--collection" "${COLLECTION}" + "--limit" "${LIMIT}" + "--skip-index" + "--mode" "${mode}" + "--output" "${output}" + ) + + if [ "${CORPUS_LIMIT}" -gt 0 ]; then + args+=("--corpus-limit" "${CORPUS_LIMIT}") + fi + if [ "${QUERY_LIMIT}" -gt 0 ]; then + args+=("--query-limit" "${QUERY_LIMIT}") + fi + if [ "${rerank}" = "0" ]; then + args+=("--no-rerank") + fi + if [ "${expand}" = "0" ]; then + args+=("--no-expand") + fi + + local case_env=("HYBRID_LEXICAL_TEXT_MODE=${lex_mode}") + for kv in "$@"; do + case_env+=("${kv}") + done + + echo "[run] ${label} mode=${mode} rerank=${rerank} expand=${expand} lex_mode=${lex_mode}" | tee "${log}" + ( + export "${BASE_ENV[@]}" + export "${case_env[@]}" + "${PYTHON_BIN}" "${args[@]}" + ) >> "${log}" 2>&1 + + echo "[ok] ${output}" +} + +CASES=() +case "${RUN_SET}" in + pr) + CASES=( + "dense_norerank|dense|0|0|raw" + "hybrid_rerank_lexrrf|hybrid|1|0|rrf" + "hybrid_rerank_expand_lexrrf|hybrid|1|1|rrf" + ) + ;; + knobs) + CASES=( + "dense_norerank|dense|0|0|raw" + "dense_rerank|dense|1|0|raw" + "hybrid_norerank_lexraw|hybrid|0|0|raw" + "hybrid_norerank_lexrrf|hybrid|0|0|rrf" + "hybrid_rerank_lexraw|hybrid|1|0|raw" + "hybrid_rerank_lexrrf|hybrid|1|0|rrf" + "hybrid_rerank_expand_lexrrf|hybrid|1|1|rrf" + "lexical_norerank|lexical|0|0|raw" + ) + ;; + nightly) + CASES=( + "dense_norerank|dense|0|0|raw" + "dense_rerank|dense|1|0|raw" + "hybrid_norerank_lexraw|hybrid|0|0|raw" + "hybrid_norerank_lexrrf|hybrid|0|0|rrf" + "hybrid_rerank_lexraw|hybrid|1|0|raw" + "hybrid_rerank_lexrrf|hybrid|1|0|rrf" + "hybrid_rerank_expand_lexrrf|hybrid|1|1|rrf" + "lexical_norerank|lexical|0|0|raw" + ) + ;; + full) + CASES=( + "dense_norerank|dense|0|0|raw" + "dense_rerank|dense|1|0|raw" + "hybrid_norerank_lexraw|hybrid|0|0|raw" + "hybrid_norerank_lexrrf|hybrid|0|0|rrf" + "hybrid_rerank_lexraw|hybrid|1|0|raw" + "hybrid_rerank_lexrrf|hybrid|1|0|rrf" + "hybrid_rerank_expand_lexrrf|hybrid|1|1|rrf" + "lexical_norerank|lexical|0|0|raw" + ) + ;; + *) + echo "Unknown RUN_SET='${RUN_SET}'. Use pr|knobs|nightly|full" >&2 + exit 2 + ;; +esac + +echo "[config] run_tag=${RUN_TAG} profile=${PROFILE} run_set=${RUN_SET} out_dir=${OUT_DIR}" +preflight_python_deps +run_index_once +verify_collection_ready + +for spec in "${CASES[@]}"; do + IFS='|' read -r label mode rerank expand lex_mode <<< "${spec}" + run_case "${label}" "${mode}" "${rerank}" "${expand}" "${lex_mode}" +done + +"${PYTHON_BIN}" - "${OUT_DIR}" "${ENFORCE_HYBRID_GATE}" "${HYBRID_MIN_DELTA}" <<'PY' +import json +import sys +from pathlib import Path + +out_dir = Path(sys.argv[1]) +enforce_gate = str(sys.argv[2]).strip() in {"1", "true", "yes"} +min_delta = float(sys.argv[3]) + +rows = [] +for path in sorted(out_dir.glob("cosqa_*.json")): + if path.name.startswith("cosqa_index") or path.name.endswith("_meta.json") or path.name.startswith("summary"): + continue + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict) or "metrics" not in data or "config" not in data: + continue + metrics = data.get("metrics") or {} + config = data.get("config") or {} + env = (config.get("env") or {}) if isinstance(config, dict) else {} + rows.append({ + "label": path.stem.replace("cosqa_", ""), + "mode": config.get("mode", ""), + "rerank": bool(config.get("rerank_enabled", False)), + "expand": env.get("HYBRID_EXPAND", ""), + "lex_mode": env.get("HYBRID_LEXICAL_TEXT_MODE", ""), + "mrr": float(metrics.get("mrr", 0.0) or 0.0), + "recall_10": float(metrics.get("recall@10", 0.0) or 0.0), + "ndcg_10": float(metrics.get("ndcg@10", 0.0) or 0.0), + "lat_ms": float((data.get("latency") or {}).get("avg_ms", 0.0) or 0.0), + "file": path.name, + }) + +if not rows: + print("No CoSQA result JSON files found.", file=sys.stderr) + sys.exit(3) + +rows.sort(key=lambda r: (-r["mrr"], -r["recall_10"])) + +summary = { + "ranked": rows, + "best": rows[0], +} +(out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") + +lines = [ + "# CoSQA Search Matrix Summary", + "", + "| Rank | Label | Mode | Rerank | Expand | LexMode | MRR | R@10 | NDCG@10 | Avg Lat (ms) |", + "|---:|---|---|---:|---:|---|---:|---:|---:|---:|", +] +for i, r in enumerate(rows, start=1): + lines.append( + f"| {i} | {r['label']} | {r['mode']} | {int(r['rerank'])} | {r['expand']} | {r['lex_mode']} | " + f"{r['mrr']:.4f} | {r['recall_10']:.4f} | {r['ndcg_10']:.4f} | {r['lat_ms']:.2f} |" + ) + +best_dense = max((r for r in rows if r["mode"] == "dense"), key=lambda r: r["mrr"], default=None) +best_hybrid = max((r for r in rows if r["mode"] == "hybrid"), key=lambda r: r["mrr"], default=None) +if best_dense and best_hybrid: + delta = best_hybrid["mrr"] - best_dense["mrr"] + lines.append("") + lines.append( + f"Best hybrid ({best_hybrid['label']}) vs best dense ({best_dense['label']}): " + f"delta MRR = {delta:+.4f}" + ) + if enforce_gate and delta < min_delta: + lines.append( + f"Gate failed: hybrid delta {delta:+.4f} is below required minimum {min_delta:+.4f}" + ) + (out_dir / "summary.md").write_text("\n".join(lines) + "\n", encoding="utf-8") + print("\n".join(lines)) + sys.exit(4) + +(out_dir / "summary.md").write_text("\n".join(lines) + "\n", encoding="utf-8") +print("\n".join(lines)) +PY + +echo "[done] results=${OUT_DIR}" +echo "[done] summary=${OUT_DIR}/summary.md" diff --git a/scripts/benchmarks/cosqa/runner.py b/scripts/benchmarks/cosqa/runner.py index 8e0c32b1..770c0843 100644 --- a/scripts/benchmarks/cosqa/runner.py +++ b/scripts/benchmarks/cosqa/runner.py @@ -61,12 +61,6 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -from dotenv import load_dotenv - -# Load .env immediately to ensure all subsequent imports (like scripts.ingest.config) -# see the correct environment variables. -load_dotenv(override=True) - from scripts.benchmarks.qdrant_utils import ( get_qdrant_client, probe_pseudo_tags, @@ -77,9 +71,6 @@ os.environ["OPENLIT_ENABLED"] = "0" os.environ["OTEL_SDK_DISABLED"] = "true" -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - # NOTE: .env loading moved to _load_benchmark_env() to avoid polluting # environment when this module is imported (e.g., by tests or __init__.py). # Call _load_benchmark_env() explicitly before running benchmarks. @@ -442,7 +433,16 @@ def _cosqa_id_from_path(p: str) -> Optional[str]: name = s.rsplit("/", 1)[-1] if name.endswith(".py"): name = name[: -3] - return name.strip() or None + name = name.strip() + if not name: + return None + # CoSQA synthetic filenames are often "__". + # Recover canonical code_id so relevance matching aligns with qrels. + if "__" in name: + tail = name.rsplit("__", 1)[-1].strip() + if tail.startswith("cosqa-"): + return tail + return name # Extract stable code_ids for evaluation. # NOTE: rerank paths may not include payload; for CoSQA we can fall back to parsing @@ -915,8 +915,8 @@ async def run_full_benchmark( print(f" Limited corpus to {len(corpus)} entries") if skip_index: - print(" [skip-index] Skipping indexing...") - result = {"reused": True, "indexed": len(corpus), "skipped": 0, "errors": 0} + print(" [skip-index] Skipping indexing (using existing collection as-is)...") + result = {"reused": False, "indexed": 0, "skipped": len(corpus), "errors": 0} else: # Check if already indexed (use fingerprint matching, not just points_count) # The indexer handles fingerprint checking internally and will recreate if needed @@ -1018,6 +1018,12 @@ def main(): help="Search mode: 'hybrid' (default), 'dense' (pure semantic), or 'lexical' (pure BM25-style)") args = parser.parse_args() + # Benchmarks must not require MCP auth sessions. + # runner imports dotenv at module import time with override=True, so enforce this + # after args parsing to guarantee process-local benchmark behavior. + os.environ["CTXCE_AUTH_ENABLED"] = "0" + os.environ["CTXCE_MCP_ACL_ENFORCE"] = "0" + # Enable Context-Engine features for accurate benchmarking. # Semantic expansion is always enabled (it may still be a no-op if query expansion is disabled). os.environ["SEMANTIC_EXPANSION_ENABLED"] = "1" diff --git a/scripts/benchmarks/efficiency_benchmark.py b/scripts/benchmarks/efficiency_benchmark.py index af9dbf55..b6f08e54 100644 --- a/scripts/benchmarks/efficiency_benchmark.py +++ b/scripts/benchmarks/efficiency_benchmark.py @@ -10,17 +10,11 @@ import asyncio import json import os -import sys import time import hashlib from dataclasses import dataclass, field, asdict -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple -# Add project root to path -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -# Shared stats helpers (after sys.path setup) from scripts.benchmarks.common import percentile, extract_result_paths, resolve_collection_auto # Ensure correct collection is used (read from workspace state or env) @@ -228,7 +222,7 @@ def compute_recall_at_k(expected_paths: List[str], result_paths: List[str], k: i "queries": [ {"tool": "repo_search", "query": "init_openlit error handling"}, {"tool": "symbol_graph", "symbol": "init_openlit", "query_type": "callers"}, - {"tool": "search_tests_for", "query": "openlit initialization"}, + {"tool": "repo_search", "query": "openlit initialization", "profile": "tests"}, ], "expected_paths": ["openlit_init.py", "test_openlit"], }, @@ -237,7 +231,7 @@ def compute_recall_at_k(expected_paths: List[str], result_paths: List[str], k: i "queries": [ {"tool": "repo_search", "query": "memory store implementation pattern"}, {"tool": "context_answer", "query": "How does memory_store work?"}, - {"tool": "search_config_for", "query": "memory collection settings"}, + {"tool": "repo_search", "query": "memory collection settings", "profile": "config"}, ], "expected_paths": ["mcp_impl/memory.py", "memory_store"], }, @@ -245,7 +239,7 @@ def compute_recall_at_k(expected_paths: List[str], result_paths: List[str], k: i "description": "Trace dependencies across multiple files", "queries": [ {"tool": "symbol_graph", "symbol": "get_embedding_model", "query_type": "callers"}, - {"tool": "search_importers_for", "query": "embedder"}, + {"tool": "repo_search", "query": "embedder", "profile": "code"}, {"tool": "repo_search", "query": "embedding dimension vector size"}, ], "expected_paths": ["embedder.py", "rerank_recursive"], @@ -460,18 +454,12 @@ async def run_benchmark( repo_search, context_answer, symbol_graph, - search_tests_for, - search_config_for, - search_importers_for, memory_find, ) tool_registry = { "repo_search": repo_search, "context_answer": context_answer, "symbol_graph": symbol_graph, - "search_tests_for": search_tests_for, - "search_config_for": search_config_for, - "search_importers_for": search_importers_for, "memory_find": memory_find, } except ImportError as e: diff --git a/scripts/benchmarks/eval_harness.py b/scripts/benchmarks/eval_harness.py index efb7611f..97691fea 100644 --- a/scripts/benchmarks/eval_harness.py +++ b/scripts/benchmarks/eval_harness.py @@ -10,14 +10,12 @@ import asyncio import json import os -import sys import time from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any, Dict, List, Optional import statistics -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # Shared stats helpers from scripts.benchmarks.common import ( diff --git a/scripts/benchmarks/expand_bench.py b/scripts/benchmarks/expand_bench.py index 86330572..82da2ae8 100644 --- a/scripts/benchmarks/expand_bench.py +++ b/scripts/benchmarks/expand_bench.py @@ -9,14 +9,11 @@ import asyncio import json import os -import sys import time from dataclasses import dataclass, field, asdict -from pathlib import Path from typing import Any, Dict, List, Optional import statistics -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # Load environment (optional) and fix Docker hostname try: diff --git a/scripts/benchmarks/grounding_scorer.py b/scripts/benchmarks/grounding_scorer.py index 669d85a8..03d8d690 100644 --- a/scripts/benchmarks/grounding_scorer.py +++ b/scripts/benchmarks/grounding_scorer.py @@ -13,15 +13,12 @@ import json import os import re -import sys from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple -# Add project root to path PROJECT_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) # Load environment variables from .env try: diff --git a/scripts/benchmarks/recommendations.py b/scripts/benchmarks/recommendations.py index db1968eb..3a973847 100644 --- a/scripts/benchmarks/recommendations.py +++ b/scripts/benchmarks/recommendations.py @@ -208,19 +208,6 @@ "impacts": ["token_usage", "recall"], "category": "limits", }, - "INFO_REQUEST_LIMIT": { - "description": "Default limit for info_request tool", - "default": 10, "type": "int", "range": (3, 20), - "impacts": ["token_usage"], - "category": "limits", - }, - "INFO_REQUEST_CONTEXT_LINES": { - "description": "Context lines for info_request", - "default": 5, "type": "int", "range": (0, 15), - "impacts": ["context_density"], - "category": "limits", - }, - # === Lexical Search === "LEX_MULTI_HASH": { "description": "Multi-hash buckets per token (reduces collisions)", diff --git a/scripts/benchmarks/refrag_bench.py b/scripts/benchmarks/refrag_bench.py index 534f7f76..943bc6a4 100644 --- a/scripts/benchmarks/refrag_bench.py +++ b/scripts/benchmarks/refrag_bench.py @@ -9,14 +9,11 @@ import asyncio import json import os -import sys import time from dataclasses import dataclass, field, asdict -from pathlib import Path from typing import Any, Dict, List import statistics -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # Load environment (optional) and fix Docker hostname try: diff --git a/scripts/benchmarks/router_bench.py b/scripts/benchmarks/router_bench.py deleted file mode 100644 index 0625f600..00000000 --- a/scripts/benchmarks/router_bench.py +++ /dev/null @@ -1,236 +0,0 @@ -#!/usr/bin/env python3 -""" -Router Benchmark for Context-Engine - -Measures tool selection accuracy, routing latency, and decision quality. -""" - -import argparse -import asyncio -import json -import os -import sys -import time -from dataclasses import dataclass, field, asdict -from pathlib import Path -from typing import Any, Dict, List, Optional -import statistics - -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -# Load environment (optional) and fix Docker hostname -try: - from dotenv import load_dotenv # type: ignore - load_dotenv() -except Exception: - pass -if "qdrant:" in os.environ.get("QDRANT_URL", ""): - os.environ["QDRANT_URL"] = "http://localhost:6333" - -from scripts.benchmarks.common import ( - percentile, - create_report, - QueryResult as CommonQueryResult, - resolve_collection_auto, -) - -# Ensure correct collection is used (read from workspace state or env) -if not os.environ.get("COLLECTION_NAME"): - try: - from scripts.workspace_state import get_collection_name - os.environ["COLLECTION_NAME"] = get_collection_name() or "codebase" - except Exception: - os.environ["COLLECTION_NAME"] = "codebase" -else: - # If COLLECTION_NAME is set but empty/unindexed, pick a non-empty collection for benchmarks. - try: - os.environ["COLLECTION_NAME"] = resolve_collection_auto(os.environ.get("COLLECTION_NAME")) - except Exception: - pass - -print( - f"[bench] Using QDRANT_URL={os.environ.get('QDRANT_URL', '')} " - f"COLLECTION_NAME={os.environ.get('COLLECTION_NAME', '')}" -) - - -@dataclass -class RouterResult: - """Result from router evaluation.""" - query: str - expected_tool: str - selected_tool: str - confidence: float - latency_ms: float - correct: bool - - -@dataclass -class RouterReport: - """Router benchmark report.""" - name: str - total_queries: int - accuracy: float - avg_confidence: float - avg_latency_ms: float - p90_latency_ms: float - results: List[RouterResult] = field(default_factory=list) - - def to_dict(self) -> Dict[str, Any]: - base = { - "name": self.name, - "total_queries": self.total_queries, - "metrics": { - "accuracy": round(self.accuracy, 4), - "avg_confidence": round(self.avg_confidence, 4), - "avg_latency_ms": round(self.avg_latency_ms, 2), - "p90_latency_ms": round(self.p90_latency_ms, 2), - }, - "results": [asdict(r) for r in self.results], - } - # Also emit the unified BenchmarkReport shape for downstream tooling. - rep = create_report("router_bench", config={"name": self.name}) - for r in self.results: - rep.per_query.append( - CommonQueryResult( - query=r.query, - latency_ms=r.latency_ms, - metrics={ - "correct": 1.0 if r.correct else 0.0, - "confidence": float(r.confidence or 0.0), - }, - retrieved_paths=[], - metadata={ - "expected_tool": r.expected_tool, - "selected_tool": r.selected_tool, - }, - ) - ) - rep.compute_aggregates() - base["unified"] = rep.to_dict() - return base - - -ROUTER_TEST_CASES = [ - {"query": "find files that import embedder module", "expected": "search_importers_for"}, - {"query": "who calls the init_openlit function", "expected": "symbol_graph"}, - {"query": "explain how the hybrid search works", "expected": "context_answer"}, - {"query": "search for memory store implementation", "expected": "repo_search"}, - {"query": "find tests for the reranker", "expected": "search_tests_for"}, - {"query": "what configs exist for qdrant", "expected": "search_config_for"}, - {"query": "store a note about this finding", "expected": "memory_store"}, - {"query": "recall notes about authentication", "expected": "memory_find"}, -] - -TOOL_ALIASES: Dict[str, str] = { - # Search - "code_search": "repo_search", - "repo_search_compat": "repo_search", - # Answer - "context_answer_compat": "context_answer", - # Memory - "find": "memory_find", - "store": "memory_store", -} - - -def _canonical_tool_name(name: Any) -> str: - if not name: - return "unknown" - s = str(name).strip() - if not s: - return "unknown" - return TOOL_ALIASES.get(s, s) - - -async def run_router_benchmark(name: str = "default") -> RouterReport: - """Run router benchmark.""" - try: - from scripts.mcp_router import route_query - except ImportError: - try: - from scripts.mcp_router.router import route_query - except ImportError as e: - print(f"Import error: {e}") - return RouterReport(name=name, total_queries=0, accuracy=0, - avg_confidence=0, avg_latency_ms=0, p90_latency_ms=0) - - results: List[RouterResult] = [] - latencies: List[float] = [] - - for case in ROUTER_TEST_CASES: - query = case["query"] - expected = case["expected"] - - start = time.perf_counter() - try: - # Try to route the query - route_result = await route_query(query) - if isinstance(route_result, dict): - selected = route_result.get("tool", "unknown") - confidence = route_result.get("confidence", 0.0) - else: - selected = str(route_result) - confidence = 0.5 - except Exception as e: - selected = "error" - confidence = 0.0 - elapsed_ms = (time.perf_counter() - start) * 1000 - latencies.append(elapsed_ms) - - selected_c = _canonical_tool_name(selected) - expected_c = _canonical_tool_name(expected) - correct = selected_c == expected_c - results.append(RouterResult( - query=query, - expected_tool=expected, - selected_tool=selected_c, - confidence=confidence, - latency_ms=elapsed_ms, - correct=correct, - )) - status = "✓" if correct else "✗" - print(f" {status} {query[:35]:35} → {selected_c} (exp: {expected_c})") - - correct_count = sum(1 for r in results if r.correct) - - return RouterReport( - name=name, - total_queries=len(results), - accuracy=correct_count / len(results) if results else 0, - avg_confidence=statistics.mean(r.confidence for r in results) if results else 0, - avg_latency_ms=statistics.mean(latencies) if latencies else 0, - p90_latency_ms=percentile(latencies, 0.90), - results=results, - ) - - -def print_report(report: RouterReport): - print("\n" + "=" * 60) - print(f"ROUTER BENCHMARK: {report.name}") - print("=" * 60) - print(f"Queries: {report.total_queries}") - print(f"Accuracy: {report.accuracy:.1%}") - print(f"Avg Confidence: {report.avg_confidence:.4f}") - print(f"Avg Latency: {report.avg_latency_ms:.2f}ms") - print(f"P90 Latency: {report.p90_latency_ms:.2f}ms") - print("=" * 60) - - -def main(): - parser = argparse.ArgumentParser(description="Router Benchmark") - parser.add_argument("--name", default="default", help="Benchmark name") - parser.add_argument("--output", type=str, help="Output JSON file") - args = parser.parse_args() - - print(f"Running router benchmark: {args.name}") - report = asyncio.run(run_router_benchmark(name=args.name)) - print_report(report) - - if args.output: - with open(args.output, "w") as f: - json.dump(report.to_dict(), f, indent=2) - - -if __name__ == "__main__": - main() diff --git a/scripts/benchmarks/rrf_quality.py b/scripts/benchmarks/rrf_quality.py index f962392b..42ab7444 100644 --- a/scripts/benchmarks/rrf_quality.py +++ b/scripts/benchmarks/rrf_quality.py @@ -20,9 +20,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional -# Add project root to path PROJECT_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) # Load environment variables from .env try: @@ -161,7 +159,7 @@ def _match_expected_file(path: str, expected_files: Iterable[str]) -> Optional[s }, { "query": "MCP tool registration fastmcp", - "expected_files": ["mcp_indexer_server.py", "mcp_router/__init__.py"], + "expected_files": ["mcp_indexer_server.py", "mcp_impl/search.py"], }, { "query": "memory store find operations", diff --git a/scripts/benchmarks/run_all.py b/scripts/benchmarks/run_all.py index 7921c477..d1fb31f9 100644 --- a/scripts/benchmarks/run_all.py +++ b/scripts/benchmarks/run_all.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import Any, Dict, List -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # Import unified metadata utilities from scripts.benchmarks.common import BenchmarkMetadata @@ -102,15 +101,6 @@ async def run_all_benchmarks(components: List[str]) -> Dict[str, Any]: except Exception as e: print(f" Expansion benchmark failed: {e}") - if "router" in components or "all" in components: - try: - from scripts.benchmarks.router_bench import run_router_benchmark - print("\n▶ Running Router Benchmark...") - report = await run_router_benchmark(name="router") - results["components"]["router"] = report.to_dict() - except Exception as e: - print(f" Router benchmark failed: {e}") - if "rrf" in components or "all" in components: try: from scripts.benchmarks.rrf_quality import run_rrf_benchmark diff --git a/scripts/benchmarks/swe/runner.py b/scripts/benchmarks/swe/runner.py index 6edd39ee..ad9fcf49 100644 --- a/scripts/benchmarks/swe/runner.py +++ b/scripts/benchmarks/swe/runner.py @@ -109,7 +109,6 @@ from __future__ import annotations import os -import sys from pathlib import Path # --------------------------------------------------------------------------- @@ -135,9 +134,6 @@ # Silence tokenizers parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" -# Ensure project root is in path -sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) - import argparse import asyncio import json diff --git a/scripts/benchmarks/test_env_snapshot.py b/scripts/benchmarks/test_env_snapshot.py index 089dfe9d..a0566859 100644 --- a/scripts/benchmarks/test_env_snapshot.py +++ b/scripts/benchmarks/test_env_snapshot.py @@ -13,9 +13,7 @@ import tempfile from pathlib import Path -# Ensure project root is in path ROOT = Path(__file__).resolve().parent.parent.parent -sys.path.insert(0, str(ROOT)) from scripts.benchmarks.common import ( get_env_snapshot, diff --git a/scripts/benchmarks/trace_optimizer.py b/scripts/benchmarks/trace_optimizer.py index 075eb3e9..acb42f89 100644 --- a/scripts/benchmarks/trace_optimizer.py +++ b/scripts/benchmarks/trace_optimizer.py @@ -11,13 +11,10 @@ import asyncio import json import os -import sys from dataclasses import dataclass, field -from datetime import datetime, timedelta -from pathlib import Path +from datetime import datetime from typing import Any, Dict, List, Optional, Tuple -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) # --------------------------------------------------------------------------- # Time Window Configuration diff --git a/scripts/benchmarks/trm_bench.py b/scripts/benchmarks/trm_bench.py index 3511e582..23352644 100644 --- a/scripts/benchmarks/trm_bench.py +++ b/scripts/benchmarks/trm_bench.py @@ -9,14 +9,11 @@ import asyncio import json import os -import sys import time from dataclasses import dataclass, field, asdict -from pathlib import Path from typing import Any, Dict, List, Optional import statistics -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from scripts.benchmarks.common import percentile, extract_result_paths, resolve_collection_auto @@ -123,7 +120,7 @@ async def run_trm_benchmark(name: str = "default") -> TRMReport: """Run TRM/reranker benchmark.""" try: from scripts.embedder import get_embedding_model - from scripts.rerank_recursive import rerank_with_learning + from scripts.rerank_recursive.recursive import rerank_with_learning from scripts.mcp_indexer_server import repo_search except ImportError as e: print(f"Import error: {e}") diff --git a/scripts/benchmarks/validation_loop.py b/scripts/benchmarks/validation_loop.py index b491bb2c..3a91af72 100644 --- a/scripts/benchmarks/validation_loop.py +++ b/scripts/benchmarks/validation_loop.py @@ -9,15 +9,12 @@ import asyncio import json import os -import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Dict, List -# Add project root to path PROJECT_ROOT = Path(__file__).parent.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) # Load environment variables from .env try: diff --git a/scripts/codex_phase3_probe.py b/scripts/codex_phase3_probe.py new file mode 100644 index 00000000..628b09d5 --- /dev/null +++ b/scripts/codex_phase3_probe.py @@ -0,0 +1,2 @@ +MARK = 'v3' +# codex phase3 probe v3 diff --git a/scripts/collection_admin.py b/scripts/collection_admin.py index d970e941..58bde79f 100644 --- a/scripts/collection_admin.py +++ b/scripts/collection_admin.py @@ -1,30 +1,22 @@ +import logging import os import json import re import shutil import time -from pathlib import Path from datetime import datetime +from pathlib import Path from typing import Any, Dict, Optional, List -from scripts.auth_backend import mark_collection_deleted +logger = logging.getLogger(__name__) -try: - from qdrant_client import QdrantClient - from qdrant_client import models as qmodels -except Exception: - QdrantClient = None # type: ignore - qmodels = None # type: ignore +from scripts.auth_backend import mark_collection_deleted -try: - from scripts.qdrant_client_manager import pooled_qdrant_client -except Exception: - pooled_qdrant_client = None +from qdrant_client import QdrantClient +from qdrant_client import models as qmodels -try: - from scripts.workspace_state import get_collection_mappings -except Exception: - get_collection_mappings = None +from scripts.qdrant_client_manager import pooled_qdrant_client +from scripts.workspace_state import get_collection_mappings _SLUGGED_REPO_RE = re.compile(r"^.+-[0-9a-f]{16}(?:_old)?$") @@ -97,7 +89,6 @@ def _managed_upload_marker_path( slug_name: str, marker_root: Optional[Path] = None, ) -> Path: - # Marker is stored with per-repo metadata, not inside the repo workspace tree. base = marker_root or work_root return base / ".codebase" / "repos" / slug_name / _MARKER_NAME @@ -115,11 +106,12 @@ def _is_managed_upload_workspace_dir( return False if not _SLUGGED_REPO_RE.match(p.name or ""): return False - return _managed_upload_marker_path( + marker = _managed_upload_marker_path( work_root=work_root, marker_root=marker_root, slug_name=p.name, - ).exists() + ) + return marker.exists() except Exception: return False @@ -193,6 +185,7 @@ def delete_collection_everywhere( out: Dict[str, Any] = { "collection": name, "qdrant_deleted": False, + "qdrant_graph_deleted": False, "registry_marked_deleted": False, "deleted_state_files": 0, "deleted_managed_workspaces": 0, @@ -209,6 +202,14 @@ def delete_collection_everywhere( out["qdrant_deleted"] = True except Exception: out["qdrant_deleted"] = False + # Best-effort: also delete companion graph edges collection when present. + # This branch stores file-level edges in `_graph`. + if not name.endswith("_graph"): + try: + cli.delete_collection(collection_name=f"{name}_graph") + out["qdrant_graph_deleted"] = True + except Exception: + out["qdrant_graph_deleted"] = False except Exception: out["qdrant_deleted"] = False @@ -226,7 +227,7 @@ def delete_collection_everywhere( mappings = [] try: if get_collection_mappings is not None: - mappings = get_collection_mappings(search_root=str(codebase_root)) or [] + mappings = get_collection_mappings(search_root=str(work_root)) or [] except Exception: mappings = [] @@ -333,8 +334,6 @@ def _copy_client_timeout_seconds() -> Optional[float]: copied = False def _manual_copy_points() -> None: - if QdrantClient is None or qmodels is None: - raise RuntimeError("QdrantClient unavailable for manual collection copy") cli = QdrantClient(url=base_url, api_key=api_key or None, timeout=_copy_client_timeout_seconds()) try: if overwrite: @@ -359,8 +358,10 @@ def _manual_copy_points() -> None: vectors_config = None sparse_vectors_config = None + # Support vector-less collections (e.g. payload-only graph edge collections). if vectors_config is None: - raise RuntimeError(f"Cannot determine vectors config for source collection {src}") + vectors_config = {} + vectorless = isinstance(vectors_config, dict) and not vectors_config try: cli.create_collection( @@ -401,7 +402,7 @@ def _manual_copy_points() -> None: limit=batch_limit, offset=offset, with_payload=True, - with_vectors=True, + with_vectors=(not vectorless), ) except Exception as exc: raise RuntimeError(f"Failed to scroll points from {src}: {exc}") from exc @@ -414,7 +415,9 @@ def _manual_copy_points() -> None: point_id = getattr(record, "id", None) payload = getattr(record, "payload", None) vector = None - if hasattr(record, "vector") and getattr(record, "vector") is not None: + if vectorless: + vector = {} + elif hasattr(record, "vector") and getattr(record, "vector") is not None: vector = getattr(record, "vector") elif hasattr(record, "vectors") and getattr(record, "vectors") is not None: vector = getattr(record, "vectors") @@ -437,8 +440,6 @@ def _manual_copy_points() -> None: pass def _count_points(name: str) -> Optional[int]: - if QdrantClient is None: - return None cli = QdrantClient(url=base_url, api_key=api_key or None, timeout=_copy_client_timeout_seconds()) try: res = cli.count(collection_name=name, exact=True) @@ -477,4 +478,23 @@ def _count_points(name: str) -> Optional[int]: # The manual path guarantees the destination gets the exact same points/payloads/vectors. _manual_copy_points() + # Best-effort: copy the companion graph collection when copying a base collection. + # Graph edges are derived data and can be rebuilt, but copying avoids a cold-start window + # during staging cutovers where the clone has no graph. + if not src.endswith("_graph") and not dest.endswith("_graph"): + try: + copy_collection_qdrant( + source=f"{src}_graph", + target=f"{dest}_graph", + qdrant_url=base_url, + overwrite=overwrite, + ) + except Exception as exc: + logger.debug( + "Best-effort graph collection copy %s_graph -> %s_graph failed: %s", + src, + dest, + exc, + ) + return dest diff --git a/scripts/collection_health.py b/scripts/collection_health.py index 53ee089b..a4c72f3d 100644 --- a/scripts/collection_health.py +++ b/scripts/collection_health.py @@ -6,16 +6,9 @@ and triggers corrective actions (cache clear + reindex). """ import os -import sys -from pathlib import Path from typing import Optional, Dict, Any import logging -# Ensure project root is on sys.path -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - from scripts.workspace_state import ( _read_cache, _write_cache, @@ -432,4 +425,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/create_indexes.py b/scripts/create_indexes.py index 8f310aeb..9421b154 100644 --- a/scripts/create_indexes.py +++ b/scripts/create_indexes.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import os -import sys from pathlib import Path from qdrant_client import QdrantClient, models @@ -8,21 +7,8 @@ QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") from datetime import datetime ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) -# Import critical functions first -try: - from scripts.workspace_state import get_collection_name, is_multi_repo_mode -except Exception: - get_collection_name = None # type: ignore - is_multi_repo_mode = None # type: ignore - -# Import other optional functions -try: - from scripts.workspace_state import log_activity -except Exception: - log_activity = None # type: ignore +from scripts.workspace_state import get_collection_name, is_multi_repo_mode, log_activity COLLECTION = os.environ.get("COLLECTION_NAME", "codebase") # Discover workspace path for state updates (allows subdir indexing) diff --git a/scripts/ctx.py b/scripts/ctx.py index 2c570a28..ee77f576 100755 --- a/scripts/ctx.py +++ b/scripts/ctx.py @@ -95,10 +95,7 @@ def _load_env_file(): _load_env_file() -try: - from scripts.mcp_router import call_tool_http # type: ignore -except ModuleNotFoundError: # pragma: no cover - local execution fallback - from mcp_router import call_tool_http # type: ignore +from scripts.mcp_http_client import call_tool_http # Configuration from environment MCP_URL = os.environ.get("MCP_INDEXER_URL", "http://localhost:8003/mcp") @@ -250,6 +247,12 @@ def parse_mcp_response(result: Dict[str, Any]) -> Optional[Dict[str, Any]]: # FastMCP typically wraps results in a content array res = result.get("result", {}) + structured = res.get("structuredContent") if isinstance(res, dict) else None + if isinstance(structured, dict): + structured_result = structured.get("result") + if isinstance(structured_result, dict): + return structured_result + content = res.get("content", []) # Some servers may return a dict directly (no content array) @@ -271,7 +274,10 @@ def parse_mcp_response(result: Dict[str, Any]) -> Optional[Dict[str, Any]]: return None try: - return json.loads(text) + parsed = json.loads(text) + if isinstance(parsed, dict) and isinstance(parsed.get("result"), dict): + return parsed["result"] + return parsed except json.JSONDecodeError: return {"raw": text} @@ -736,19 +742,14 @@ def _generate_plan(enhanced_prompt: str, context: str, note: str) -> str: from refrag_glm import GLMRefragClient # type: ignore client = GLMRefragClient() - response = client.client.chat.completions.create( - model=os.environ.get("GLM_MODEL", "glm-4.6"), - messages=[ - {"role": "system", "content": system_msg}, - {"role": "user", "content": user_msg}, - ], + plan = client.generate_with_soft_embeddings( + f"{system_msg}\n\n{user_msg}", max_tokens=200, + model=os.environ.get("GLM_MODEL", "glm-4.6"), temperature=0.3, stream=False, - ) - plan = ( - (response.choices[0].message.content if response and response.choices else "") - or "" + no_thinking=os.environ.get("CTX_GLM_DISABLE_THINKING", "1").strip().lower() + not in {"0", "false", "no", "off"}, ).strip() if not plan: # Fall through to llama.cpp path @@ -1030,6 +1031,7 @@ def fetch_context(query: str, **filters) -> Tuple[str, str]: params = { "query": query, "limit": filters.get("limit", DEFAULT_LIMIT), + "per_path": filters.get("per_path", DEFAULT_PER_PATH), "include_snippet": with_snippets, "context_lines": filters.get("context_lines", DEFAULT_CONTEXT_LINES), "collection": collection_name, @@ -1244,33 +1246,16 @@ def rewrite_prompt(original_prompt: str, context: str, note: str, max_tokens: Op "For questions: expand into related conceptual questions. For commands/instructions: provide general guidance about the task. " ) - # GLM API call - response = client.client.chat.completions.create( - model=os.environ.get("GLM_MODEL", "glm-4.6"), - messages=[ - {"role": "system", "content": system_msg}, - {"role": "user", "content": user_msg} - ], + enhanced = client.generate_with_soft_embeddings( + f"{system_msg}\n\n{user_msg}", max_tokens=int(max_tokens or DEFAULT_REWRITE_TOKENS), + model=os.environ.get("GLM_MODEL", "glm-4.6"), temperature=0.45, - stream=stream + stream=stream, + no_thinking=os.environ.get("CTX_GLM_DISABLE_THINKING", "1").strip().lower() + not in {"0", "false", "no", "off"}, ) - enhanced = "" - if stream: - # Streaming mode for GLM - for chunk in response: - if chunk.choices[0].delta.content: - token = chunk.choices[0].delta.content - sys.stdout.write(token) - sys.stdout.flush() - enhanced += token - sys.stdout.write("\n") - sys.stdout.flush() - else: - # Non-streaming mode for GLM - enhanced = response.choices[0].message.content - else: # Use local decoder (llama.cpp by default; Ollama supported when DECODER_URL points to /api/chat) meta_prompt = ( @@ -1586,6 +1571,8 @@ def main(): else: rewritten = rewrite_prompt(args.query, context_text, context_note, max_tokens=args.rewrite_max_tokens) output = sanitize_citations(rewritten.strip(), allowed_paths) + if args.with_context and context_text.strip(): + output = output.rstrip() + "\n\n---\nSupporting context:\n" + context_text.strip() if args.cmd: subprocess.run(args.cmd, input=output.encode("utf-8"), shell=True, check=False) diff --git a/scripts/health_check.py b/scripts/health_check.py index 32b9167d..3e2bf057 100644 --- a/scripts/health_check.py +++ b/scripts/health_check.py @@ -1,25 +1,11 @@ #!/usr/bin/env python3 import os import sys -from pathlib import Path from typing import Dict, Any from qdrant_client import QdrantClient, models -# Ensure /work (repo root) is on sys.path when run from /work/scripts -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -# Use embedder factory for Qwen3 support; fallback to direct fastembed -try: - from scripts.embedder import get_embedding_model, get_model_dimension - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - from fastembed import TextEmbedding - - +from scripts.embedder import get_embedding_model, get_model_dimension from scripts.utils import sanitize_vector_name from scripts.auth_backend import ensure_collections, AuthDisabledError @@ -42,18 +28,14 @@ def assert_true(cond: bool, msg: str, *, critical: bool = False, failures: list[ def main(): qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333") api_key = os.environ.get("QDRANT_API_KEY") - collection = os.environ.get("COLLECTION_NAME", "codebase") + collection = (os.environ.get("COLLECTION_NAME") or "codebase").strip() model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") print(f"Health check -> {qdrant_url} collection={collection} model={model_name}") # Init embedding to derive dimension and test embedding - if _EMBEDDER_FACTORY: - model = get_embedding_model(model_name) - dim = get_model_dimension(model_name) - else: - model = TextEmbedding(model_name=model_name) - dim = len(next(model.embed(["health dim probe"]))) + model = get_embedding_model(model_name) + dim = get_model_dimension(model_name) vec_name_expect = sanitize_vector_name(model_name) client = QdrantClient(url=qdrant_url, api_key=api_key or None) @@ -82,7 +64,6 @@ def main(): print("No collections found - nothing to health check") return - # Check each collection for collection_name in collections: print(f"Checking collection: {collection_name}") @@ -92,10 +73,20 @@ def main(): if isinstance(cfg, dict): present_names = list(cfg.keys()) assert_true(len(present_names) >= 1, "Collection has at least one named vector") + has_expected_vector = vec_name_expect in present_names assert_true( - vec_name_expect in present_names, + has_expected_vector, f"Expected vector name present: {vec_name_expect} in {present_names}", ) + if not has_expected_vector: + failures.append( + f"Collection {collection_name} is missing expected vector {vec_name_expect}" + ) + print( + f"[WARN] Skipping vector query for {collection_name}; " + f"expected vector {vec_name_expect!r} not present" + ) + continue got_dim = cfg[vec_name_expect].size else: present_names = [""] @@ -172,4 +163,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/hybrid/__init__.py b/scripts/hybrid/__init__.py index f893feb1..c37c2729 100644 --- a/scripts/hybrid/__init__.py +++ b/scripts/hybrid/__init__.py @@ -5,11 +5,4 @@ from scripts.hybrid import config, qdrant, embed, filters, expand, ranking from scripts.hybrid.config import QDRANT_URL """ -from scripts.hybrid import config -from scripts.hybrid import qdrant -from scripts.hybrid import embed -from scripts.hybrid import filters -from scripts.hybrid import expand -from scripts.hybrid import ranking - __all__ = ["config", "qdrant", "embed", "filters", "expand", "ranking"] diff --git a/scripts/hybrid/config.py b/scripts/hybrid/config.py index a1c8854a..0edd00f8 100644 --- a/scripts/hybrid/config.py +++ b/scripts/hybrid/config.py @@ -152,11 +152,9 @@ def _get_micro_defaults() -> tuple[int, int, int, int]: """ micro_enabled = os.environ.get("INDEX_MICRO_CHUNKS", "1").strip().lower() in {"1", "true", "yes", "on"} - try: - from scripts.refrag_glm import detect_glm_runtime - is_glm = detect_glm_runtime() - except ImportError: - is_glm = False + from scripts.refrag_glm import detect_glm_runtime + + is_glm = detect_glm_runtime() if is_glm: if micro_enabled: diff --git a/scripts/hybrid/embed.py b/scripts/hybrid/embed.py index 23b7ee82..e24d49a3 100644 --- a/scripts/hybrid/embed.py +++ b/scripts/hybrid/embed.py @@ -25,24 +25,14 @@ from pathlib import Path from typing import Any, List, Optional, TYPE_CHECKING -# --------------------------------------------------------------------------- -# Embedder factory setup -# --------------------------------------------------------------------------- -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - _get_embedding_model = None # type: ignore - -# Always try to import TextEmbedding for backward compatibility with tests -try: - from fastembed import TextEmbedding -except ImportError: - TextEmbedding = None # type: ignore +from scripts.embedder import get_embedding_model as _get_embedding_model + +_EMBEDDER_FACTORY = True + +TextEmbedding = None # Type alias for embedding model (TextEmbedding or compatible) -EmbeddingModel = Any if TextEmbedding is None else TextEmbedding +EmbeddingModel = Any # --------------------------------------------------------------------------- # Configuration constants @@ -52,12 +42,9 @@ # --------------------------------------------------------------------------- # Unified cache system # --------------------------------------------------------------------------- -try: - from scripts.cache_manager import get_embedding_cache - UNIFIED_CACHE_AVAILABLE = True -except ImportError: - UNIFIED_CACHE_AVAILABLE = False - get_embedding_cache = None # type: ignore +from scripts.cache_manager import get_embedding_cache + +UNIFIED_CACHE_AVAILABLE = True # Legacy cache fallback structures _EMBED_QUERY_CACHE: OrderedDict[tuple[str, str], List[float]] = OrderedDict() @@ -109,14 +96,12 @@ def get_embedding_model(model_name: Optional[str] = None) -> EmbeddingModel: if _EMBEDDER_FACTORY and _get_embedding_model is not None: return _get_embedding_model(model_name) - if TextEmbedding is None: - raise ImportError( - "No embedding backend available. Install fastembed or ensure " - "scripts.embedder is importable." - ) - name = model_name or MODEL_NAME - return TextEmbedding(model_name=name) + text_embedding_cls = TextEmbedding + if text_embedding_cls is None: + from fastembed import TextEmbedding as text_embedding_cls + + return text_embedding_cls(model_name=name) # --------------------------------------------------------------------------- @@ -169,11 +154,9 @@ def embed_queries_cached( name = os.environ.get("EMBEDDING_MODEL", MODEL_NAME) # Apply Qwen3 instruction prefix if enabled (queries only, not documents) - try: - from scripts.embedder import prefix_queries - sanitized = prefix_queries(sanitized, name) - except ImportError: - pass + from scripts.embedder import prefix_queries + + sanitized = prefix_queries(sanitized, name) cache = _get_embed_cache() diff --git a/scripts/hybrid/expand.py b/scripts/hybrid/expand.py index 20a43834..b787db2b 100644 --- a/scripts/hybrid/expand.py +++ b/scripts/hybrid/expand.py @@ -29,27 +29,25 @@ from typing import List, Dict, Any, TYPE_CHECKING from pathlib import Path +from scripts.path_scope import ( + normalize_under as _normalize_under_scope, + metadata_matches_under as _metadata_matches_under, +) + logger = logging.getLogger("hybrid_expand") # Import QdrantClient type for annotations if TYPE_CHECKING: from qdrant_client import QdrantClient -# Import semantic expansion functionality (optional) -try: - from scripts.semantic_expansion import ( - expand_queries_semantically, - expand_queries_with_prf, - get_expansion_stats, - clear_expansion_cache, - ) - SEMANTIC_EXPANSION_AVAILABLE = True -except ImportError: - SEMANTIC_EXPANSION_AVAILABLE = False - expand_queries_semantically = None - expand_queries_with_prf = None - get_expansion_stats = None - clear_expansion_cache = None +from scripts.semantic_expansion import ( + expand_queries_semantically, + expand_queries_with_prf, + get_expansion_stats, + clear_expansion_cache, +) + +SEMANTIC_EXPANSION_AVAILABLE = True # Feature flag for embedding-based dynamic expansion @@ -542,20 +540,8 @@ def expand_via_embeddings( except Exception: vec_name = None - def _norm_under(u: str | None) -> str | None: - if not u: - return None - u = str(u).strip().replace("\\", "/") - u = "/".join([p for p in u.split("/") if p]) - if not u: - return None - if u.startswith("/work/"): - return u - if not u.startswith("/"): - return "/work/" + u - return "/work/" + u.lstrip("/") - flt = None + eff_under = _normalize_under_scope(under) try: from qdrant_client import models @@ -567,15 +553,6 @@ def _norm_under(u: str | None) -> str | None: match=models.MatchValue(value=language), ) ) - if under: - eff_under = _norm_under(under) - if eff_under: - must.append( - models.FieldCondition( - key="metadata.path_prefix", - match=models.MatchValue(value=eff_under), - ) - ) if kind: must.append( models.FieldCondition( @@ -621,10 +598,11 @@ def _norm_under(u: str | None) -> str | None: # Search for soft matches (we want semantically similar docs, not exact matches) try: + initial_limit = 8 if not eff_under else max(32, int(max_terms) * 8) search_kwargs = { "collection_name": collection, "query_vector": (vec_name, query_vector) if vec_name else query_vector, - "limit": 8, # Get top 8 neighbors + "limit": initial_limit, # Over-fetch when `under` is set (we post-filter). "with_payload": True, "score_threshold": 0.3, # Lower threshold to get more diverse results } @@ -637,6 +615,17 @@ def _norm_under(u: str | None) -> str | None: if not results: return [] + if eff_under: + _scoped = [] + for hit in results: + payload = getattr(hit, "payload", None) or {} + md = payload.get("metadata") or {} + if _metadata_matches_under(md, eff_under): + _scoped.append(hit) + results = _scoped + if not results: + return [] + # Extract unique terms from neighbors extracted_terms: set[str] = set() query_tokens = set(combined_query.lower().split()) diff --git a/scripts/hybrid/qdrant.py b/scripts/hybrid/qdrant.py index ef0f936d..67b754dc 100644 --- a/scripts/hybrid/qdrant.py +++ b/scripts/hybrid/qdrant.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + """ Qdrant client and query logic extracted from hybrid_search.py. @@ -25,16 +27,22 @@ import logging import threading import re -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any, Tuple, TYPE_CHECKING from pathlib import Path from concurrent.futures import ThreadPoolExecutor -# Core Qdrant imports -try: - from qdrant_client import QdrantClient, models -except ImportError: - QdrantClient = None # type: ignore - models = None # type: ignore +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() logger = logging.getLogger("hybrid_qdrant") @@ -75,6 +83,9 @@ def _safe_float(val: Any, default: float) -> float: LEX_SPARSE_NAME, LEX_SPARSE_MODE, ) +from scripts.query_optimizer import optimize_query +from scripts.utils import lex_hash_vector_queries as _lex_hash_vector_queries +from scripts.utils import lex_sparse_vector_queries as _lex_sparse_vector_queries EF_SEARCH = _safe_int(os.environ.get("QDRANT_EF_SEARCH", "128"), 128) @@ -101,40 +112,13 @@ def _get_search_params(ef: int) -> models.SearchParams: # Connection pooling setup # --------------------------------------------------------------------------- -try: - from scripts.qdrant_client_manager import get_qdrant_client, return_qdrant_client, pooled_qdrant_client - _POOL_AVAILABLE = True -except ImportError: - _POOL_AVAILABLE = False - - def get_qdrant_client(url=None, api_key=None, force_new=False, use_pool=True): - """Fallback client creation when pooling is unavailable.""" - if QdrantClient is None: - raise ImportError( - "qdrant_client is not installed. Install with: pip install qdrant-client" - ) - return QdrantClient( - url=url or os.environ.get("QDRANT_URL", "http://localhost:6333"), - api_key=api_key or os.environ.get("QDRANT_API_KEY") - ) - - def return_qdrant_client(client): - """No-op when pooling is unavailable.""" - pass - - class pooled_qdrant_client: - """Fallback context manager when pooling is unavailable.""" - def __init__(self, url=None, api_key=None): - self.url = url - self.api_key = api_key - self.client = None - - def __enter__(self): - self.client = get_qdrant_client(self.url, self.api_key) - return self.client +from scripts.qdrant_client_manager import ( + get_qdrant_client, + return_qdrant_client, + pooled_qdrant_client, +) - def __exit__(self, exc_type, exc_val, exc_tb): - return_qdrant_client(self.client) +_POOL_AVAILABLE = True # --------------------------------------------------------------------------- @@ -276,12 +260,10 @@ def _ensure_collection(client, collection: str, dim: int, vec_name: str): _ENSURED_COLLECTIONS.add(cache_key) return - # Collection doesn't exist - only then call ensure_collection to create it - try: - from scripts.ingest_code import ensure_collection as _ensure_collection_raw - _ensure_collection_raw(client, collection, dim, vec_name) - except ImportError: - pass + # Collection doesn't exist - only then call the ingest Qdrant adapter to create it. + from scripts.ingest.qdrant import ensure_collection as _ensure_collection_raw + + _ensure_collection_raw(client, collection, dim, vec_name) try: _cache_collection_vectors(client, collection) @@ -403,51 +385,12 @@ def lex_hash_vector(phrases: List[str], dim: int | None = None) -> List[float]: """Generate dense lexical hash vector for query phrases.""" if dim is None: dim = LEX_VECTOR_DIM - try: - from scripts.utils import lex_hash_vector_queries as _lex_hash_vector_queries - return _lex_hash_vector_queries(phrases, dim) - except ImportError: - return _fallback_lex_hash_vector(phrases, dim) - - -def _fallback_lex_hash_vector(phrases: List[str], dim: int) -> List[float]: - """Fallback implementation when utils is unavailable.""" - import hashlib - vec = [0.0] * dim - for phrase in phrases: - for tok in _split_ident_lex(phrase): - h = int(hashlib.md5(tok.encode()).hexdigest(), 16) - idx = h % dim - vec[idx] += 1.0 - norm = sum(v * v for v in vec) ** 0.5 - if norm > 0: - vec = [v / norm for v in vec] - return vec + return _lex_hash_vector_queries(phrases, dim) def lex_sparse_vector(phrases: List[str]) -> Dict[str, Any]: """Generate sparse vector for query phrases (lossless exact matching).""" - try: - from scripts.utils import lex_sparse_vector_queries as _lex_sparse_vector_queries - return _lex_sparse_vector_queries(phrases) - except ImportError: - return _fallback_lex_sparse_vector(phrases) - - -def _fallback_lex_sparse_vector(phrases: List[str]) -> Dict[str, Any]: - """Fallback implementation when utils is unavailable.""" - import hashlib - indices = [] - values = [] - seen = set() - for phrase in phrases: - for tok in _split_ident_lex(phrase): - h = int(hashlib.md5(tok.encode()).hexdigest(), 16) % (2**31) - if h not in seen: - indices.append(h) - values.append(1.0) - seen.add(h) - return {"indices": indices, "values": values} + return _lex_sparse_vector_queries(phrases) # --------------------------------------------------------------------------- @@ -605,15 +548,12 @@ def dense_query( # Apply dynamic EF optimization if query text provided if query_text: try: - from scripts.query_optimizer import optimize_query result = optimize_query(query_text) # Only override EF when adaptive optimization is enabled if result.get("adaptive_enabled", False): ef = result["recommended_ef"] if os.environ.get("DEBUG_HYBRID_SEARCH"): logger.debug(f"Dynamic EF: {ef} (complexity={result['complexity']}, type={result['query_type']})") - except ImportError: - pass except Exception as e: if os.environ.get("DEBUG_HYBRID_SEARCH"): logger.debug(f"Query optimizer failed, using default EF: {e}") diff --git a/scripts/hybrid/ranking.py b/scripts/hybrid/ranking.py index 6a436272..8a96bc83 100644 --- a/scripts/hybrid/ranking.py +++ b/scripts/hybrid/ranking.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + """ Ranking and scoring logic for hybrid search. @@ -21,7 +23,20 @@ import re import math import logging -from typing import List, Dict, Any, Tuple +from typing import List, Dict, Any, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() logger = logging.getLogger("hybrid_ranking") @@ -81,11 +96,8 @@ def _get_micro_defaults() -> Tuple[int, int, int, int]: Budget tokens floor is 5000 to ensure context_answer has enough context for quality answers. """ micro_enabled = os.environ.get("INDEX_MICRO_CHUNKS", "1").strip().lower() in {"1", "true", "yes", "on"} - try: - from scripts.refrag_glm import detect_glm_runtime - is_glm = detect_glm_runtime() - except ImportError: - is_glm = False + from scripts.refrag_glm import detect_glm_runtime + is_glm = detect_glm_runtime() if is_glm: if micro_enabled: return (24, 6, 8192, 32) @@ -658,12 +670,6 @@ def _get_symbol_extent( if cache_key in _SYMBOL_EXTENT_CACHE: return _SYMBOL_EXTENT_CACHE[cache_key] - # Lazy import to avoid circular dependencies - try: - from qdrant_client import QdrantClient, models - except ImportError: - return (0, 0) - if not collection: collection = os.environ.get("COLLECTION_NAME", "") if not collection: @@ -679,7 +685,9 @@ def _get_symbol_extent( timeout_s = float(os.environ.get("ADAPTIVE_SPAN_QDRANT_TIMEOUT", "1.0") or 1.0) except Exception: timeout_s = 1.0 - _SYMBOL_EXTENT_CLIENT = QdrantClient( + from qdrant_client import QdrantClient as _QdrantClient + + _SYMBOL_EXTENT_CLIENT = _QdrantClient( url=qdrant_url, api_key=os.environ.get("QDRANT_API_KEY"), timeout=timeout_s, diff --git a/scripts/hybrid_config.py b/scripts/hybrid_config.py deleted file mode 100644 index aebff322..00000000 --- a/scripts/hybrid_config.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/config.py""" -from scripts.hybrid.config import * diff --git a/scripts/hybrid_embed.py b/scripts/hybrid_embed.py deleted file mode 100644 index eb331621..00000000 --- a/scripts/hybrid_embed.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/embed.py""" -from scripts.hybrid.embed import * diff --git a/scripts/hybrid_expand.py b/scripts/hybrid_expand.py deleted file mode 100644 index d8131c0c..00000000 --- a/scripts/hybrid_expand.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/expand.py""" -from scripts.hybrid.expand import * diff --git a/scripts/hybrid_filters.py b/scripts/hybrid_filters.py deleted file mode 100644 index 60ffca18..00000000 --- a/scripts/hybrid_filters.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/filters.py""" -from scripts.hybrid.filters import * diff --git a/scripts/hybrid_qdrant.py b/scripts/hybrid_qdrant.py deleted file mode 100644 index 2498ef8e..00000000 --- a/scripts/hybrid_qdrant.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/qdrant.py""" -from scripts.hybrid.qdrant import * diff --git a/scripts/hybrid_ranking.py b/scripts/hybrid_ranking.py deleted file mode 100644 index e20f7b32..00000000 --- a/scripts/hybrid_ranking.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/hybrid/ranking.py""" -from scripts.hybrid.ranking import * diff --git a/scripts/hybrid_search.py b/scripts/hybrid_search.py index 70b5b094..bdcba933 100644 --- a/scripts/hybrid_search.py +++ b/scripts/hybrid_search.py @@ -4,12 +4,12 @@ This is the stable public entrypoint for the hybrid search subsystem. All internal logic has been refactored into smaller, focused modules: -- hybrid_config.py: Environment-based configuration and constants -- hybrid_qdrant.py: Qdrant client management, queries, and vector functions -- hybrid_embed.py: Embedding model factory and cached embedding -- hybrid_filters.py: File classification and query DSL parsing -- hybrid_ranking.py: RRF, scoring, diversification, and micro-span budgeting -- hybrid_expand.py: Query expansion (synonyms, semantic, LLM-assisted) +- hybrid/config.py: Environment-based configuration and constants +- hybrid/qdrant.py: Qdrant client management, queries, and vector functions +- hybrid/embed.py: Embedding model factory and cached embedding +- hybrid/filters.py: File classification and query DSL parsing +- hybrid/ranking.py: RRF, scoring, diversification, and micro-span budgeting +- hybrid/expand.py: Query expansion (synonyms, semantic, LLM-assisted) This façade: 1. Re-exports all public APIs for backwards compatibility @@ -19,32 +19,36 @@ from __future__ import annotations import os -import sys import argparse import re import json import math import logging import threading -from pathlib import Path from typing import List, Dict, Any, Tuple, TYPE_CHECKING from functools import lru_cache from concurrent.futures import ThreadPoolExecutor -# Ensure /work or repo root is in sys.path for scripts imports -_ROOT_DIR = Path(__file__).resolve().parent.parent -if str(_ROOT_DIR) not in sys.path: - sys.path.insert(0, str(_ROOT_DIR)) - # --------------------------------------------------------------------------- -# Core Qdrant imports +# Lazy Qdrant model namespace # --------------------------------------------------------------------------- -from qdrant_client import QdrantClient, models +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() # --------------------------------------------------------------------------- # Re-exports from hybrid_config # --------------------------------------------------------------------------- -from scripts.hybrid_config import ( +from scripts.hybrid.config import ( # Helper functions _safe_int, _safe_float, @@ -109,7 +113,7 @@ # --------------------------------------------------------------------------- # Re-exports from hybrid_qdrant # --------------------------------------------------------------------------- -from scripts.hybrid_qdrant import ( +from scripts.hybrid.qdrant import ( # Pool availability _POOL_AVAILABLE, # Connection pooling @@ -142,7 +146,7 @@ # --------------------------------------------------------------------------- # Re-exports from hybrid_embed # --------------------------------------------------------------------------- -from scripts.hybrid_embed import ( +from scripts.hybrid.embed import ( # Embedder factory _EMBEDDER_FACTORY, EmbeddingModel, @@ -156,34 +160,21 @@ UNIFIED_CACHE_AVAILABLE, ) -# Import unified cache objects from cache_manager when available -if UNIFIED_CACHE_AVAILABLE: - try: - from scripts.cache_manager import get_search_cache, get_embedding_cache, get_expansion_cache - _EMBED_CACHE = get_embedding_cache() - _RESULTS_CACHE = get_search_cache() - _EXPANSION_CACHE = get_expansion_cache() - except ImportError: - _EMBED_CACHE = None - _RESULTS_CACHE = {} - _EXPANSION_CACHE = None -else: - _EMBED_CACHE = None - _RESULTS_CACHE = {} - _EXPANSION_CACHE = None - -# Lightweight local fallback cache for deterministic test hits -try: - from collections import OrderedDict as _OD -except Exception: - _OD = dict # pragma: no cover +from scripts.cache_manager import get_search_cache, get_embedding_cache, get_expansion_cache + +_EMBED_CACHE = get_embedding_cache() +_RESULTS_CACHE = get_search_cache() +_EXPANSION_CACHE = get_expansion_cache() + +from collections import OrderedDict as _OD + _RESULTS_CACHE_OD = _OD() _RESULTS_LOCK = threading.RLock() # --------------------------------------------------------------------------- # Re-exports from hybrid_filters # --------------------------------------------------------------------------- -from scripts.hybrid_filters import ( +from scripts.hybrid.filters import ( # File patterns CORE_FILE_PATTERNS, NON_CORE_PATTERNS, @@ -206,7 +197,7 @@ # --------------------------------------------------------------------------- # Re-exports from hybrid_ranking # --------------------------------------------------------------------------- -from scripts.hybrid_ranking import ( +from scripts.hybrid.ranking import ( # RRF rrf, _scale_rrf_k, @@ -238,7 +229,7 @@ # --------------------------------------------------------------------------- # Re-exports from hybrid_expand # --------------------------------------------------------------------------- -from scripts.hybrid_expand import ( +from scripts.hybrid.expand import ( # Synonyms CODE_SYNONYMS, # Expansion functions @@ -253,7 +244,7 @@ # Conditionally re-export semantic expansion functions if SEMANTIC_EXPANSION_AVAILABLE: - from scripts.hybrid_expand import ( + from scripts.hybrid.expand import ( expand_queries_semantically, expand_queries_with_prf, get_expansion_stats, @@ -268,34 +259,27 @@ # --------------------------------------------------------------------------- # Additional imports for backward compatibility # --------------------------------------------------------------------------- -try: - from fastembed import TextEmbedding -except ImportError: - TextEmbedding = None # type: ignore +TextEmbedding = None # Tests may monkeypatch this; production imports lazily if needed. -try: - from scripts.embedder import get_embedding_model as _get_embedding_model -except ImportError: - _get_embedding_model = None +from scripts.embedder import get_embedding_model as _get_embedding_model # Import request deduplication system -try: - from scripts.deduplication import get_deduplicator, is_duplicate_request - DEDUPLICATION_AVAILABLE = True -except ImportError: - DEDUPLICATION_AVAILABLE = False +from scripts.deduplication import get_deduplicator, is_duplicate_request + +DEDUPLICATION_AVAILABLE = True # Import query optimizer for dynamic EF tuning -try: - from scripts.query_optimizer import get_query_optimizer, optimize_query - QUERY_OPTIMIZER_AVAILABLE = True -except ImportError: - QUERY_OPTIMIZER_AVAILABLE = False +from scripts.query_optimizer import get_query_optimizer, optimize_query + +QUERY_OPTIMIZER_AVAILABLE = True # Import ingest helpers from scripts.utils import sanitize_vector_name as _sanitize_vector_name -from scripts.ingest_code import ensure_collection as _ensure_collection_raw -from scripts.ingest_code import project_mini as _project_mini +from scripts.ingest.vectors import project_mini as _project_mini +from scripts.path_scope import ( + normalize_under as _normalize_under_scope, + metadata_matches_under as _metadata_matches_under, +) # --------------------------------------------------------------------------- # Module logger @@ -421,13 +405,64 @@ def _generate_code_query_variants(query: str) -> List[str]: return result[:5] # Max 5 variants to balance coverage vs compute +def _shape_dense_points( + ranked_points: List[Any], + *, + limit: int, + per_path: int | None = 1, + under: str | None = None, +) -> List[Dict[str, Any]]: + eff_under = _normalize_under_scope(under) + eff_per_path = int(per_path or 0) + + results: List[Dict[str, Any]] = [] + path_counts: dict[str, int] = {} + for p in ranked_points: + payload = p.payload or {} + md = payload.get("metadata") or {} + if eff_under and not _metadata_matches_under(md, eff_under): + continue + + # Prefer host_path when available (consistent with hybrid search). + path = md.get("host_path") or payload.get("path") or md.get("path") or "" + if eff_per_path > 0: + current = path_counts.get(path, 0) + if current >= eff_per_path: + continue + else: + current = 0 + + results.append( + { + "score": float(getattr(p, "score", 0) or 0), + "path": path, + "symbol": payload.get("symbol") or md.get("symbol") or "", + "start_line": int(md.get("start_line") or 0), + "end_line": int(md.get("end_line") or 0), + "code_id": payload.get("code_id") or payload.get("_id") or "", + "doc_id": payload.get("code_id") or payload.get("_id") or "", + "payload": payload, + } + ) + if eff_per_path > 0: + path_counts[path] = current + 1 + if len(results) >= int(limit): + break + + return results + + def run_pure_dense_search( query: str, limit: int = 10, + per_path: int | None = 1, model: Any = None, collection: str | None = None, language: str | None = None, under: str | None = None, + kind: str | None = None, + symbol: str | None = None, + ext: str | None = None, repo: str | list[str] | None = None, ) -> List[Dict[str, Any]]: """Pure dense search - single query embedding, single vector search. @@ -437,45 +472,51 @@ def run_pure_dense_search( Args: query: Natural language query limit: Max results to return + per_path: Optional max results per file path; <= 0 disables the cap model: Embedding model (will load default if None) collection: Qdrant collection name language: Optional language filter - under: Optional path prefix filter + under: Optional recursive workspace subtree filter + kind: Optional kind filter (exact match) + symbol: Optional symbol filter (exact match) + ext: Optional file extension filter (without dot) repo: Optional repo filter Returns: List of search results with raw cosine similarity scores """ - from scripts.hybrid_qdrant import get_qdrant_client, return_qdrant_client, dense_query + from scripts.hybrid.qdrant import get_qdrant_client, return_qdrant_client, dense_query from scripts.utils import sanitize_vector_name from qdrant_client import models # Get model if model is None: model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") - try: - from scripts.embedder import get_embedding_model - model = get_embedding_model(model_name) - except ImportError: - from fastembed import TextEmbedding - model = TextEmbedding(model_name=model_name) + from scripts.embedder import get_embedding_model + model = get_embedding_model(model_name) else: model_name = getattr(model, "model_name", os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5")) vec_name = sanitize_vector_name(model_name) coll = collection or _collection() - # Build filter + # Build server-side filter (exclude `under` here; recursive under is post-filtered) must = [] if language: must.append(models.FieldCondition(key="metadata.language", match=models.MatchValue(value=language))) - if under: - must.append(models.FieldCondition(key="metadata.path_prefix", match=models.MatchValue(value=under))) if repo and repo != "*": if isinstance(repo, list): must.append(models.FieldCondition(key="metadata.repo", match=models.MatchAny(any=repo))) else: must.append(models.FieldCondition(key="metadata.repo", match=models.MatchValue(value=repo))) + if kind: + must.append(models.FieldCondition(key="metadata.kind", match=models.MatchValue(value=kind))) + if symbol: + must.append(models.FieldCondition(key="metadata.symbol", match=models.MatchValue(value=symbol))) + if ext: + ext_clean = str(ext).lower().lstrip(".") + if ext_clean: + must.append(models.FieldCondition(key="metadata.ext", match=models.MatchValue(value=ext_clean))) flt = models.Filter(must=must) if must else None # Single query embedding - no variants, no expansion @@ -500,30 +541,22 @@ def run_pure_dense_search( ) try: - # Single dense query - no pooling, no re-scoring - ranked_points = dense_query(client, vec_name, vec_list, flt, limit, coll, query_text=query) - - # Build output - results = [] - for p in ranked_points: - payload = p.payload or {} - md = payload.get("metadata") or {} - - # Prefer host_path when available (consistent with hybrid search) - _path = md.get("host_path") or payload.get("path") or md.get("path") or "" - - results.append({ - "score": float(getattr(p, "score", 0) or 0), - "path": _path, - "symbol": payload.get("symbol") or md.get("symbol") or "", - "start_line": int(md.get("start_line") or 0), - "end_line": int(md.get("end_line") or 0), - "code_id": payload.get("code_id") or payload.get("_id") or "", - "doc_id": payload.get("code_id") or payload.get("_id") or "", - "payload": payload, - }) - - return results + # Single dense query - no pooling, no re-scoring. + # When `under` or `per_path` is set, we may need to over-fetch so post-filters + # can still fill up to `limit` results. + eff_under = _normalize_under_scope(under) + fetch_limit = int(limit) + eff_per_path = int(per_path or 0) + if eff_under or eff_per_path > 0: + fetch_limit = min(max(fetch_limit * 4, fetch_limit + 16), 2000) + ranked_points = dense_query(client, vec_name, vec_list, flt, fetch_limit, coll, query_text=query) + + return _shape_dense_points( + ranked_points, + limit=limit, + per_path=per_path, + under=under, + ) finally: return_qdrant_client(client) @@ -532,7 +565,7 @@ def run_pure_dense_search( # --------------------------------------------------------------------------- # Backward compatibility: _embed_queries_cached alias # --------------------------------------------------------------------------- -# The function is now in hybrid_embed.py as embed_queries_cached +# The function is now in hybrid/embed.py as embed_queries_cached # Keep the underscore-prefixed alias for any legacy callers @@ -612,7 +645,11 @@ def _dt(label: str): elif _EMBEDDER_FACTORY: _model = _get_embedding_model(model_name) else: - _model = TextEmbedding(model_name=model_name) + text_embedding_cls = TextEmbedding + if text_embedding_cls is None: + from fastembed import TextEmbedding as text_embedding_cls + + _model = text_embedding_cls(model_name=model_name) vec_name = _sanitize_vector_name(model_name) # Parse Query DSL and merge with explicit args @@ -690,21 +727,8 @@ def _normalize_globs(globs: list[str]) -> list[str]: eff_path_globs_norm = _normalize_globs(eff_path_globs) eff_not_globs_norm = _normalize_globs(eff_not_globs) - # Normalize under - def _norm_under(u: str | None) -> str | None: - if not u: - return None - u = str(u).strip().replace("\\", "/") - u = "/".join([p for p in u.split("/") if p]) - if not u: - return None - if not u.startswith("/"): - v = "/work/" + u - else: - v = "/work/" + u.lstrip("/") if not u.startswith("/work/") else u - return v - - eff_under = _norm_under(eff_under) + # Normalize under as a user-facing recursive subtree scope. + eff_under = _normalize_under_scope(eff_under) # Expansion knobs that affect query construction/results (must be part of cache key) try: @@ -810,12 +834,8 @@ def _norm_under(u: str | None) -> str | None: key="metadata.repo", match=models.MatchValue(value=eff_repo) ) ) - if eff_under: - must.append( - models.FieldCondition( - key="metadata.path_prefix", match=models.MatchValue(value=eff_under) - ) - ) + # NOTE: `under` is recursive and user-facing; we enforce it in client-side + # filtering via normalized metadata paths instead of exact path_prefix equality. if eff_kind: must.append( models.FieldCondition( @@ -2105,7 +2125,7 @@ def _match_glob(pat: str, path: str) -> bool: return _fnm.fnmatchcase(path, pat) return _fnm.fnmatchcase(path.lower(), pat.lower()) - if eff_not or eff_path_regex or eff_ext or eff_path_globs or eff_not_globs: + if eff_under or eff_not or eff_path_regex or eff_ext or eff_path_globs or eff_not_globs: def _pass_filters(m: Dict[str, Any]) -> bool: md = (m["pt"].payload or {}).get("metadata") or {} @@ -2118,6 +2138,8 @@ def _pass_filters(m: Dict[str, Any]) -> bool: nn = eff_not if case_sensitive else eff_not.lower() if nn in p_for_sub or nn in pp_for_sub: return False + if eff_under and not _metadata_matches_under(md, eff_under): + return False if eff_not_globs_norm and any(_match_glob(g, path) or _match_glob(g, rel) for g in eff_not_globs_norm): return False if eff_ext: diff --git a/scripts/indexing_admin.py b/scripts/indexing_admin.py index f3bb69d8..ac93b6f6 100644 --- a/scripts/indexing_admin.py +++ b/scripts/indexing_admin.py @@ -10,70 +10,36 @@ from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple -try: - from qdrant_client import QdrantClient -except Exception: - QdrantClient = None # type: ignore - -try: - from scripts.embedder import get_model_dimension -except Exception: - get_model_dimension = None # type: ignore - -try: - from scripts.collection_admin import copy_collection_qdrant -except Exception: - copy_collection_qdrant = None # type: ignore - -try: - from scripts.ingest_code import ( - ensure_collection_and_indexes_once, - ensure_payload_indexes, - _sanitize_vector_name, - MINI_VECTOR_NAME as _MINI_VECTOR_NAME, - LEX_SPARSE_NAME as _LEX_SPARSE_NAME, - ) -except Exception: - ensure_collection_and_indexes_once = None # type: ignore - ensure_payload_indexes = None # type: ignore - _sanitize_vector_name = None # type: ignore - _MINI_VECTOR_NAME = os.environ.get("MINI_VECTOR_NAME", "mini") - _LEX_SPARSE_NAME = os.environ.get("LEX_SPARSE_NAME", "lex_sparse") - -try: - from scripts.workspace_state import ( - get_collection_mappings, - get_workspace_state, - update_workspace_state, - update_indexing_status, - get_indexing_config_snapshot, - compute_indexing_config_hash, - is_staging_enabled, - set_staging_state, - update_staging_status, - clear_staging_collection, - activate_staging_collection, - promote_pending_indexing_config, - persist_indexing_config, - ) -except Exception: - get_collection_mappings = None # type: ignore - get_workspace_state = None # type: ignore - update_workspace_state = None # type: ignore - update_indexing_status = None # type: ignore - get_indexing_config_snapshot = None # type: ignore - compute_indexing_config_hash = None # type: ignore - is_staging_enabled = None # type: ignore - set_staging_state = None # type: ignore - update_staging_status = None # type: ignore - clear_staging_collection = None # type: ignore - activate_staging_collection = None # type: ignore - promote_pending_indexing_config = None # type: ignore - persist_indexing_config = None # type: ignore +from qdrant_client import QdrantClient + +from scripts.embedder import get_model_dimension +from scripts.collection_admin import copy_collection_qdrant +from scripts.ingest_code import ( + ensure_collection_and_indexes_once, + ensure_payload_indexes, + _sanitize_vector_name, + MINI_VECTOR_NAME as _MINI_VECTOR_NAME, + LEX_SPARSE_NAME as _LEX_SPARSE_NAME, +) +from scripts.workspace_state import ( + get_collection_mappings, + get_workspace_state, + update_workspace_state, + update_indexing_status, + get_indexing_config_snapshot, + compute_indexing_config_hash, + is_staging_enabled, + set_staging_state, + update_staging_status, + clear_staging_collection, + activate_staging_collection, + promote_pending_indexing_config, + persist_indexing_config, +) def _staging_enabled() -> bool: - return bool(is_staging_enabled() if callable(is_staging_enabled) else False) + return bool(is_staging_enabled()) def _workspace_base_dir() -> Path: @@ -141,7 +107,7 @@ def _copy_repo_state_for_clone( def _probe_collection_schema(collection: str) -> Optional[Dict[str, Any]]: - if not collection or QdrantClient is None: + if not collection: return None cached = _COLLECTION_SCHEMA_CACHE.get(collection) if cached: @@ -916,8 +882,6 @@ def build_admin_collections_view(*, collections: Any, work_dir: str) -> List[Dic def delete_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], collection: str) -> None: - if QdrantClient is None: - return name = (collection or "").strip() if not name: return @@ -927,6 +891,17 @@ def delete_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], collect return try: cli.delete_collection(collection_name=name) + # Best-effort: also delete companion graph edges collection when present. + if not name.endswith("_graph"): + try: + cli.delete_collection(collection_name=f"{name}_graph") + except Exception as exc: + try: + print( + f"[indexing_admin] best-effort graph collection delete failed for {name}_graph: {exc}" + ) + except Exception: + pass except Exception: pass finally: @@ -937,8 +912,6 @@ def delete_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], collect def recreate_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], collection: str) -> None: - if QdrantClient is None: - return name = (collection or "").strip() if not name: return @@ -951,6 +924,17 @@ def recreate_collection_qdrant(*, qdrant_url: str, api_key: Optional[str], colle cli.delete_collection(collection_name=name) except Exception as delete_error: raise RuntimeError(f"Failed to delete existing collection '{name}' in Qdrant: {delete_error}") from delete_error + # Best-effort: also delete companion graph edges collection when present. + if not name.endswith("_graph"): + try: + cli.delete_collection(collection_name=f"{name}_graph") + except Exception as exc: + try: + print( + f"[indexing_admin] best-effort graph collection delete failed for {name}_graph: {exc}" + ) + except Exception: + pass finally: try: cli.close() @@ -984,12 +968,9 @@ def spawn_ingest_code( env.pop(k, None) else: env[str(k)] = str(v) - # When we provide env overrides for a run (e.g. staging rebuild), we also want to - # force ingest_code to honor the explicit COLLECTION_NAME instead of routing based - # on per-repo state/serving_collection in multi-repo mode. - # CTXCE_FORCE_COLLECTION_NAME is only used for these subprocess runs; normal watcher - # and indexer flows do not set it. - env["CTXCE_FORCE_COLLECTION_NAME"] = "1" # Force ingest_code to use COLLECTION_NAME for staging/pending env overrides + # For admin-triggered subprocess runs (recreate/reindex/staging), force ingest_code to + # honor explicit COLLECTION_NAME and avoid multi-repo enumeration. + env["CTXCE_FORCE_COLLECTION_NAME"] = "1" env["COLLECTION_NAME"] = collection env["WATCH_ROOT"] = work_dir env["WORKSPACE_PATH"] = work_dir @@ -1022,11 +1003,10 @@ def spawn_ingest_code( def _determine_embedding_dim(model_name: str) -> int: - if get_model_dimension: - try: - return int(get_model_dimension(model_name)) - except Exception: - pass + try: + return int(get_model_dimension(model_name)) + except Exception: + pass try: from fastembed import TextEmbedding # type: ignore @@ -1037,16 +1017,13 @@ def _determine_embedding_dim(model_name: str) -> int: def _normalize_cloned_collection_schema(*, collection_name: str, qdrant_url: str) -> None: - if QdrantClient is None: - return vector_name = None model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") dim = _determine_embedding_dim(model_name) - if _sanitize_vector_name is not None: - try: - vector_name = _sanitize_vector_name(model_name) - except Exception: - vector_name = None + try: + vector_name = _sanitize_vector_name(model_name) + except Exception: + vector_name = None try: client = QdrantClient(url=qdrant_url, api_key=os.environ.get("QDRANT_API_KEY") or None) except Exception: @@ -1073,8 +1050,6 @@ def _normalize_cloned_collection_schema(*, collection_name: str, qdrant_url: str def _get_collection_point_count(*, collection_name: str, qdrant_url: str) -> Optional[int]: - if QdrantClient is None: - return None try: client = QdrantClient(url=qdrant_url, api_key=os.environ.get("QDRANT_API_KEY") or None) except Exception: @@ -1100,8 +1075,6 @@ def _wait_for_clone_points( expected_count: Optional[int], timeout_seconds: int = 60, ) -> None: - if QdrantClient is None: - return try: client = QdrantClient(url=qdrant_url, api_key=os.environ.get("QDRANT_API_KEY") or None) except Exception: @@ -1167,25 +1140,15 @@ def start_staging_rebuild(*, collection: str, work_dir: str) -> str: qdrant_url = os.environ.get("QDRANT_URL", "http://qdrant:6333") source_point_count = _get_collection_point_count(collection_name=collection, qdrant_url=qdrant_url) - # Use local import for thread-safety and determinism - _copy_fn: Any = copy_collection_qdrant - if _copy_fn is None: - # Re-import for container environments where module-level import may have failed - from scripts.collection_admin import copy_collection_qdrant as _ccq - _copy_fn = _ccq - - if not callable(_copy_fn): - raise RuntimeError("copy_collection_qdrant unavailable (import failed)") - try: print(f"[staging] Copying collection {collection} -> {old_collection} (overwrite=True)") try: print( - f"[staging] copy_collection_qdrant callable={callable(_copy_fn)} type={type(_copy_fn)} module={getattr(_copy_fn, '__module__', '?')}" + f"[staging] copy_collection_qdrant module={getattr(copy_collection_qdrant, '__module__', '?')}" ) except Exception: pass - _copy_fn( + copy_collection_qdrant( source=collection, target=old_collection, qdrant_url=qdrant_url, @@ -1291,9 +1254,9 @@ def start_staging_rebuild(*, collection: str, work_dir: str) -> str: pending_env = state.get("indexing_env_pending") or dict(os.environ) env_hash = pending_hash or current_env_indexing_hash() - if not pending_cfg and get_indexing_config_snapshot: - pending_cfg = get_indexing_config_snapshot() if callable(get_indexing_config_snapshot) else get_indexing_config_snapshot - if not pending_hash and pending_cfg and compute_indexing_config_hash: + if not pending_cfg: + pending_cfg = get_indexing_config_snapshot() + if not pending_hash and pending_cfg: pending_hash = compute_indexing_config_hash(pending_cfg) if set_staging_state: diff --git a/scripts/ingest/__init__.py b/scripts/ingest/__init__.py index 6b802a95..e273a90e 100644 --- a/scripts/ingest/__init__.py +++ b/scripts/ingest/__init__.py @@ -1,36 +1,9 @@ -""" -Ingest package - Code indexing subsystem. - -This package contains extracted modules from ingest_code.py: -- config: Environment-based configuration and constants -- tree_sitter: Tree-sitter setup and language loading -- vectors: Vector generation utilities (lex hash, mini projection) -- exclusions: File and directory exclusion logic -- chunking: Code chunking utilities (line, semantic, token-based) -- symbols: Symbol extraction for code analysis -- pseudo: ReFRAG pseudo-description and tag generation -- metadata: Metadata extraction (git, imports, calls) -- qdrant: Qdrant schema and I/O operations -- pipeline: Core indexing pipeline -- cli: Command-line interface +"""Code indexing subsystem package. -Usage: - from scripts.ingest import config, pipeline, qdrant - from scripts.ingest.config import LEX_VECTOR_NAME, LEX_VECTOR_DIM - from scripts.ingest.pipeline import index_repo, index_single_file - from scripts.ingest.qdrant import ensure_collection, upsert_points +Submodules are intentionally loaded on demand. Importing lightweight helpers +such as ``scripts.ingest.config`` should not initialize Qdrant, tree-sitter, or +the indexing pipeline. """ -from scripts.ingest import config -from scripts.ingest import tree_sitter -from scripts.ingest import vectors -from scripts.ingest import exclusions -from scripts.ingest import chunking -from scripts.ingest import symbols -from scripts.ingest import pseudo -from scripts.ingest import metadata -from scripts.ingest import qdrant -from scripts.ingest import pipeline -from scripts.ingest import cli __all__ = [ "config", diff --git a/scripts/ingest/chunking.py b/scripts/ingest/chunking.py index a90ffdf6..9a1e36a4 100644 --- a/scripts/ingest/chunking.py +++ b/scripts/ingest/chunking.py @@ -14,12 +14,9 @@ from scripts.ingest.config import ROOT_DIR from scripts.ingest.tree_sitter import _use_tree_sitter, _TS_LANGUAGES -# Import AST analyzer for enhanced semantic chunking -try: - from scripts.ast_analyzer import get_ast_analyzer, chunk_code_semantically - _AST_ANALYZER_AVAILABLE = True -except ImportError: - _AST_ANALYZER_AVAILABLE = False +from scripts.ast_analyzer import get_ast_analyzer, chunk_code_semantically + +_AST_ANALYZER_AVAILABLE = True # Cache tokenizers loaded from TOKENIZER_JSON (or default) to avoid repeatedly @@ -57,10 +54,6 @@ def chunk_semantic( _ast_supported = False if use_enhanced and _AST_ANALYZER_AVAILABLE: try: - # ast_analyzer internally respects USE_TREE_SITTER when constructing the analyzer - # (see scripts/ast_analyzer.py:get_ast_analyzer). - from scripts.ast_analyzer import get_ast_analyzer # type: ignore - analyzer = get_ast_analyzer() lang_key = str(language or "").strip().lower() # Supported either via builtin ast (python) or via tree-sitter when enabled. diff --git a/scripts/ingest/cli.py b/scripts/ingest/cli.py index 8561a9c2..1e089a16 100644 --- a/scripts/ingest/cli.py +++ b/scripts/ingest/cli.py @@ -9,15 +9,20 @@ import os import argparse +import logging from pathlib import Path from scripts.ingest.config import ( is_multi_repo_mode, get_collection_name, ) +from scripts.pseudo_config import env_bool, effective_pseudo_mode +from scripts.collection_health import clear_indexing_caches as _clear_indexing_caches_impl from scripts.ingest.pipeline import index_repo from scripts.ingest.pseudo import generate_pseudo_tags +logger = logging.getLogger(__name__) + def parse_args(): """Parse command-line arguments.""" @@ -40,6 +45,11 @@ def parse_args(): action="store_true", help="Do not skip files whose content hash matches existing index", ) + parser.add_argument( + "--clear-indexing-caches", + action="store_true", + help="Clear local indexing caches (file hash/symbol caches) before indexing", + ) parser.add_argument( "--schema-mode", type=str, @@ -186,13 +196,31 @@ def main(): ) return + def _clear_indexing_caches(workspace_root: Path, repo_name: str | None) -> None: + try: + _clear_indexing_caches_impl(str(workspace_root), repo_name=repo_name) + except Exception as e: + logger.warning( + "Failed to clear indexing caches for workspace=%s repo=%s: %s", + workspace_root, + repo_name, + e, + exc_info=True, + ) + qdrant_url = os.environ.get("QDRANT_URL", "http://localhost:6333") api_key = os.environ.get("QDRANT_API_KEY") collection = os.environ.get("COLLECTION_NAME") or os.environ.get("DEFAULT_COLLECTION") or "codebase" model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") # Resolve collection name based on multi-repo mode - multi_repo = bool(is_multi_repo_mode and is_multi_repo_mode()) + force_collection = (os.environ.get("CTXCE_FORCE_COLLECTION_NAME") or "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + multi_repo = bool(is_multi_repo_mode and is_multi_repo_mode()) and not force_collection if multi_repo: print("[multi_repo] Multi-repo mode enabled - will create separate collections per repository") @@ -231,6 +259,9 @@ def main(): if not repo_collection: repo_collection = "codebase" + if args.clear_indexing_caches: + _clear_indexing_caches(root_path, repo_name) + index_repo( repo_root, qdrant_url, @@ -240,25 +271,25 @@ def main(): args.recreate, dedupe=(not args.no_dedupe), skip_unchanged=(not args.no_skip_unchanged), - pseudo_mode="off" if (os.environ.get("PSEUDO_DEFER_TO_WORKER") or "").strip().lower() in {"1", "true", "yes", "on"} else "full", + pseudo_mode=effective_pseudo_mode( + defer_to_worker=env_bool("PSEUDO_DEFER_TO_WORKER"), + backfill_enabled=env_bool("PSEUDO_BACKFILL_ENABLED"), + ), schema_mode=args.schema_mode, ) return else: - if get_collection_name: - try: - resolved = get_collection_name(str(Path(args.root).resolve())) - placeholders = {"", "default-collection", "my-collection", "codebase"} - if resolved and collection in placeholders: - collection = resolved - except Exception: - pass if not collection: collection = os.environ.get("COLLECTION_NAME", "codebase") print(f"[single_repo] Single-repo mode enabled - using collection: {collection}") - flag = (os.environ.get("PSEUDO_DEFER_TO_WORKER") or "").strip().lower() - pseudo_mode = "off" if flag in {"1", "true", "yes", "on"} else "full" + pseudo_mode = effective_pseudo_mode( + defer_to_worker=env_bool("PSEUDO_DEFER_TO_WORKER"), + backfill_enabled=env_bool("PSEUDO_BACKFILL_ENABLED"), + ) + + if args.clear_indexing_caches: + _clear_indexing_caches(Path(args.root).resolve(), None) index_repo( Path(args.root).resolve(), diff --git a/scripts/ingest/config.py b/scripts/ingest/config.py index c4001eaf..89f10374 100644 --- a/scripts/ingest/config.py +++ b/scripts/ingest/config.py @@ -221,68 +221,41 @@ def _env_truthy(val: str | None, default: bool) -> bool: # --------------------------------------------------------------------------- -# Workspace state function imports (optional) +# Workspace state function imports # --------------------------------------------------------------------------- -# These are imported at module load time for convenience, with fallbacks -try: - from scripts.workspace_state import ( - is_multi_repo_mode, - get_collection_name, - logical_repo_reuse_enabled, - ) -except ImportError: - is_multi_repo_mode = None # type: ignore - get_collection_name = None # type: ignore +from scripts.workspace_state import ( + is_multi_repo_mode, + get_collection_name, + logical_repo_reuse_enabled, + log_activity, + get_cached_file_hash, + set_cached_file_hash, + remove_cached_file, + update_indexing_status, + update_workspace_state, + get_cached_symbols, + set_cached_symbols, + remove_cached_symbols, + compare_symbol_changes, + get_cached_pseudo, + set_cached_pseudo, + update_symbols_with_pseudo, + get_workspace_state, + get_cached_file_meta, + indexing_lock, + file_indexing_lock, + is_file_locked, +) - def logical_repo_reuse_enabled() -> bool: # type: ignore[no-redef] - return False +def _detect_repo_for_file(path): + """Defer watcher routing import to avoid the ingest/watch bootstrap cycle.""" + from scripts.watch_index_core.routing import _detect_repo_for_file as _impl -# Import watcher's repo detection for surgical fix -try: - from scripts.watch_index_core.routing import _detect_repo_for_file, _get_collection_for_file -except ImportError: - _detect_repo_for_file = None # type: ignore - _get_collection_for_file = None # type: ignore + return _impl(path) -# Import other workspace state functions (optional) -try: - from scripts.workspace_state import ( - log_activity, - get_cached_file_hash, - set_cached_file_hash, - remove_cached_file, - update_indexing_status, - update_workspace_state, - get_cached_symbols, - set_cached_symbols, - remove_cached_symbols, - compare_symbol_changes, - get_cached_pseudo, - set_cached_pseudo, - update_symbols_with_pseudo, - get_workspace_state, - get_cached_file_meta, - indexing_lock, - file_indexing_lock, - is_file_locked, - ) -except ImportError: - # State integration is optional; continue if not available - log_activity = None # type: ignore - get_cached_file_hash = None # type: ignore - set_cached_file_hash = None # type: ignore - remove_cached_file = None # type: ignore - update_indexing_status = None # type: ignore - update_workspace_state = None # type: ignore - get_cached_symbols = None # type: ignore - set_cached_symbols = None # type: ignore - remove_cached_symbols = None # type: ignore - get_cached_pseudo = None # type: ignore - set_cached_pseudo = None # type: ignore - update_symbols_with_pseudo = None # type: ignore - compare_symbol_changes = None # type: ignore - get_workspace_state = None # type: ignore - get_cached_file_meta = None # type: ignore - indexing_lock = None # type: ignore - file_indexing_lock = None # type: ignore - is_file_locked = None # type: ignore + +def _get_collection_for_file(path): + """Defer watcher routing import to avoid the ingest/watch bootstrap cycle.""" + from scripts.watch_index_core.routing import _get_collection_for_file as _impl + + return _impl(path) diff --git a/scripts/ingest/graph_edges.py b/scripts/ingest/graph_edges.py new file mode 100644 index 00000000..3751d970 --- /dev/null +++ b/scripts/ingest/graph_edges.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 +""" +ingest/graph_edges.py - Materialized graph edges in Qdrant. + +This is a small, MIT-safe reimplementation of the "graph edges collection" idea: +- Maintain a dedicated Qdrant collection named `_graph` +- Store payload-only edge docs for fast lookups: + - callers/importers queries become simple keyword filters on an indexed payload field + +Design goals for this branch: +- Keep this as an *accelerator* (symbol_graph still works without it) +- Avoid Neo4j/PageRank/GraphRAG complexity +- Avoid CLI flags; watcher can backfill opportunistically +""" + +from __future__ import annotations + +import hashlib +import logging +import os +from typing import Any, Dict, Iterable, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +GRAPH_COLLECTION_SUFFIX = "_graph" + +EDGE_TYPE_CALLS = "calls" +EDGE_TYPE_IMPORTS = "imports" + +GRAPH_INDEX_FIELDS: Tuple[str, ...] = ( + "caller_path", + "callee_symbol", + "edge_type", + "repo", +) + +_ENSURED_GRAPH_COLLECTIONS: set[str] = set() +_GRAPH_VECTOR_MODE: dict[str, str] = {} +_MISSING_GRAPH_COLLECTIONS: set[str] = set() +_BACKFILL_OFFSETS: dict[tuple[str, Optional[str]], Any] = {} + +_EDGE_VECTOR_NAME = "_edge" +_EDGE_VECTOR_VALUE = [0.0] + + +def _normalize_path(path: str) -> str: + if not path: + return "" + try: + normalized = os.path.normpath(str(path)) + except Exception: + normalized = str(path) + return normalized.replace("\\", "/") + + +def normalize_caller_path(path: str) -> str: + """Normalize a caller path exactly as graph edge payloads do. + + This is used by both the writer (upsert/delete) and any readers/verifiers so + cross-platform path separators (Windows vs POSIX) do not cause mismatches. + """ + + return _normalize_path(path) + + +def get_graph_collection_name(base_collection: str) -> str: + return f"{base_collection}{GRAPH_COLLECTION_SUFFIX}" + + +def _edge_vector_for_upsert(graph_collection: str) -> dict: + mode = _GRAPH_VECTOR_MODE.get(graph_collection) + if mode == "named": + return {_EDGE_VECTOR_NAME: _EDGE_VECTOR_VALUE} + return {} + + +def ensure_graph_collection(client: Any, base_collection: str) -> Optional[str]: + """Ensure `_graph` exists and has payload indexes.""" + from qdrant_client import models as qmodels + from qdrant_client.http.exceptions import UnexpectedResponse + + if not base_collection: + return None + graph_coll = get_graph_collection_name(base_collection) + if graph_coll in _ENSURED_GRAPH_COLLECTIONS: + return graph_coll + + def _detect_vector_mode(info: Any) -> str: + try: + vectors = getattr( + getattr(getattr(info, "config", None), "params", None), "vectors", None + ) + if isinstance(vectors, dict): + return "none" if not vectors else "named" + return "none" if vectors is None else "named" + except Exception: + return "named" + + try: + info = client.get_collection(graph_coll) + _GRAPH_VECTOR_MODE[graph_coll] = _detect_vector_mode(info) + _ENSURED_GRAPH_COLLECTIONS.add(graph_coll) + _MISSING_GRAPH_COLLECTIONS.discard(graph_coll) + return graph_coll + except UnexpectedResponse as e: + # Only a 404 means "missing"; any other HTTP failure should be visible. + if getattr(e, "status_code", None) != 404: + logger.exception( + "Failed to get graph collection %s (status=%s): %s", + graph_coll, + getattr(e, "status_code", None), + e, + ) + return None + except Exception as e: + logger.exception("Failed to get graph collection %s: %s", graph_coll, e) + return None + + try: + # Prefer vector-less collection when supported by server/client. + try: + client.create_collection( + collection_name=graph_coll, + vectors_config={}, + ) + _GRAPH_VECTOR_MODE[graph_coll] = "none" + except Exception as vec_exc: + logger.debug( + "Vector-less creation failed for %s, trying named vector: %s", + graph_coll, + vec_exc, + ) + client.create_collection( + collection_name=graph_coll, + vectors_config={ + _EDGE_VECTOR_NAME: qmodels.VectorParams( + size=1, distance=qmodels.Distance.COSINE + ) + }, + ) + _GRAPH_VECTOR_MODE[graph_coll] = "named" + + # Create payload indexes (best-effort). + for field in GRAPH_INDEX_FIELDS: + try: + client.create_payload_index( + collection_name=graph_coll, + field_name=field, + field_schema=qmodels.PayloadSchemaType.KEYWORD, + ) + except Exception as e: + logger.debug( + "Failed to create graph payload index '%s' for %s: %s", + field, + graph_coll, + e, + exc_info=True, + ) + + _ENSURED_GRAPH_COLLECTIONS.add(graph_coll) + _MISSING_GRAPH_COLLECTIONS.discard(graph_coll) + return graph_coll + except Exception as e: + logger.debug("Failed to ensure graph collection %s: %s", graph_coll, e) + return None + + +def _edge_id(edge_type: str, repo: str, caller_path: str, callee_symbol: str) -> str: + key = f"{edge_type}\x00{repo}\x00{caller_path}\x00{callee_symbol}" + return hashlib.sha256(key.encode("utf-8", errors="ignore")).hexdigest()[:32] + + +def _iter_edges( + *, + caller_path: str, + repo: str, + calls: Iterable[str] = (), + imports: Iterable[str] = (), +) -> List[Dict[str, Any]]: + norm_path = _normalize_path(caller_path) + repo_s = (repo or "").strip() or "default" + + edges: List[Dict[str, Any]] = [] + for sym in calls or []: + s = str(sym).strip() + if not s: + continue + edges.append( + { + "id": _edge_id(EDGE_TYPE_CALLS, repo_s, norm_path, s), + "payload": { + "caller_path": norm_path, + "callee_symbol": s, + "edge_type": EDGE_TYPE_CALLS, + "repo": repo_s, + }, + } + ) + for sym in imports or []: + s = str(sym).strip() + if not s: + continue + edges.append( + { + "id": _edge_id(EDGE_TYPE_IMPORTS, repo_s, norm_path, s), + "payload": { + "caller_path": norm_path, + "callee_symbol": s, + "edge_type": EDGE_TYPE_IMPORTS, + "repo": repo_s, + }, + } + ) + return edges + + +def upsert_file_edges( + client: Any, + base_collection: str, + *, + caller_path: str, + repo: str | None, + calls: List[str] | None = None, + imports: List[str] | None = None, +) -> int: + graph_coll = ensure_graph_collection(client, base_collection) + if not graph_coll: + return 0 + edges = _iter_edges( + caller_path=caller_path, + repo=repo or "default", + calls=calls or [], + imports=imports or [], + ) + if not edges: + return 0 + + from qdrant_client import models as qmodels + + points = [ + qmodels.PointStruct( + id=e["id"], + vector=_edge_vector_for_upsert(graph_coll), + payload=e["payload"], + ) + for e in edges + ] + try: + client.upsert(collection_name=graph_coll, points=points, wait=True) + return len(points) + except Exception as e: + logger.debug("Graph edge upsert failed for %s: %s", caller_path, e) + return 0 + + +def delete_edges_by_path( + client: Any, + base_collection: str, + *, + caller_path: str, + repo: str | None = None, +) -> int: + from qdrant_client.http.exceptions import UnexpectedResponse + graph_coll = get_graph_collection_name(base_collection) + if graph_coll in _MISSING_GRAPH_COLLECTIONS: + return 0 + + from qdrant_client import models as qmodels + + norm_path = _normalize_path(caller_path) + must: list[Any] = [ + qmodels.FieldCondition( + key="caller_path", match=qmodels.MatchValue(value=norm_path) + ) + ] + if repo: + r = str(repo).strip() + if r and r != "*": + must.append( + qmodels.FieldCondition(key="repo", match=qmodels.MatchValue(value=r)) + ) + flt = qmodels.Filter(must=must) + + # Probe first so callers can distinguish "no matching rows" (0) from a real delete. + # This is important for fallback logic (e.g., retry path-only delete when repo tag drifted). + try: + existing, _ = client.scroll( + collection_name=graph_coll, + scroll_filter=flt, + limit=1, + with_payload=False, + with_vectors=False, + ) + if not existing: + return 0 + except UnexpectedResponse as e: + if getattr(e, "status_code", None) == 404: + _MISSING_GRAPH_COLLECTIONS.add(graph_coll) + return 0 + logger.debug( + "Graph edge probe failed for %s in %s (status=%s): %s", + norm_path, + graph_coll, + getattr(e, "status_code", None), + e, + exc_info=True, + ) + return 0 + except Exception as e: + logger.debug( + "Graph edge probe failed for %s in %s: %s", + norm_path, + graph_coll, + e, + exc_info=True, + ) + return 0 + + try: + resp = client.delete( + collection_name=graph_coll, + points_selector=qmodels.FilterSelector(filter=flt), + ) + result_status = getattr(getattr(resp, "result", None), "status", None) + if result_status is None: + result_status = getattr(resp, "status", None) + if result_status is None: + return 1 + status_s = str(result_status).strip().lower() + return 1 if status_s in {"acknowledged", "completed", "ok", "success"} else 0 + except UnexpectedResponse as e: + if getattr(e, "status_code", None) == 404: + _MISSING_GRAPH_COLLECTIONS.add(graph_coll) + return 0 + logger.debug( + "Graph edge delete failed for %s in %s (status=%s): %s", + norm_path, + graph_coll, + getattr(e, "status_code", None), + e, + exc_info=True, + ) + return 0 + except Exception as e: + logger.debug( + "Graph edge delete failed for %s in %s: %s", + norm_path, + graph_coll, + e, + exc_info=True, + ) + return 0 + + +def graph_edges_backfill_tick( + client: Any, + base_collection: str, + *, + repo_name: str | None = None, + max_files: int = 128, +) -> int: + """Best-effort incremental backfill from `` into `_graph`. + + This scans the main collection and upserts file-level edges into the graph collection. + It's idempotent (deterministic IDs) and safe to run continuously in a watcher worker. + """ + from qdrant_client import models as qmodels + + if not base_collection or max_files <= 0: + return 0 + + graph_coll = ensure_graph_collection(client, base_collection) + if not graph_coll: + return 0 + + must: list[Any] = [] + if repo_name: + must.append( + qmodels.FieldCondition( + key="metadata.repo", match=qmodels.MatchValue(value=repo_name) + ) + ) + flt = qmodels.Filter(must=must or None) + + processed_files = 0 + seen_paths: set[str] = set() + + key = (base_collection, repo_name) + next_offset = _BACKFILL_OFFSETS.get(key) + + # We may need to overscan because the main collection is chunked. + overscan = max_files * 8 + while processed_files < max_files: + attempts = 0 + while True: + try: + points, next_offset = client.scroll( + collection_name=base_collection, + scroll_filter=flt, + limit=min(64, overscan), + with_payload=True, + with_vectors=False, + offset=next_offset, + ) + break + except Exception as e: + attempts += 1 + logger.exception( + "Graph edge backfill scroll failed (collection=%s repo=%s offset=%s attempt=%d): %s", + base_collection, + repo_name or "default", + next_offset, + attempts, + e, + ) + # Retry a couple times for transient errors, then raise so failures are not silent. + if attempts >= 3: + raise + import time + + time.sleep(0.25 * (2 ** (attempts - 1))) + + if not points: + break + + for rec in points: + if processed_files >= max_files: + break + payload = getattr(rec, "payload", None) or {} + md = payload.get("metadata") or {} + path = md.get("path") or "" + if not path: + continue + norm_path = _normalize_path(str(path)) + if norm_path in seen_paths: + continue + seen_paths.add(norm_path) + + repo = md.get("repo") or repo_name or "default" + calls = md.get("calls") or [] + imports = md.get("imports") or [] + if not isinstance(calls, list): + calls = [] + if not isinstance(imports, list): + imports = [] + + upsert_file_edges( + client, + base_collection, + caller_path=norm_path, + repo=str(repo), + calls=[str(x) for x in calls if x], + imports=[str(x) for x in imports if x], + ) + processed_files += 1 + + if next_offset is None: + break + + _BACKFILL_OFFSETS[key] = next_offset + return processed_files + + +__all__ = [ + "GRAPH_COLLECTION_SUFFIX", + "EDGE_TYPE_CALLS", + "EDGE_TYPE_IMPORTS", + "get_graph_collection_name", + "ensure_graph_collection", + "upsert_file_edges", + "delete_edges_by_path", + "graph_edges_backfill_tick", +] diff --git a/scripts/ingest/pipeline.py b/scripts/ingest/pipeline.py index 45716049..fb02eb6c 100644 --- a/scripts/ingest/pipeline.py +++ b/scripts/ingest/pipeline.py @@ -14,7 +14,18 @@ from pathlib import Path from typing import List, Dict, Any, Optional, TYPE_CHECKING -from qdrant_client import QdrantClient, models +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any # type: ignore + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() from scripts.ingest.config import ( ROOT_DIR, @@ -86,11 +97,13 @@ def _detect_repo_name_from_path(path: Path) -> str: """Wrapper function to use workspace_state repository detection.""" - try: - from scripts.workspace_state import _extract_repo_name_from_path as _ws_detect - return _ws_detect(str(path)) - except ImportError: - return path.name if path.is_dir() else path.parent.name + from scripts.workspace_state import _extract_repo_name_from_path as _ws_detect + + # `_extract_repo_name_from_path` expects a workspace/repo path shape, not a file path. + # Always normalize file inputs to their parent directory to avoid falling back to + # file basenames (which can poison metadata.repo and graph edge repo tags). + candidate = path if path.is_dir() else path.parent + return _ws_detect(str(candidate)) def detect_language(path: Path) -> str: @@ -109,7 +122,8 @@ def detect_language(path: Path) -> str: _TEXT_LIKE_LANGS = {"unknown", "markdown", "text"} -def _is_text_like_language(language: str) -> bool: +def is_text_like_language(language: str) -> bool: + """Classify whether a detected language should skip smart reindexing.""" return str(language or "").strip().lower() in _TEXT_LIKE_LANGS @@ -227,6 +241,87 @@ def _normalize_info_for_dense(s: str) -> str: return text +def _sync_graph_edges_best_effort( + client: QdrantClient, + collection: str, + file_path: str, + repo: str | None, + calls: list[str] | None, + imports: list[str] | None, +) -> None: + """Best-effort sync of file-level graph edges. Safe to skip on failure.""" + enabled = str(os.environ.get("GRAPH_EDGES_ENABLE", "1") or "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + if not enabled: + return + try: + from scripts.ingest.graph_edges import ( + delete_edges_by_path, + ensure_graph_collection, + upsert_file_edges, + ) + + ensure_graph_collection(client, collection) + # Important: delete stale edges for this file before upserting the new set. + delete_edges_by_path( + client, + collection, + caller_path=str(file_path), + repo=repo, + ) + upsert_file_edges( + client, + collection, + caller_path=str(file_path), + repo=repo, + calls=calls, + imports=imports, + ) + except Exception as exc: + try: + print(f"[graph_edges] best-effort sync failed for {file_path}: {exc}") + except Exception: + pass + + +def _symbols_to_metadata_dict(language: str, text: str) -> dict: + """Build symbol metadata dict from in-memory source text.""" + symbols = {} + try: + symbols_list = _extract_symbols(language, text) + lines = text.split("\n") + for sym in symbols_list or []: + kind = str(sym.get("kind") or "") + name = str(sym.get("name") or "") + start = int(sym.get("start") or 0) + end = int(sym.get("end") or 0) + if not kind or not name or start <= 0 or end < start: + continue + symbol_id = f"{kind}_{name}_{start}" + content = "\n".join(lines[start - 1 : end]) + content_hash = hashlib.sha1( + content.encode("utf-8", errors="ignore") + ).hexdigest() + symbols[symbol_id] = { + "name": name, + "type": kind, + "start_line": start, + "end_line": end, + "content_hash": content_hash, + "content": content, + "pseudo": "", + "tags": [], + "qdrant_ids": [], + } + except Exception: + return {} + return symbols + + def build_information( language: str, path: Path, start: int, end: int, first_line: str ) -> str: @@ -251,16 +346,31 @@ def index_single_file( repo_name_for_cache: str | None = None, allowed_vectors: set[str] | None = None, allowed_sparse: set[str] | None = None, + preloaded_text: str | None = None, + preloaded_file_hash: str | None = None, + preloaded_language: str | None = None, ) -> bool: """Index a single file path. Returns True if indexed, False if skipped.""" + repo_for_graph = repo_name_for_cache or _detect_repo_name_from_path(file_path) try: if _should_skip_explicit_file_by_excluder(file_path): try: delete_points_by_path(client, collection, str(file_path)) except Exception: pass + # Clean up graph edges for excluded file + _sync_graph_edges_best_effort( + client, + collection, + str(file_path), + repo_for_graph, + None, # No calls when file is excluded + None, # No imports when file is excluded + ) print(f"Skipping excluded file: {file_path}") return False + except NameError: + raise except Exception: return False @@ -283,6 +393,9 @@ def index_single_file( repo_name_for_cache=repo_name_for_cache, allowed_vectors=allowed_vectors, allowed_sparse=allowed_sparse, + preloaded_text=preloaded_text, + preloaded_file_hash=preloaded_file_hash, + preloaded_language=preloaded_language, ) finally: if _file_lock_ctx is not None: @@ -306,6 +419,9 @@ def _index_single_file_inner( repo_name_for_cache: str | None = None, allowed_vectors: set[str] | None = None, allowed_sparse: set[str] | None = None, + preloaded_text: str | None = None, + preloaded_file_hash: str | None = None, + preloaded_language: str | None = None, ) -> bool: """Inner implementation of index_single_file (after lock is acquired).""" if trust_cache is None: @@ -317,7 +433,12 @@ def _index_single_file_inner( trust_cache = False fast_fs = _env_truthy(os.environ.get("INDEX_FS_FASTPATH"), False) - if skip_unchanged and fast_fs and get_cached_file_meta is not None: + if ( + preloaded_text is None + and skip_unchanged + and fast_fs + and get_cached_file_meta is not None + ): try: repo_for_cache = repo_name_for_cache or _detect_repo_name_from_path(file_path) meta = get_cached_file_meta(str(file_path), repo_for_cache) or {} @@ -333,15 +454,17 @@ def _index_single_file_inner( except Exception: pass - try: - text = file_path.read_text(encoding="utf-8", errors="ignore") - except Exception as e: - print(f"Skipping {file_path}: {e}") - return False + if preloaded_text is None: + try: + text = file_path.read_text(encoding="utf-8", errors="ignore") + except Exception as e: + print(f"Skipping {file_path}: {e}") + return False + else: + text = preloaded_text - language = detect_language(file_path) - is_text_like = _is_text_like_language(language) - file_hash = hashlib.sha1(text.encode("utf-8", errors="ignore")).hexdigest() + language = preloaded_language or detect_language(file_path) + file_hash = preloaded_file_hash or hashlib.sha1(text.encode("utf-8", errors="ignore")).hexdigest() repo_tag = repo_name_for_cache or _detect_repo_name_from_path(file_path) @@ -376,7 +499,10 @@ def _index_single_file_inner( if get_cached_symbols and set_cached_symbols: cached_symbols = get_cached_symbols(str(file_path)) if cached_symbols: - current_symbols = extract_symbols_with_tree_sitter(str(file_path)) + if preloaded_text is not None: + current_symbols = _symbols_to_metadata_dict(language, preloaded_text) + else: + current_symbols = extract_symbols_with_tree_sitter(str(file_path)) _, changed = compare_symbol_changes(cached_symbols, current_symbols) for symbol_data in current_symbols.values(): symbol_id = f"{symbol_data['type']}_{symbol_data['name']}_{symbol_data['start_line']}" @@ -653,6 +779,22 @@ def make_point(pid, dense_vec, lex_vec, payload, lex_text: str = "", code_text: for i, v, lx, m, lt, ct in zip(batch_ids, vectors, batch_lex, batch_meta, batch_lex_text, batch_code) ] upsert_points(client, collection, points) + + # Optional: materialize file-level graph edges in a companion `_graph` store. + # This is an accelerator for symbol_graph callers/importers and is safe to skip on failure. + # IMPORTANT: Sync must run after upserts (or after delete-only reindex) to ensure graph + # edges stay consistent. When a file reindexes to zero chunks, batch_texts is empty but + # we still need to sync graph edges to remove stale entries. + _sync_graph_edges_best_effort( + client, + collection, + str(file_path), + repo_tag, + calls, + imports, + ) + + if batch_texts: try: ws = os.environ.get("WATCH_ROOT") or os.environ.get("WORKSPACE_PATH") or "/work" if set_cached_file_hash: @@ -717,14 +859,10 @@ def index_repo( except Exception: pass - try: - from scripts.embedder import get_embedding_model, get_model_dimension - model = get_embedding_model(model_name) - dim = get_model_dimension(model_name) - except ImportError: - from fastembed import TextEmbedding - model = TextEmbedding(model_name=model_name) - dim = len(next(model.embed(["dimension probe"]))) + from scripts.embedder import get_embedding_model, get_model_dimension + + model = get_embedding_model(model_name) + dim = get_model_dimension(model_name) client = QdrantClient( url=qdrant_url, @@ -881,6 +1019,7 @@ def process_file_with_smart_reindexing( model, vector_name: str | None, *, + model_dim: int | None = None, allowed_vectors: set[str] | None = None, allowed_sparse: set[str] | None = None, ) -> str: @@ -890,18 +1029,10 @@ def process_file_with_smart_reindexing( - Reusing existing embeddings/lexical vectors for unchanged chunks (by code content), and - Re-embedding only for changed chunks. """ - # Allow test monkeypatching on ingest_code.* to be honored here. - # Must be done FIRST before any helper calls. - _ingest_mod = None - try: - import importlib - _ingest_mod = importlib.import_module("scripts.ingest_code") - except Exception: - _ingest_mod = None - _embed_batch = getattr(_ingest_mod, "embed_batch", embed_batch) if _ingest_mod else embed_batch - _upsert_points_fn = getattr(_ingest_mod, "upsert_points", upsert_points) if _ingest_mod else upsert_points - _delete_points_fn = getattr(_ingest_mod, "delete_points_by_path", delete_points_by_path) if _ingest_mod else delete_points_by_path - _should_process_pseudo = getattr(_ingest_mod, "should_process_pseudo_for_chunk", should_process_pseudo_for_chunk) if _ingest_mod else should_process_pseudo_for_chunk + _embed_batch = embed_batch + _upsert_points_fn = upsert_points + _delete_points_fn = delete_points_by_path + _should_process_pseudo = should_process_pseudo_for_chunk try: p = Path(str(file_path)) @@ -910,8 +1041,19 @@ def process_file_with_smart_reindexing( _delete_points_fn(client, current_collection, str(p)) except Exception: pass + # Clean up graph edges for excluded file + _sync_graph_edges_best_effort( + client, + current_collection, + str(p), + per_file_repo or _detect_repo_name_from_path(file_path), + None, # No calls when file is excluded + None, # No imports when file is excluded + ) print(f"[SMART_REINDEX] Skipping excluded file: {file_path}") return "skipped" + except NameError: + raise except Exception: return "skipped" @@ -927,6 +1069,13 @@ def process_file_with_smart_reindexing( except Exception: file_path = Path(fp) + is_text_like = is_text_like_language(language) + if is_text_like: + print( + f"[SMART_REINDEX] {file_path}: text-like language '{language}', " + "skipping smart reindex and using full reindex path" + ) + return "failed" file_hash = hashlib.sha1(text.encode("utf-8", errors="ignore")).hexdigest() if allowed_vectors is None and allowed_sparse is None: @@ -988,8 +1137,31 @@ def process_file_with_smart_reindexing( changed_set = set(changed_symbols) if len(changed_symbols) == 0 and cached_symbols: - print(f"[SMART_REINDEX] {file_path}: 0 changes detected, skipping") - return "skipped" + prev_hash = None + try: + if get_cached_file_hash: + prev_hash = get_cached_file_hash(fp, per_file_repo) + except Exception: + prev_hash = None + if prev_hash and file_hash and prev_hash == file_hash: + print(f"[SMART_REINDEX] {file_path}: 0 changes detected, skipping") + return "skipped" + print( + f"[SMART_REINDEX] {file_path}: non-symbol change detected; " + "falling back to full reindex" + ) + return "failed" + + if model_dim and vector_name: + try: + ensure_collection_and_indexes_once( + client, + current_collection, + int(model_dim), + vector_name, + ) + except Exception: + pass existing_points = [] try: @@ -1085,7 +1257,6 @@ def process_file_with_smart_reindexing( else: chunks = chunk_lines(text, CHUNK_LINES, CHUNK_OVERLAP) - is_text_like = _is_text_like_language(language) symbol_spans = _extract_symbols(language, text) reused_points: list[models.PointStruct] = [] @@ -1102,6 +1273,30 @@ def process_file_with_smart_reindexing( pseudo_batch_concurrency = int(os.environ.get("PSEUDO_BATCH_CONCURRENCY", "1") or 1) use_batch_pseudo = pseudo_batch_concurrency > 1 + def _apply_symbol_pseudo( + symbol_name: str, + kind: str, + start_line: int, + pseudo_text: str, + pseudo_tags: list[str], + ) -> None: + if not symbol_name or not kind: + return + sid = f"{kind}_{symbol_name}_{start_line}" + target = symbol_meta.get(sid) + if target is None: + for candidate in symbol_meta.values(): + if str(candidate.get("type") or "") != str(kind): + continue + if str(candidate.get("name") or "") != str(symbol_name): + continue + target = candidate + break + if target is None: + return + target["pseudo"] = pseudo_text + target["tags"] = list(pseudo_tags or []) + chunk_data_sr: list[dict] = [] for ch in chunks: info = build_information( @@ -1190,6 +1385,14 @@ def process_file_with_smart_reindexing( start_line = ch.get("start", 0) sid = f"{k}_{symbol_name}_{start_line}" set_cached_pseudo(fp, sid, pseudo, tags, file_hash) + _apply_symbol_pseudo( + symbol_name, + ch.get("kind", "unknown"), + ch.get("start", 0), + pseudo, + tags, + ) + ch["_pseudo_applied"] = True except Exception as e: print(f"[PSEUDO_BATCH] Smart reindex batch failed, falling back: {e}") use_batch_pseudo = False @@ -1211,9 +1414,26 @@ def process_file_with_smart_reindexing( sid = f"{k}_{symbol_name}_{start_line}" if set_cached_pseudo: set_cached_pseudo(fp, sid, pseudo, tags, file_hash) + _apply_symbol_pseudo( + symbol_name, + k, + start_line, + pseudo, + tags, + ) + ch["_pseudo_applied"] = True except Exception: pass + if (pseudo or tags) and not ch.get("_pseudo_applied"): + _apply_symbol_pseudo( + ch.get("symbol", ""), + ch.get("kind", "unknown"), + ch.get("start", 0), + pseudo, + tags, + ) + if pseudo: payload["pseudo"] = pseudo if tags: @@ -1368,6 +1588,19 @@ def process_file_with_smart_reindexing( if all_points: _upsert_points_fn(client, current_collection, all_points) + # Optional: materialize file-level graph edges (best-effort). + # IMPORTANT: Sync must run after upserts OR after delete-only reindex to ensure graph + # edges stay consistent. When a file reindexes to zero chunks, all_points is empty but + # we still need to sync graph edges to remove stale entries. + _sync_graph_edges_best_effort( + client, + current_collection, + str(file_path), + per_file_repo, + calls, + imports, + ) + try: if set_cached_symbols: set_cached_symbols(fp, symbol_meta, file_hash) diff --git a/scripts/ingest/pseudo.py b/scripts/ingest/pseudo.py index ea157e2b..0b02db2e 100644 --- a/scripts/ingest/pseudo.py +++ b/scripts/ingest/pseudo.py @@ -7,11 +7,13 @@ """ from __future__ import annotations +import logging import os from typing import Tuple, List from scripts.ingest.config import ( get_cached_pseudo, + get_cached_symbols, set_cached_pseudo, compare_symbol_changes, ) @@ -130,25 +132,58 @@ def should_process_pseudo_for_chunk( start_line = chunk.get("start", 0) symbol_id = f"{kind}_{symbol_name}_{start_line}" + def _lookup_cached() -> Tuple[str, List[str]]: + if get_cached_pseudo: + try: + cached_pseudo, cached_tags = get_cached_pseudo(file_path, symbol_id) + if cached_pseudo or cached_tags: + return cached_pseudo, cached_tags + except Exception as exc: + logging.getLogger(__name__).debug( + "get_cached_pseudo failed for %s/%s: %s", + file_path, + symbol_id, + exc, + exc_info=True, + ) + if get_cached_symbols: + try: + cached_symbols = get_cached_symbols(file_path) or {} + for info in cached_symbols.values(): + if str(info.get("type") or "") != str(kind): + continue + if str(info.get("name") or "") != str(symbol_name): + continue + cached_pseudo = info.get("pseudo", "") + cached_tags = info.get("tags", []) + if not isinstance(cached_pseudo, str): + cached_pseudo = "" + if not isinstance(cached_tags, list): + cached_tags = [] + cached_tags = [str(tag) for tag in cached_tags if str(tag)] + if cached_pseudo or cached_tags: + return cached_pseudo, cached_tags + except Exception as exc: + logging.getLogger(__name__).debug( + "get_cached_symbols failed for %s: %s", + file_path, + exc, + exc_info=True, + ) + return "", [] + # If we don't have any change information, best effort: try reusing cached pseudo when present - if not changed_symbols and get_cached_pseudo: - try: - cached_pseudo, cached_tags = get_cached_pseudo(file_path, symbol_id) - if cached_pseudo or cached_tags: - return False, cached_pseudo, cached_tags - except Exception: - pass + if not changed_symbols: + cached_pseudo, cached_tags = _lookup_cached() + if cached_pseudo or cached_tags: + return False, cached_pseudo, cached_tags return True, "", [] # Unchanged symbol: prefer reuse when cached pseudo/tags exist if symbol_id not in changed_symbols: - if get_cached_pseudo: - try: - cached_pseudo, cached_tags = get_cached_pseudo(file_path, symbol_id) - if cached_pseudo or cached_tags: - return False, cached_pseudo, cached_tags - except Exception: - pass + cached_pseudo, cached_tags = _lookup_cached() + if cached_pseudo or cached_tags: + return False, cached_pseudo, cached_tags # Unchanged but no cached data yet – process once return True, "", [] @@ -162,7 +197,6 @@ def should_use_smart_reindexing(file_path: str, file_hash: str) -> Tuple[bool, s Returns: (use_smart, reason) """ - from scripts.ingest.config import get_cached_symbols, compare_symbol_changes from scripts.ingest.symbols import extract_symbols_with_tree_sitter if not _smart_symbol_reindexing_enabled(): diff --git a/scripts/ingest/qdrant.py b/scripts/ingest/qdrant.py index f98207ca..fb4eb957 100644 --- a/scripts/ingest/qdrant.py +++ b/scripts/ingest/qdrant.py @@ -11,9 +11,20 @@ import time import hashlib from pathlib import Path -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, TYPE_CHECKING -from qdrant_client import QdrantClient, models +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any # type: ignore + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() from scripts.ingest.config import ( LEX_VECTOR_NAME, @@ -31,6 +42,7 @@ # --------------------------------------------------------------------------- ENSURED_COLLECTIONS: set[str] = set() ENSURED_COLLECTIONS_LAST_CHECK: dict[str, float] = {} +ENSURED_PAYLOAD_INDEX_COLLECTIONS: set[str] = set() class CollectionNeedsRecreateError(Exception): @@ -535,6 +547,9 @@ def recreate_collection(client: QdrantClient, name: str, dim: int, vector_name: if not name: print("[BUG] recreate_collection called with name=None! Fix the caller - collection name is required.", flush=True) return + ENSURED_COLLECTIONS.discard(name) + ENSURED_COLLECTIONS_LAST_CHECK.pop(name, None) + ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard(name) try: client.delete_collection(name) except Exception: @@ -580,6 +595,20 @@ def recreate_collection(client: QdrantClient, name: str, dim: int, vector_name: def ensure_payload_indexes(client: QdrantClient, collection: str): """Create helpful payload indexes if they don't exist (idempotent).""" + if not collection: + return + + # On memo hit, verify collection still exists and indexes are present + if collection in ENSURED_PAYLOAD_INDEX_COLLECTIONS: + try: + info = client.get_collection(collection) + if not _missing_payload_indexes(info): + # Memo is still valid + return + except Exception: + # Collection doesn't exist or error accessing it; remove from memo + ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard(collection) + for field in PAYLOAD_INDEX_FIELDS: try: client.create_payload_index( @@ -589,6 +618,15 @@ def ensure_payload_indexes(client: QdrantClient, collection: str): ) except Exception: pass + try: + info = client.get_collection(collection) + except Exception: + return + if _missing_payload_indexes(info): + # Do not memoize; a later call should retry. + return + # Even if create_payload_index threw, get_collection confirms indexes exist. + ENSURED_PAYLOAD_INDEX_COLLECTIONS.add(collection) def ensure_collection_and_indexes_once( @@ -629,6 +667,10 @@ def ensure_collection_and_indexes_once( ENSURED_COLLECTIONS_LAST_CHECK.pop(collection, None) except Exception: pass + try: + ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard(collection) + except Exception: + pass ensure_collection(client, collection, dim, vector_name, schema_mode=mode) if mode in {"legacy", "migrate"}: ensure_payload_indexes(client, collection) diff --git a/scripts/ingest/vectors.py b/scripts/ingest/vectors.py index a4905389..27c7bae1 100644 --- a/scripts/ingest/vectors.py +++ b/scripts/ingest/vectors.py @@ -107,12 +107,10 @@ def _get_pattern_tools(): """Lazy load pattern extraction tools.""" global _PATTERN_EXTRACTOR, _PATTERN_ENCODER if _PATTERN_EXTRACTOR is None: - try: - from scripts.pattern_detection import PatternExtractor, PatternEncoder - _PATTERN_EXTRACTOR = PatternExtractor() - _PATTERN_ENCODER = PatternEncoder() - except ImportError: - pass + from scripts.pattern_detection import PatternExtractor, PatternEncoder + + _PATTERN_EXTRACTOR = PatternExtractor() + _PATTERN_ENCODER = PatternEncoder() return _PATTERN_EXTRACTOR, _PATTERN_ENCODER diff --git a/scripts/ingest_code.py b/scripts/ingest_code.py index 2da24911..1d6363b5 100644 --- a/scripts/ingest_code.py +++ b/scripts/ingest_code.py @@ -24,18 +24,12 @@ from __future__ import annotations import os -import sys import hashlib import time from pathlib import Path from datetime import datetime from typing import List, Dict, Any, Optional, TYPE_CHECKING -# Ensure project root is on sys.path when run as a script -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - from qdrant_client import QdrantClient, models # --------------------------------------------------------------------------- @@ -203,6 +197,7 @@ from scripts.ingest.pipeline import ( _detect_repo_name_from_path, + is_text_like_language, detect_language, build_information, pseudo_backfill_tick, @@ -212,6 +207,15 @@ index_repo, process_file_with_smart_reindexing, ) + +# --------------------------------------------------------------------------- +# Graph edges (optional accelerator) +# --------------------------------------------------------------------------- +from scripts.ingest.graph_edges import ( + graph_edges_backfill_tick, + delete_edges_by_path as delete_graph_edges_by_path, + upsert_file_edges as upsert_graph_edges_for_file, +) # --------------------------------------------------------------------------- # Re-exports from ingest/cli.py # --------------------------------------------------------------------------- @@ -222,11 +226,9 @@ # --------------------------------------------------------------------------- # Additional imports for backward compatibility # --------------------------------------------------------------------------- -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False +from scripts.embedder import get_embedding_model as _get_embedding_model + +_EMBEDDER_FACTORY = True if TYPE_CHECKING: from fastembed import TextEmbedding @@ -240,11 +242,9 @@ from scripts.utils import lex_hash_vector_text as _lex_hash_vector_text from scripts.utils import lex_sparse_vector_text as _lex_sparse_vector_text -try: - from scripts.ast_analyzer import get_ast_analyzer, chunk_code_semantically - _AST_ANALYZER_AVAILABLE = True -except ImportError: - _AST_ANALYZER_AVAILABLE = False +from scripts.ast_analyzer import get_ast_analyzer, chunk_code_semantically + +_AST_ANALYZER_AVAILABLE = True try: from tqdm import tqdm @@ -332,12 +332,17 @@ def main(): "embed_batch", # Pipeline "_detect_repo_name_from_path", + "is_text_like_language", "detect_language", "build_information", "index_single_file", "index_repo", "process_file_with_smart_reindexing", "pseudo_backfill_tick", + # Graph edges (optional) + "graph_edges_backfill_tick", + "delete_graph_edges_by_path", + "upsert_graph_edges_for_file", # CLI "main", # Backward compat diff --git a/scripts/ingest_history.py b/scripts/ingest_history.py index 3f42715c..522e10ff 100644 --- a/scripts/ingest_history.py +++ b/scripts/ingest_history.py @@ -4,12 +4,11 @@ import subprocess import shlex import hashlib +import logging from typing import List, Dict, Any import re import time import json -import sys -from pathlib import Path from qdrant_client import QdrantClient, models @@ -19,22 +18,13 @@ API_KEY = os.environ.get("QDRANT_API_KEY") REPO_NAME = os.environ.get("REPO_NAME", "workspace") -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -# Import TextEmbedding for type hints and fallback -from fastembed import TextEmbedding - -# Use embedder factory for Qwen3 support; fallback to direct fastembed -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False +from scripts.embedder import get_embedding_model as _get_embedding_model from scripts.utils import sanitize_vector_name as _sanitize_vector_name +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + def _manifest_run_id(manifest_path: str) -> str: try: @@ -365,30 +355,60 @@ def _ingest_from_manifest( vec_name: str, include_body: bool, per_batch: int, -) -> int: +) -> tuple[int, bool]: try: with open(manifest_path, "r", encoding="utf-8") as f: data = json.load(f) except Exception as e: print(f"Failed to read manifest {manifest_path}: {e}") - return 0 + return 0, False commits = data.get("commits") or [] if not commits: print("No commits in manifest.") - return 0 + return 0, False run_id = _manifest_run_id(manifest_path) mode = str(data.get("mode") or "delta").strip().lower() or "delta" points: List[models.PointStruct] = [] - count = 0 - for c in commits: + total_commits = len(commits) + prepared_count = 0 + persisted_count = 0 + invalid_commit_records = 0 + embed_failures = 0 + point_build_failures = 0 + upsert_failures = 0 + processed_count = 0 + progress_step = max(1, total_commits // 10) if total_commits > 0 else 1 + + def _log_progress(force: bool = False) -> None: + if not force and processed_count % progress_step != 0: + return + logger.info( + "[ingest_history] progress run_id=%s processed=%d/%d prepared=%d persisted=%d invalid=%d embed_failures=%d point_failures=%d upsert_failures=%d", + run_id, + processed_count, + total_commits, + prepared_count, + persisted_count, + invalid_commit_records, + embed_failures, + point_build_failures, + upsert_failures, + ) + + for idx, c in enumerate(commits, start=1): + processed_count += 1 try: if not isinstance(c, dict): + invalid_commit_records += 1 + _log_progress() continue commit_id = str(c.get("commit_id") or "").strip() if not commit_id: + invalid_commit_records += 1 + _log_progress() continue author_name = str(c.get("author_name") or "") authored_date = str(c.get("authored_date") or "") @@ -406,7 +426,15 @@ def _ingest_from_manifest( text = build_text(md, include_body=include_body) try: vec = next(model.embed([text])).tolist() - except Exception: + except Exception as e: + embed_failures += 1 + logger.warning( + "[ingest_history] embed failed for commit=%s idx=%d: %s", + commit_id, + idx, + e, + ) + _log_progress() continue goal: str = "" @@ -451,28 +479,96 @@ def _ingest_from_manifest( pid = stable_id(commit_id) pt = models.PointStruct(id=pid, vector={vec_name: vec}, payload=payload) points.append(pt) - count += 1 + prepared_count += 1 if len(points) >= per_batch: - client.upsert(collection_name=COLLECTION, points=points) - points.clear() + batch_size = len(points) + try: + client.upsert(collection_name=COLLECTION, points=points) + persisted_count += batch_size + except Exception as e: + upsert_failures += batch_size + logger.exception( + "[ingest_history] upsert batch failed (size=%d): %s", + batch_size, + e, + ) + finally: + points.clear() + _log_progress() except Exception: + point_build_failures += 1 + logger.warning( + "[ingest_history] commit processing failed idx=%d", + idx, + exc_info=True, + ) + _log_progress() continue if points: - client.upsert(collection_name=COLLECTION, points=points) - try: - _prune_old_commit_points(client, run_id, mode=mode) - except Exception: - pass - try: - _cleanup_manifest_files(manifest_path) - except Exception: - pass - print(f"Ingested {count} commits into {COLLECTION} from manifest {manifest_path}.") - return count + batch_size = len(points) + try: + client.upsert(collection_name=COLLECTION, points=points) + persisted_count += batch_size + except Exception as e: + upsert_failures += batch_size + logger.exception( + "[ingest_history] final upsert failed (size=%d): %s", + batch_size, + e, + ) + _log_progress(force=True) + ingest_successful = ( + prepared_count > 0 + and invalid_commit_records == 0 + and embed_failures == 0 + and point_build_failures == 0 + and upsert_failures == 0 + and persisted_count == prepared_count + ) + # Only prune snapshot runs that completed cleanly + prune_safe = mode == "snapshot" and ingest_successful + if prune_safe: + try: + _prune_old_commit_points(client, run_id, mode=mode) + except Exception as e: + logger.warning("[ingest_history] prune failed for run_id=%s: %s", run_id, e) + elif mode == "snapshot": + logger.warning( + "[ingest_history] skipping prune for run_id=%s because the snapshot ingest was incomplete", + run_id, + ) + + # Only cleanup manifest if ingest completed successfully + ingest_complete = ingest_successful + if ingest_complete: + try: + _cleanup_manifest_files(manifest_path) + except Exception as e: + logger.warning("[ingest_history] manifest cleanup failed for %s: %s", manifest_path, e) + else: + logger.warning( + "[ingest_history] keeping manifest %s because ingest was incomplete", + manifest_path, + ) + + logger.info( + "Ingested commits from manifest %s into %s: persisted=%d prepared=%d invalid=%d " + "embed_failures=%d point_failures=%d upsert_failures=%d", + manifest_path, + COLLECTION, + persisted_count, + prepared_count, + invalid_commit_records, + embed_failures, + point_build_failures, + upsert_failures, + ) + return persisted_count, ingest_complete def main(): + logging.basicConfig(level=logging.INFO) ap = argparse.ArgumentParser( description="Ingest Git history into Qdrant deterministically" ) @@ -512,16 +608,12 @@ def main(): ) args = ap.parse_args() - # Use embedder factory for Qwen3 support - if _EMBEDDER_FACTORY: - model = _get_embedding_model(MODEL_NAME) - else: - model = TextEmbedding(model_name=MODEL_NAME) + model = _get_embedding_model(MODEL_NAME) vec_name = _sanitize_vector_name(MODEL_NAME) client = QdrantClient(url=QDRANT_URL, api_key=API_KEY or None) if args.manifest_json: - _ingest_from_manifest( + persisted_count, ingest_complete = _ingest_from_manifest( args.manifest_json, model, client, @@ -529,6 +621,8 @@ def main(): args.include_body, args.per_batch, ) + if not ingest_complete: + raise SystemExit(1) return commits = list_commits(args) @@ -537,6 +631,8 @@ def main(): return points: List[models.PointStruct] = [] + persisted_count = 0 + upsert_failures = 0 for sha in commits: md = commit_metadata(sha) text = build_text(md, include_body=args.include_body) @@ -583,11 +679,40 @@ def main(): point = models.PointStruct(id=pid, vector={vec_name: vec}, payload=payload) points.append(point) if len(points) >= args.per_batch: - client.upsert(collection_name=COLLECTION, points=points) - points.clear() + batch_size = len(points) + try: + client.upsert(collection_name=COLLECTION, points=points) + persisted_count += batch_size + except Exception as e: + upsert_failures += batch_size + logger.exception( + "[ingest_history] batch upsert failed collection=%s repo=%s size=%d path=%s: %s", + COLLECTION, + REPO_NAME, + batch_size, + args.path or "", + e, + ) + finally: + points.clear() if points: - client.upsert(collection_name=COLLECTION, points=points) - print(f"Ingested {len(commits)} commits into {COLLECTION}.") + final_size = len(points) + try: + client.upsert(collection_name=COLLECTION, points=points) + persisted_count += final_size + except Exception as e: + upsert_failures += final_size + logger.exception( + "[ingest_history] final upsert failed collection=%s repo=%s size=%d path=%s: %s", + COLLECTION, + REPO_NAME, + final_size, + args.path or "", + e, + ) + if upsert_failures: + raise SystemExit(1) + print(f"Ingested {persisted_count} commits into {COLLECTION}.") if __name__ == "__main__": diff --git a/scripts/k8s_uploader.py b/scripts/k8s_uploader.py index 4f2947c3..03e6d1a8 100755 --- a/scripts/k8s_uploader.py +++ b/scripts/k8s_uploader.py @@ -173,8 +173,6 @@ def trigger_indexing( # Build Python command to call qdrant_index via MCP server # Use qdrant_index with subdir parameter to index specific repo python_cmd = f""" -import sys -sys.path.insert(0, '/app') from scripts.mcp_indexer_server import qdrant_index import asyncio import json @@ -309,4 +307,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/learning_reranker_worker.py b/scripts/learning_reranker_worker.py index 39cff7e5..48c72afe 100644 --- a/scripts/learning_reranker_worker.py +++ b/scripts/learning_reranker_worker.py @@ -26,34 +26,26 @@ import argparse import json import os -import sys import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np -# Add project root to path -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -from scripts.rerank_events import ( +from scripts.rerank_tools.events import ( RERANK_EVENTS_DIR, RERANK_EVENTS_RETENTION_DAYS, read_events, list_event_files, cleanup_old_events, ) -from scripts.rerank_recursive import ( - TinyScorer, - LatentRefiner, - RecursiveReranker, - VICReg, - LearnedProjection, - LearnedHybridWeights, - QueryExpander, -) +from scripts.rerank_recursive.expander import QueryExpander +from scripts.rerank_recursive.hybrid_weights import LearnedHybridWeights +from scripts.rerank_recursive.projection import LearnedProjection +from scripts.rerank_recursive.recursive import RecursiveReranker +from scripts.rerank_recursive.refiner import LatentRefiner +from scripts.rerank_recursive.scorer import TinyScorer +from scripts.rerank_recursive.vicreg import VICReg # Configuration BATCH_SIZE = int(os.environ.get("RERANK_LEARNING_BATCH_SIZE", "32")) @@ -540,7 +532,7 @@ def _learn_from_batch(self, events: List[Dict[str, Any]]): def _maybe_fill_teacher_scores(self, events: List[Dict[str, Any]]): """Compute teacher scores for events that don't already have them.""" try: - from scripts.rerank_local import rerank_local + from scripts.rerank_tools.local import rerank_local except Exception: rerank_local = None diff --git a/scripts/mcp_admin_tools.py b/scripts/mcp_admin_tools.py deleted file mode 100644 index c4ea0378..00000000 --- a/scripts/mcp_admin_tools.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/mcp/admin_tools.py""" -from scripts.mcp_impl.admin_tools import * - diff --git a/scripts/mcp_auth.py b/scripts/mcp_auth.py index 2b13c791..2ae3d220 100644 --- a/scripts/mcp_auth.py +++ b/scripts/mcp_auth.py @@ -1,45 +1,13 @@ import os from typing import Any, Dict, Optional -try: - from scripts.logger import ValidationError -except Exception: - - class ValidationError(Exception): - pass - - -try: - from scripts.auth_backend import ( - AUTH_ENABLED as AUTH_ENABLED_AUTH, - ACL_ALLOW_ALL as ACL_ALLOW_ALL_AUTH, - validate_session as _auth_validate_session, - has_collection_access as _has_collection_access, - ) -except Exception as _auth_backend_import_exc: - _AUTH_BACKEND_IMPORT_ERROR = repr(_auth_backend_import_exc) - AUTH_ENABLED_AUTH = ( - str(os.environ.get("CTXCE_AUTH_ENABLED", "0")).strip().lower() in {"1", "true", "yes", "on"} - ) - ACL_ALLOW_ALL_AUTH = ( - str(os.environ.get("CTXCE_ACL_ALLOW_ALL", "0")).strip().lower() in {"1", "true", "yes", "on"} - ) - - def _auth_validate_session(session_id: str): # type: ignore[no-redef] - if AUTH_ENABLED_AUTH: - raise ValidationError( - f"Auth backend unavailable (import failed): {_AUTH_BACKEND_IMPORT_ERROR}" - ) - return None - - def _has_collection_access( - user_id: str, qdrant_collection: str, permission: str = "read" - ) -> bool: # type: ignore[no-redef] - if AUTH_ENABLED_AUTH: - raise ValidationError( - f"Auth backend unavailable (import failed): {_AUTH_BACKEND_IMPORT_ERROR}" - ) - return True +from scripts.logger import ValidationError +from scripts.auth_backend import ( + AUTH_ENABLED as AUTH_ENABLED_AUTH, + ACL_ALLOW_ALL as ACL_ALLOW_ALL_AUTH, + validate_session as _auth_validate_session, + has_collection_access as _has_collection_access, +) ACL_ENFORCE = ( diff --git a/scripts/mcp_code_signals.py b/scripts/mcp_code_signals.py deleted file mode 100644 index 29027b55..00000000 --- a/scripts/mcp_code_signals.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/mcp/code_signals.py""" -from scripts.mcp_impl.code_signals import * - diff --git a/scripts/mcp_context_answer.py b/scripts/mcp_context_answer.py deleted file mode 100644 index 2259671b..00000000 --- a/scripts/mcp_context_answer.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -""" -Backward-compatibility shim for scripts.mcp_context_answer. - -New location: scripts.mcp.context_answer -""" -from scripts.mcp_impl.context_answer import * # noqa: F401,F403 - diff --git a/scripts/mcp_http_client.py b/scripts/mcp_http_client.py new file mode 100644 index 00000000..5b01a56a --- /dev/null +++ b/scripts/mcp_http_client.py @@ -0,0 +1,153 @@ +"""Small HTTP client for calling MCP tools. + +This is intentionally transport glue, not an intent router. Agent/tool selection +belongs to the MCP client using the exposed tools directly. +""" +from __future__ import annotations + +import json +import time +from typing import Any, Dict, Tuple +from urllib import request + + +def _post_raw(url: str, payload: Dict[str, Any], headers: Dict[str, str], timeout: float = 60.0) -> Tuple[Dict[str, str], bytes]: + req = request.Request(url, method="POST") + for k, v in headers.items(): + req.add_header(k, v) + data = json.dumps(payload).encode("utf-8") + with request.urlopen(req, data=data, timeout=timeout) as resp: + body = resp.read() + hdrs = {k.lower(): v for k, v in resp.headers.items()} + return hdrs, body + + +def _post_raw_retry(url: str, payload: Dict[str, Any], headers: Dict[str, str], timeout: float = 60.0, retries: int = 2, backoff: float = 0.5) -> Tuple[Dict[str, str], bytes]: + last_exc: Exception | None = None + for i in range(max(0, retries) + 1): + try: + return _post_raw(url, payload, headers, timeout=timeout) + except Exception as e: + last_exc = e + if i < retries: + try: + time.sleep(backoff * (2 ** i)) + except Exception: + pass + else: + raise last_exc + raise last_exc or RuntimeError("MCP HTTP request failed") + + +def _parse_stream_or_json(body: bytes) -> Dict[str, Any]: + txt = body.decode("utf-8", errors="ignore") + if "data:" in txt and ("event:" in txt or txt.strip().startswith("data:")): + last = None + for line in txt.splitlines(): + if line.startswith("data:"): + last = line[len("data:"):].strip() + if last: + try: + return json.loads(last) + except Exception: + pass + return json.loads(txt) + + +def _filter_args(d: Dict[str, Any]) -> Dict[str, Any]: + return {k: v for k, v in d.items() if v not in (None, "")} + + +def _mcp_handshake(base_url: str, timeout: float = 30.0) -> Dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + init_payload = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "context-engine-http-client", "version": "0.1.0"}, + }, + "id": 1, + } + hdrs, body = _post_raw_retry(base_url, init_payload, headers, timeout=timeout) + sid = hdrs.get("mcp-session-id") or hdrs.get("Mcp-Session-Id") + if not sid: + try: + j = _parse_stream_or_json(body) + sid = j.get("sessionId") + except Exception: + sid = None + if sid: + headers["Mcp-Session-Id"] = sid + try: + _post_raw_retry(base_url, {"jsonrpc": "2.0", "method": "notifications/initialized"}, headers, timeout=timeout) + except Exception: + pass + return headers + + +def _extract_iserror_text(resp: Dict[str, Any]) -> str | None: + try: + r = resp.get("result") or {} + if isinstance(r, dict) and r.get("isError"): + content = r.get("content") + if isinstance(content, list) and content and isinstance(content[0], dict): + if content[0].get("type") == "text": + return content[0].get("text") + except Exception: + pass + return None + + +def call_tool_http(base_url: str, tool_name: str, args: Dict[str, Any], timeout: float = 120.0) -> Dict[str, Any]: + """Call an MCP tool over streamable HTTP.""" + headers = _mcp_handshake(base_url, timeout=min(timeout, 30.0)) + + def _do_call(arguments: Dict[str, Any]) -> Dict[str, Any]: + payload = { + "jsonrpc": "2.0", + "id": "mcp-http-client-1", + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments}, + } + _, body = _post_raw_retry(base_url, payload, headers, timeout=timeout) + return _parse_stream_or_json(body) + + args1 = _filter_args(args or {}) + resp = _do_call({"arguments": args1} if tool_name.endswith("_compat") else args1) + + def _get_structured_error(r: Dict[str, Any]) -> str | None: + try: + rr = r.get("result") or {} + sc = rr.get("structuredContent") or {} + rs = sc.get("result") or {} + err = rs.get("error") + if isinstance(err, str): + return err + except Exception: + pass + return None + + msg = _extract_iserror_text(resp) + serr = _get_structured_error(resp) + if msg: + low = msg.lower() + if ("kwargs" in low) and ("field required" in low or "missing" in low): + return _do_call({"kwargs": args1}) + if ("arguments" in low) and ("field required" in low or "missing" in low): + return _do_call({"arguments": args1}) + if (serr and serr.strip().lower() == "query required") and ("query" in args1 or "queries" in args1): + resp4 = _do_call({"kwargs": args1}) + serr2 = _get_structured_error(resp4) + if not (serr2 and serr2.strip().lower() == "query required"): + return resp4 + resp5 = _do_call({"arguments": {"kwargs": args1}}) + serr3 = _get_structured_error(resp5) + if not (serr3 and serr3.strip().lower() == "query required"): + return resp5 + return _do_call({"arguments": args1}) + return resp diff --git a/scripts/mcp_impl/__init__.py b/scripts/mcp_impl/__init__.py index 4a033f9d..2198055d 100644 --- a/scripts/mcp_impl/__init__.py +++ b/scripts/mcp_impl/__init__.py @@ -1,36 +1,9 @@ -""" -MCP (Model Context Protocol) indexer server package. - -This package contains extracted modules from mcp_indexer_server.py: -- utils: Type coercion, JSON parsing, tokenization, env helpers -- toon: TOON output format support -- workspace: Workspace state and collection resolution -- admin_tools: Qdrant admin operations (index, list, status, prune) -- code_signals: Code intent detection -- context_answer: LLM-assisted Q&A with retrieval -- context_search: Blended code + memory search -- query_expand: LLM-assisted query expansion +"""MCP indexer server implementation package. -Usage: - from scripts.mcp_impl import utils, toon, workspace - from scripts.mcp_impl.utils import _coerce_bool, _env_overrides - from scripts.mcp_impl.workspace import _default_collection - from scripts.mcp_impl.context_search import _context_search_impl - from scripts.mcp_impl.query_expand import _expand_query_impl +Submodules are intentionally not imported here. Several helpers pull service +stacks such as Qdrant, FastMCP, or search wiring, and package import should stay +cheap for tests and small utility imports. """ -from scripts.mcp_impl import utils -from scripts.mcp_impl import toon -from scripts.mcp_impl import workspace -from scripts.mcp_impl import admin_tools -from scripts.mcp_impl import code_signals -from scripts.mcp_impl import context_answer -from scripts.mcp_impl import context_search -from scripts.mcp_impl import query_expand -from scripts.mcp_impl import search -from scripts.mcp_impl import info_request -from scripts.mcp_impl import memory -from scripts.mcp_impl import search_specialized -from scripts.mcp_impl import search_history __all__ = [ "utils", @@ -42,9 +15,7 @@ "context_search", "query_expand", "search", - "info_request", "memory", - "search_specialized", + "search_profiles", "search_history", ] - diff --git a/scripts/mcp_impl/admin_tools.py b/scripts/mcp_impl/admin_tools.py index f8f39962..f02003b6 100644 --- a/scripts/mcp_impl/admin_tools.py +++ b/scripts/mcp_impl/admin_tools.py @@ -6,7 +6,6 @@ Contains: - Subprocess runner (_run_async) - Embedding model cache (_get_embedding_model) -- Router cache invalidation (_invalidate_router_scratchpad) - Repo detection (_detect_current_repo) Note: The @mcp.tool() decorated functions remain in mcp_indexer_server.py @@ -22,7 +21,6 @@ # Functions "_run_async", "_get_embedding_model", - "_invalidate_router_scratchpad", "_detect_current_repo", "_collection_map_impl", ] @@ -30,6 +28,7 @@ import asyncio import logging import os +import subprocess import threading from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -45,39 +44,10 @@ def _get_embedding_model(model_name: str): - """Get cached embedding model with optional Qwen3 support. + """Get cached embedding model via the centralized embedder factory.""" + from scripts.embedder import get_embedding_model - Uses the centralized embedder factory if available, with fallback - to direct fastembed initialization for backwards compatibility. - """ - # Try centralized embedder factory first (supports Qwen3 feature flag) - try: - from scripts.embedder import get_embedding_model - return get_embedding_model(model_name) - except ImportError: - pass - - # Fallback to original implementation - try: - from fastembed import TextEmbedding # type: ignore - except Exception: - raise - - m = _EMBED_MODEL_CACHE.get(model_name) - if m is None: - # Double-checked locking to avoid duplicate inits under concurrency - lock = _EMBED_MODEL_LOCKS.setdefault(model_name, threading.Lock()) - with lock: - m = _EMBED_MODEL_CACHE.get(model_name) - if m is None: - m = TextEmbedding(model_name=model_name) - try: - # Warmup with common patterns to optimize internal caches - _ = list(m.embed(["function", "class", "import", "def", "const"])) - except Exception: - pass - _EMBED_MODEL_CACHE[model_name] = m - return m + return get_embedding_model(model_name) # --------------------------------------------------------------------------- @@ -98,21 +68,6 @@ async def _run_async( return await run_subprocess_async(cmd, timeout=timeout, env=env) -# --------------------------------------------------------------------------- -# Router cache invalidation -# --------------------------------------------------------------------------- -def _invalidate_router_scratchpad(workspace_path: str) -> bool: - """Invalidate any cached router scratchpad for the workspace. - - This is called after indexing operations to ensure the router - picks up new/changed code. Returns True if invalidation occurred. - """ - try: - # Clear any in-memory caches that might be stale - return True - except Exception: - return False - # --------------------------------------------------------------------------- # Repo detection @@ -123,8 +78,7 @@ def _detect_current_repo() -> Optional[str]: Priority: 1. CURRENT_REPO env var (explicitly set) 2. REPO_NAME env var - 3. Detect from /work directory structure (first subdirectory with .git) - 4. Git remote origin name + 3. Bindmount git detection when CTXCE_BINDMOUNT_REPO_DETECTION=1 Returns: repo name or None if detection fails """ @@ -134,34 +88,40 @@ def _detect_current_repo() -> Optional[str]: if val: return val - # Try to detect from /work directory + try: + from scripts.workspace_state import bindmount_repo_detection_enabled + + allow_git_detection = bindmount_repo_detection_enabled() + except Exception: + allow_git_detection = False + + if not allow_git_detection: + return None + + # Bindmount detection from /work. Do not guess from invalid/internal + # metadata: a leaked /work/.git must not become repo "work". work_path = Path("/work") if work_path.exists(): try: - # Check for .git in /work itself if (work_path / ".git").exists(): - # Use git to get repo name from remote - try: - import subprocess - result = subprocess.run( - ["git", "-C", str(work_path), "config", "--get", "remote.origin.url"], - capture_output=True, text=True, timeout=5 - ) - if result.returncode == 0 and result.stdout.strip(): - url = result.stdout.strip() - # Extract repo name from URL - name = url.rstrip("/").rsplit("/", 1)[-1] - if name.endswith(".git"): - name = name[:-4] - if name: - return name - except Exception: - pass - # Fallback to directory name - return work_path.name - - # Check subdirectories for repos + result = subprocess.run( + ["git", "-C", str(work_path), "config", "--get", "remote.origin.url"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and result.stdout.strip(): + url = result.stdout.strip() + name = url.rstrip("/").rsplit("/", 1)[-1] + if name.endswith(".git"): + name = name[:-4] + if name: + return name + + internal_dirs = {".codebase", ".git", "__pycache__"} for subdir in work_path.iterdir(): + if subdir.name in internal_dirs: + continue if subdir.is_dir() and (subdir / ".git").exists(): return subdir.name except Exception: diff --git a/scripts/mcp_impl/context_answer.py b/scripts/mcp_impl/context_answer.py index 54877976..f17b8793 100644 --- a/scripts/mcp_impl/context_answer.py +++ b/scripts/mcp_impl/context_answer.py @@ -50,7 +50,8 @@ _primary_identifier_from_queries, ) from scripts.mcp_impl.workspace import _default_collection -from scripts.logger import safe_int, ValidationError +from scripts.logger import safe_bool, safe_float, safe_int, ValidationError +from scripts.refrag_glm import detect_glm_runtime, get_glm_model_name, get_model_config logger = logging.getLogger(__name__) @@ -114,11 +115,7 @@ def _cleanup_answer(text: str, max_chars: int | None = None) -> str: def _answer_style_guidance() -> str: """Compact instruction to keep answers direct and grounded.""" - try: - from scripts.refrag_glm import detect_glm_runtime - is_glm = detect_glm_runtime() - except ImportError: - is_glm = False + is_glm = detect_glm_runtime() if is_glm: sentence_guidance = "Write a clear, comprehensive answer in 4-8 sentences." @@ -233,11 +230,7 @@ def _answer_style_guidance() -> str: GLM models get more generous guidance (4-8 sentences) since they handle longer outputs better than Granite-4.0-Micro which needs strict 2-4 sentence limits. """ - try: - from scripts.refrag_glm import detect_glm_runtime - is_glm = detect_glm_runtime() - except ImportError: - is_glm = False + is_glm = detect_glm_runtime() if is_glm: # GLM models can handle longer, more detailed answers @@ -1013,30 +1006,6 @@ def _ok_lang(it: Dict[str, Any]) -> bool: model=model, repo=repo, # Cross-codebase isolation ) - # Ensure last call reflects tier-2 relaxed filters for introspection/testing - _ = run_hybrid_search( - queries=queries, - limit=int(max(lim, 1)), - per_path=int(max(ppath, 1)), - language=eff_language, - under=override_under or None, - kind=None, - symbol=None, - ext=None, - not_filter=(not_ or kwargs.get("not_") or kwargs.get("not") or None), - case=(case or kwargs.get("case") or None), - path_regex=None, - path_glob=None, - not_glob=eff_not_glob, - expand=False - if did_local_expand - else ( - str(os.environ.get("HYBRID_EXPAND", "0")).strip().lower() - in {"1", "true", "yes", "on"} - ), - model=model, - repo=repo, # Cross-codebase isolation - ) if os.environ.get("DEBUG_CONTEXT_ANSWER"): logger.debug( @@ -1518,58 +1487,56 @@ def _ok_lang(it: Dict[str, Any]) -> bool: # Filter out memory-like items without a valid path to avoid empty citations items = [it for it in items if str(it.get("path") or "").strip()] - # Apply ReFRAG span budgeting to compress context - from scripts.hybrid_search import _merge_and_budget_spans # type: ignore - - try: - if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("BUDGET_BEFORE", extra={"items": len(items)}) - _pairs = {} + if items and all(isinstance(it, dict) and it.get("span_budgeted") for it in items): + budgeted = items + else: try: - # Relax budgets for context_answer unless explicitly disabled via CTX_RELAX_BUDGETS=0 - if str(os.environ.get("CTX_RELAX_BUDGETS", "1")).strip().lower() in { - "1", - "true", - "yes", - "on", - }: - # GLM models have much larger context windows - use higher budgets - try: - from scripts.refrag_glm import detect_glm_runtime + from scripts.hybrid_search import _merge_and_budget_spans # type: ignore + + if os.environ.get("DEBUG_CONTEXT_ANSWER"): + logger.debug("BUDGET_BEFORE", extra={"items": len(items)}) + _pairs = {} + try: + # Relax budgets for context_answer unless explicitly disabled via CTX_RELAX_BUDGETS=0 + if str(os.environ.get("CTX_RELAX_BUDGETS", "1")).strip().lower() in { + "1", + "true", + "yes", + "on", + }: + # GLM models have much larger context windows - use higher budgets is_glm = detect_glm_runtime() - except ImportError: - is_glm = False - - if is_glm: - # GLM: 200K context allows much more code context - _default_budget = "8192" # 8x more than Granite - _default_spans = "24" # 3x more spans - else: - # Granite/llamacpp: tighter limits - _default_budget = "1024" - _default_spans = "8" - - _pairs = { - "MICRO_BUDGET_TOKENS": os.environ.get( - "MICRO_BUDGET_TOKENS", _default_budget - ), - "MICRO_OUT_MAX_SPANS": os.environ.get("MICRO_OUT_MAX_SPANS", _default_spans), - } - except Exception: - _pairs = {"MICRO_BUDGET_TOKENS": "5000", "MICRO_OUT_MAX_SPANS": "8"} - with _env_overrides(_pairs): - budgeted = _merge_and_budget_spans(items) - if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("BUDGET_AFTER", extra={"items": len(budgeted)}) - if not budgeted and items: + + if is_glm: + # GLM: 200K context allows much more code context + _default_budget = "8192" # 8x more than Granite + _default_spans = "24" # 3x more spans + else: + # Granite/llamacpp: tighter limits + _default_budget = "1024" + _default_spans = "8" + + _pairs = { + "MICRO_BUDGET_TOKENS": os.environ.get( + "MICRO_BUDGET_TOKENS", _default_budget + ), + "MICRO_OUT_MAX_SPANS": os.environ.get("MICRO_OUT_MAX_SPANS", _default_spans), + } + except Exception: + _pairs = {"MICRO_BUDGET_TOKENS": "5000", "MICRO_OUT_MAX_SPANS": "8"} + with _env_overrides(_pairs): + budgeted = _merge_and_budget_spans(items) + if os.environ.get("DEBUG_CONTEXT_ANSWER"): + logger.debug("BUDGET_AFTER", extra={"items": len(budgeted)}) + if not budgeted and items: + if os.environ.get("DEBUG_CONTEXT_ANSWER"): + logger.debug("BUDGET_EMPTY_FALLBACK") + budgeted = items + except (ImportError, AttributeError, KeyError): + logger.warning("Span budgeting failed, using raw items", exc_info=True) if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("BUDGET_EMPTY_FALLBACK") + logger.debug("BUDGET_FAILED", exc_info=True) budgeted = items - except (ImportError, AttributeError, KeyError): - logger.warning("Span budgeting failed, using raw items", exc_info=True) - if os.environ.get("DEBUG_CONTEXT_ANSWER"): - logger.debug("BUDGET_FAILED", exc_info=True) - budgeted = items # Enforce an output max spans knob - do this BEFORE env restore try: @@ -2055,11 +2022,7 @@ def _to_float(v, d): # Granite/llamacpp: use env var or 2000 default # GLM: dynamically use model's max_output_tokens from config - try: - from scripts.refrag_glm import detect_glm_runtime, get_glm_model_name, get_model_config - is_glm = detect_glm_runtime() - except ImportError: - is_glm = False + is_glm = detect_glm_runtime() if is_glm: # Pull dynamic limit from GLM model config (imports already succeeded above) @@ -2502,22 +2465,6 @@ async def _context_answer_impl( import time import asyncio - # Import logger utilities - try: - from scripts.logger import safe_bool, safe_float - except ImportError: - def safe_bool(val, default=False, **kw): - if val is None: - return default - if isinstance(val, bool): - return val - return str(val).strip().lower() in {"1", "true", "yes", "on"} - def safe_float(val, default=0.0, **kw): - try: - return float(val) if val is not None else default - except Exception: - return default - # Get embedding model function if get_embedding_model_fn is None: from scripts.mcp_impl.admin_tools import _get_embedding_model @@ -2784,19 +2731,6 @@ def safe_float(val, default=0.0, **kw): "query": original_queries, } - # Ensure final retrieval call reflects Tier-2 relaxed filters - try: - from scripts.hybrid_search import run_hybrid_search as _rh - await asyncio.to_thread( - lambda: _rh( - queries=queries, - limit=int(max(lim, 1)), - per_path=int(max(ppath, 1)), - ) - ) - except Exception: - pass - # Build citations and context payload for the decoder ( citations, @@ -3040,17 +2974,6 @@ def _k(s: Dict[str, Any]): "query": original_queries, } - # Final introspection call - try: - from scripts.hybrid_search import run_hybrid_search as _rh2 - _ = _rh2( - queries=queries, - limit=int(max(lim, 1)), - per_path=int(max(ppath, 1)), - ) - except Exception: - pass - # Optional: provide per-query answers/citations for pack mode answers_by_query = None try: diff --git a/scripts/mcp_impl/context_search.py b/scripts/mcp_impl/context_search.py index 14ffa4a5..701828d4 100644 --- a/scripts/mcp_impl/context_search.py +++ b/scripts/mcp_impl/context_search.py @@ -37,6 +37,7 @@ ) from scripts.mcp_impl.workspace import _default_collection, _MEM_COLL_CACHE from scripts.mcp_impl.toon import _should_use_toon, _format_context_results_as_toon +from scripts.mcp_http_client import call_tool_http # Environment QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") @@ -681,8 +682,6 @@ def _maybe_dict(val: Any) -> Dict[str, Any]: used_http_fallback = False if not code_hits: try: - from scripts.mcp_router import call_tool_http # type: ignore - base = ( os.environ.get("MCP_INDEXER_HTTP_URL") or "http://localhost:8003/mcp" ).rstrip("/") @@ -1254,4 +1253,3 @@ def push_text( if _should_use_toon(output_format): return _format_context_results_as_toon(ret, compact=bool(eff_compact)) return ret - diff --git a/scripts/mcp_impl/info_request.py b/scripts/mcp_impl/info_request.py deleted file mode 100644 index e9c72601..00000000 --- a/scripts/mcp_impl/info_request.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -""" -mcp/info_request.py - Info request helpers for MCP indexer server. - -Extracted from mcp_indexer_server.py for better modularity. -Contains: -- Helper functions for info_request tool -""" - -from __future__ import annotations - -__all__ = [ - "_extract_symbols_from_query", - "_extract_related_concepts", - "_format_information_field", - "_extract_relationships", - "_calculate_confidence", -] - -import re -import logging -from typing import Any, Dict, List - -logger = logging.getLogger(__name__) - -# Import _split_ident for tokenization -from scripts.mcp_impl.utils import _split_ident - - -def _extract_symbols_from_query(query: str) -> list[str]: - """Extract potential symbol names from a query string.""" - # Match CamelCase, snake_case, or standalone words that look like identifiers - patterns = [ - r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', # CamelCase - r'\b[a-z_][a-z0-9_]*(?:_[a-z0-9]+)+\b', # snake_case - r'\b(?:def|class|function|method|async)\s+(\w+)', # explicit mentions - ] - symbols = set() - for pat in patterns: - for m in re.finditer(pat, query): - sym = m.group(1) if m.lastindex else m.group(0) - if len(sym) > 2: - symbols.add(sym) - return list(symbols)[:5] # Limit to top 5 - - -def _extract_related_concepts(query: str, results: list) -> list[str]: - """Extract related technical concepts dynamically from results (codebase-agnostic).""" - concepts = set() - - # Extract from results - this works on any codebase - for r in results[:10]: - # From symbols: split CamelCase/snake_case into meaningful parts - sym = r.get("symbol", "") or "" - if sym and len(sym) > 2: - parts = [p for p in re.split(r'(?=[A-Z])|_|-', sym) if p and len(p) > 2] - for part in parts[:3]: - concepts.add(part.lower()) - - # From file paths: extract directory/module names - path = r.get("path", "") or "" - if path: - path_parts = path.replace("\\", "/").split("/") - for pp in path_parts[-3:]: # Last 3 path segments - # Remove extension and split - name = pp.rsplit(".", 1)[0] if "." in pp else pp - if name and len(name) > 2 and not name.startswith("_"): - concepts.add(name.lower()) - - # From kind: function, class, method, etc. - kind = r.get("kind", "") or "" - if kind and len(kind) > 2: - concepts.add(kind.lower()) - - # From query: extract significant words (skip common words) - skip_words = {"the", "is", "are", "how", "does", "what", "where", "find", "get", "set", "for", "and", "with"} - query_parts = re.split(r'\W+', query.lower()) - for qp in query_parts: - if qp and len(qp) > 2 and qp not in skip_words: - concepts.add(qp) - - # Sort by frequency in results for relevance - return list(concepts)[:10] - - -def _format_information_field(result: dict) -> str: - """Generate human-readable information field for a result.""" - path = result.get("path", "") - symbol = result.get("symbol", "") - start = result.get("start_line", 0) - end = result.get("end_line", 0) - kind = result.get("kind", "") - - # Get just the filename - filename = path.split("/")[-1] if "/" in path else path - - if symbol and kind: - return f"Found {kind} '{symbol}' in {filename} (lines {start}-{end})" - elif symbol: - return f"Found '{symbol}' in {filename} (lines {start}-{end})" - else: - return f"Found match in {filename} (lines {start}-{end})" - - -def _extract_relationships(result: dict) -> dict: - """Extract relationship metadata (imports, calls) from a result.""" - relations = result.get("relations") or {} - # Get from relations object if present - imports = relations.get("imports") or [] - calls = relations.get("calls") or [] - symbol_path = relations.get("symbol_path") or "" - # Also check top-level metadata (fallback) - if not imports: - imports = result.get("imports") or [] - if not calls: - calls = result.get("calls") or [] - # Get related paths if available - related_paths = result.get("related_paths") or [] - - return { - "imports_from": imports[:10] if imports else [], # Limit to 10 - "calls": calls[:10] if calls else [], - "symbol_path": symbol_path, - "related_paths": related_paths[:5] if related_paths else [], - } - - -def _calculate_confidence(query: str, results: list) -> dict: - """Calculate confidence metrics for the search.""" - if not results: - return {"level": "none", "score": 0.0, "reason": "no_results"} - - avg_score = sum(r.get("score", 0) for r in results) / len(results) - top_score = results[0].get("score", 0) if results else 0 - - # Check if query terms match symbols - query_tokens = set(_split_ident(query.lower())) - symbol_matches = sum( - 1 for r in results[:5] - if any(tok in _split_ident((r.get("symbol", "") or "").lower()) - for tok in query_tokens) - ) - - if top_score > 0.8 and symbol_matches > 0: - level = "high" - elif avg_score > 0.6: - level = "medium" - elif results: - level = "low" - else: - level = "none" - - return { - "level": level, - "score": round(avg_score, 3), - "top_score": round(top_score, 3), - "symbol_matches": symbol_matches, - } - diff --git a/scripts/mcp_impl/pattern_search.py b/scripts/mcp_impl/pattern_search.py index 02898c9c..ae5cf33d 100644 --- a/scripts/mcp_impl/pattern_search.py +++ b/scripts/mcp_impl/pattern_search.py @@ -9,13 +9,9 @@ import re from typing import Any, Dict, List, Optional, Union -# Import logger with fallback -try: - from scripts.logger import get_logger - logger = get_logger(__name__) -except ImportError: - import logging - logger = logging.getLogger(__name__) +from scripts.logger import get_logger + +logger = get_logger(__name__) # Import pattern detection components (lazy to avoid startup penalty) _PATTERN_SEARCH_LOADED = False @@ -28,18 +24,14 @@ def _ensure_pattern_search(): global _PATTERN_SEARCH_LOADED, _pattern_search_fn, _search_by_pattern_description_fn if _PATTERN_SEARCH_LOADED: return True - try: - from scripts.pattern_detection.search import ( - pattern_search, - search_by_pattern_description, - ) - _pattern_search_fn = pattern_search - _search_by_pattern_description_fn = search_by_pattern_description - _PATTERN_SEARCH_LOADED = True - return True - except ImportError as e: - logger.warning(f"Pattern search not available: {e}") - return False + from scripts.pattern_detection.search import ( + pattern_search, + search_by_pattern_description, + ) + _pattern_search_fn = pattern_search + _search_by_pattern_description_fn = search_by_pattern_description + _PATTERN_SEARCH_LOADED = True + return True # Supported languages for tree-sitter parsing diff --git a/scripts/mcp_impl/search.py b/scripts/mcp_impl/search.py index e4dc3766..a01827d6 100644 --- a/scripts/mcp_impl/search.py +++ b/scripts/mcp_impl/search.py @@ -39,10 +39,14 @@ _tokens_from_queries, safe_int, ) -from scripts.mcp_impl.workspace import _default_collection, _work_script +from scripts.mcp_impl.workspace import _default_collection from scripts.mcp_impl.admin_tools import _detect_current_repo, _run_async -from scripts.mcp_toon import _should_use_toon, _format_results_as_toon +from scripts.mcp_impl.search_profiles import append_profile_globs, normalize_profile +from scripts.mcp_impl.toon import _should_use_toon, _format_results_as_toon from scripts.mcp_auth import require_collection_access as _require_collection_access +from scripts.path_scope import ( + normalize_under as _normalize_under_scope, +) # Constants QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") @@ -54,6 +58,46 @@ ) +# Fields to strip from results when debug=False (internal/debugging fields) +_DEBUG_RESULT_FIELDS = { + "components", # Internal scoring breakdown (dense_rrf, lexical, fname_boost, etc.) + "doc_id", # Internal benchmark ID (often null/opaque) + "code_id", # Internal benchmark ID (often null/opaque) + "payload", # Duplicates other fields (information, document, pseudo, tags) + "why", # Often empty []; debugging explanation list + "span_budgeted", # Internal budget flag + "relations", # Call graph info (imports, calls) - useful but often noise + "related_paths", # Optional related file paths + "budget_tokens_used", # Internal token accounting + "fname_boost", # Internal boost value (already applied to score) + "host_path", # Internal dual-path (host side) - use path/client_path instead + "container_path", # Internal dual-path (container side) - use path/client_path instead +} + +# Top-level response fields to strip when debug=False +_DEBUG_TOP_LEVEL_FIELDS = { + "rerank_counters", # Internal reranking metrics (inproc_hybrid, timeout, etc.) + "code_signals", # Internal code signal detection results +} + + +def _strip_debug_fields(item: dict, keep_paths: bool = True) -> dict: + """Strip internal/debug fields from a result item. + + Args: + item: Result dict to strip + keep_paths: If True, keep host_path/container_path + + Returns: + New dict with debug fields removed + """ + strip_fields = _DEBUG_RESULT_FIELDS + if keep_paths: + strip_fields = _DEBUG_RESULT_FIELDS - {"host_path", "container_path"} + result = {k: v for k, v in item.items() if k not in strip_fields} + return result + + async def _repo_search_impl( query: Any = None, queries: Any = None, # Alias for query (many clients use this) @@ -71,6 +115,7 @@ async def _repo_search_impl( collection: Any = None, workspace_path: Any = None, mode: Any = None, + profile: Any = None, session: Any = None, ctx: Any = None, # MCP Context (passed from wrapper) # Structured filters (optional; mirrors hybrid_search flags) @@ -89,6 +134,7 @@ async def _repo_search_impl( repo: Any = None, # str, list[str], or "*" to search all repos # Response shaping compact: Any = None, + debug: Any = None, # When True, include verbose internal fields (components, rerank_counters, etc.) output_format: Any = None, # "json" (default) or "toon" for token-efficient format args: Any = None, # Compatibility shim for mcp-remote/Claude wrappers that send args/kwargs kwargs: Any = None, @@ -117,18 +163,26 @@ async def _repo_search_impl( - repo: str or list[str]. Filter by repo name(s). Use "*" to search all repos (disable auto-filter). By default, auto-detects current repo from CURRENT_REPO env and filters to it. Use repo=["frontend","backend"] to search related repos together. - - Filters (optional): language, under (path prefix), kind, symbol, ext, path_regex, + - profile: optional search profile ("tests", "config", "code") that applies useful path constraints. + - Filters (optional): language, under (recursive workspace subtree), kind, symbol, ext, path_regex, path_glob (str or list[str]), not_glob (str or list[str]), not_ (negative text), case. + - debug: bool (default false). When true, includes verbose internal fields like + components, rerank_counters, code_signals. Default false saves ~60-80% tokens. Returns: - Dict with keys: - - results: list of {score, path, symbol, start_line, end_line, why[, components][, relations][, related_paths][, snippet]} - - total: int; used_rerank: bool; rerank_counters: dict + - results: list of {score, path, symbol, start_line, end_line[, snippet][, tags][, host_path][, container_path]} + When debug=true, also includes: components, why, relations, related_paths, doc_id, code_id + - total: int; used_rerank: bool - If compact=true (and snippets not requested), results contain only {path,start_line,end_line}. + - If debug=true, response also includes: rerank_counters, code_signals Examples: - path_glob=["scripts/**","**/*.py"], language="python" + - profile="tests" # constrain to test files + - profile="config" # constrain to config files - symbol="context_answer", under="scripts" + - debug=true # Include internal scoring details for query tuning """ sess = require_auth_session_fn(session) if require_auth_session_fn else session @@ -252,11 +306,18 @@ async def _repo_search_impl( case = _extra.get("case") if compact in (None, "") and _extra.get("compact") is not None: compact = _extra.get("compact") + if debug in (None, "") and _extra.get("debug") is not None: + debug = _extra.get("debug") # Optional mode hint: "code_first", "docs_first", "balanced" if ( mode is None or (isinstance(mode, str) and str(mode).strip() == "") ) and _extra.get("mode") is not None: mode = _extra.get("mode") + if ( + profile is None + or (isinstance(profile, str) and str(profile).strip() == "") + ) and _extra.get("profile") is not None: + profile = _extra.get("profile") except Exception: pass @@ -295,7 +356,8 @@ def _to_str(x, default=""): per_path = _to_int(per_path, 2) include_snippet = _to_bool(include_snippet, True) context_lines = _to_int(context_lines, 2) - # Reranker: default ON; can be disabled via env or client args + # Reranker defaults come from the environment, but an explicit request-level + # opt-in/opt-out should still be respected by MCP/API callers. rerank_env_default = str( os.environ.get("RERANKER_ENABLED", "1") ).strip().lower() in {"1", "true", "yes", "on"} @@ -366,10 +428,13 @@ def _to_str(x, default=""): except Exception: pass - # 3) Environment default (collection only for now) + # 3) Environment defaults (collection + mode) env_coll = (os.environ.get("DEFAULT_COLLECTION") or os.environ.get("COLLECTION_NAME") or "").strip() if (not coll_hint) and env_coll: coll_hint = env_coll + env_mode = (os.environ.get("REPO_SEARCH_DEFAULT_MODE") or "").strip() + if (not mode_hint) and env_mode: + mode_hint = env_mode # Final fallback env_fallback = (os.environ.get("DEFAULT_COLLECTION") or os.environ.get("COLLECTION_NAME") or "codebase").strip() @@ -390,7 +455,7 @@ def _to_str(x, default=""): under = under_hint language = _to_str(language, "").strip() - under = _to_str(under, "").strip() + under = _normalize_under_scope(_to_str(under, "").strip()) kind = _to_str(kind, "").strip() symbol = _to_str(symbol, "").strip() path_regex = _to_str(path_regex, "").strip() @@ -414,9 +479,12 @@ def _to_str_list(x): path_globs = _to_str_list(path_glob) not_globs = _to_str_list(not_glob) + profile = normalize_profile(profile) ext = _to_str(ext, "").strip() not_ = _to_str(not_, "").strip() case = _to_str(case, "").strip() + if profile: + path_globs = append_profile_globs(path_globs, profile) # Normalize repo filter: str, list[str], or "*" (search all) # Default: auto-detect current repo unless REPO_AUTO_FILTER=0 @@ -440,12 +508,122 @@ def _to_str_list(x): if detected_repo: repo_filter = [detected_repo] + case_sensitive = str(case or "").strip().lower() in { + "sensitive", + "true", + "1", + "yes", + "on", + } + path_globs_norm = [g if case_sensitive else g.lower() for g in path_globs] + not_globs_norm = [g if case_sensitive else g.lower() for g in not_globs] + + def _norm_case(v: str) -> str: + return v if case_sensitive else v.lower() + + def _match_glob(glob_pat: str, path_val: str) -> bool: + import fnmatch as _fnm + if not glob_pat: + return False + p = _norm_case(path_val).replace("\\", "/").strip("/") + if _fnm.fnmatchcase(p, glob_pat): + return True + # Allow repo-relative globs (e.g., scripts/**) to match absolute paths + # by testing suffix windows of the normalized path. + if not glob_pat.startswith("/") and "/" in p: + parts = [seg for seg in p.split("/") if seg] + for i in range(1, len(parts)): + tail = "/".join(parts[i:]) + if _fnm.fnmatchcase(tail, glob_pat): + return True + return False + + def _result_passes_path_filters(item: dict) -> bool: + import re as _re + + path = str(item.get("path") or "") + if not path: + return False + + # Evaluate filters against all known path forms carried by this result. + path_vals = [] + for key in ("path", "rel_path", "client_path", "host_path", "container_path"): + v = item.get(key) + if isinstance(v, str) and v.strip(): + path_vals.append(v.strip().replace("\\", "/")) + if not path_vals: + path_vals = [path] + if path.startswith("/work/"): + path_vals.append(path[len("/work/") :]) + + # Deduplicate while preserving order. + seen = set() + norm_paths = [] + for pv in path_vals: + if pv not in seen: + norm_paths.append(pv) + seen.add(pv) + + if not_: + needle = _norm_case(str(not_)) + if any(needle in _norm_case(pv) for pv in norm_paths): + return False + + if ext: + ext_norm = str(ext).lower().lstrip(".") + if not any(_norm_case(pv).endswith("." + ext_norm) for pv in norm_paths): + return False + + if path_regex: + flags = 0 if case_sensitive else _re.IGNORECASE + try: + if not any(_re.search(path_regex, pv, flags=flags) for pv in norm_paths): + return False + except _re.error as exc: + logger.warning( + "Invalid path_regex filter '%s': %s", + path_regex, + exc, + ) + return False + except Exception as exc: + logger.warning( + "Failed evaluating path_regex filter '%s': %s", + path_regex, + exc, + exc_info=True, + ) + return False + + if path_globs_norm and not any( + _match_glob(g, pv) for g in path_globs_norm for pv in norm_paths + ): + return False + + if not_globs_norm and any( + _match_glob(g, pv) for g in not_globs_norm for pv in norm_paths + ): + return False + + return True + + def _apply_result_filters(items: list[dict]) -> list[dict]: + if not items: + return [] + if not (not_ or path_regex or ext or path_globs_norm or not_globs_norm): + return items + return [it for it in items if _result_passes_path_filters(it)] + compact_raw = compact compact = _to_bool(compact, False) # If snippets are requested, do not compact (we need snippet field in results) if include_snippet: compact = False + # Debug mode: when False (default), strip internal/debug fields from results + # to reduce token bloat. Set debug=True to see components, rerank_counters, etc. + debug = _to_bool(debug, False) + # Default behavior: exclude commit-history docs (which use path=".git") from # generic repo_search calls, unless the caller explicitly asks for git # content. This prevents normal code queries from surfacing commit-index @@ -455,6 +633,7 @@ def _to_str_list(x): ): if ".git" not in not_globs: not_globs.append(".git") + not_globs_norm = [g if case_sensitive else g.lower() for g in not_globs] # Accept top-level alias `queries` as a drop-in for `query` # Many clients send queries=[...] instead of query=[...] @@ -548,53 +727,24 @@ def _to_str_list(x): lambda: run_pure_dense_search( query=query_text, limit=eff_limit, + per_path=( + int(per_path) + if (per_path is not None and str(per_path).strip() != "") + else None + ), collection=collection, language=language or None, under=under or None, + kind=kind or None, + symbol=symbol or None, + ext=ext or None, repo=repo_filter, ) ) - # Apply post-filters (path_regex, path_glob, not_glob, not_) that aren't - # supported by run_pure_dense_search's server-side filters - case_sensitive = str(case or "").strip().lower() in {"sensitive", "true", "1", "yes", "on"} - import fnmatch as _fnm - import re as _re - - def _norm_path(p: str) -> str: - return p if case_sensitive else p.lower() - - path_globs_norm = [g if case_sensitive else g.lower() for g in path_globs] - not_globs_norm = [g if case_sensitive else g.lower() for g in not_globs] - path_regex_norm = path_regex or "" - - def _match_glob(glob_pat: str, path_val: str) -> bool: - if not glob_pat: - return False - return _fnm.fnmatchcase(_norm_path(path_val), glob_pat) - for item in items: path = item.get("path") or "" - - # Apply path_regex filter - if path_regex_norm: - flags = 0 if case_sensitive else _re.IGNORECASE - try: - if not _re.search(path_regex_norm, path, flags=flags): - continue - except Exception: - pass - - # Apply path_glob filter - if path_globs_norm and not any(_match_glob(g, path) for g in path_globs_norm): - continue - - # Apply not_glob filter - if not_globs_norm and any(_match_glob(g, path) for g in not_globs_norm): - continue - - # Apply not_ text filter - if not_ and not_.lower() in _norm_path(path): + if not _result_passes_path_filters(item): continue payload = item.get("payload") or {} @@ -640,8 +790,7 @@ def _match_glob(glob_pat: str, path_val: str) -> bool: rt = 0 if rt > eff_limit: eff_limit = rt - # In-process path_glob/not_glob accept a single string; reduce list inputs safely - print(f"[debug] DEBUG_SEARCH_TIMING={os.environ.get('DEBUG_SEARCH_TIMING', 'not set')}", flush=True) + # In-process path_glob/not_glob accept list inputs. items = await asyncio.to_thread( lambda: run_hybrid_search( queries=queries, @@ -696,7 +845,8 @@ def _match_glob(glob_pat: str, path_val: str) -> bool: eff_limit = rt cmd = [ "python", - _work_script("hybrid_search.py"), + "-m", + "scripts.hybrid_search", "--limit", str(eff_limit), "--json", @@ -794,7 +944,7 @@ def _match_glob(glob_pat: str, path_val: str) -> bool: if use_learning_rerank and json_lines: try: - from scripts.rerank_recursive import rerank_with_learning + from scripts.rerank_recursive.recursive import rerank_with_learning rq = queries[0] if queries else "" cand_objs = list(json_lines[: int(rerank_top_n)]) @@ -867,7 +1017,7 @@ def _match_glob(glob_pat: str, path_val: str) -> bool: if use_rerank_inproc and not used_rerank: try: if json_lines: - from scripts.rerank_local import rerank_local as _rr_local # type: ignore + from scripts.rerank_tools.local import rerank_local as _rr_local # type: ignore import concurrent.futures as _fut rq = queries[0] if queries else "" @@ -1065,7 +1215,7 @@ def _doc_for(obj: dict) -> str: if not used_rerank: if use_rerank_inproc: try: - from scripts.rerank_local import rerank_in_process # type: ignore + from scripts.rerank_tools.local import rerank_in_process # type: ignore model_name = os.environ.get( "EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5" @@ -1092,7 +1242,8 @@ def _doc_for(obj: dict) -> str: rq = queries[0] if queries else "" rcmd = [ "python", - _work_script("rerank_local.py"), + "-m", + "scripts.rerank_tools.local", "--query", rq, "--topk", @@ -1217,6 +1368,10 @@ def _doc_for(obj: dict) -> str: item["tags"] = obj.get("tags") results.append(item) + # Enforce strict filter semantics regardless of retrieval/rerank branch. + # This closes gaps where fallback rerank paths may bypass path_glob/not_glob. + results = _apply_result_filters(results) + # Mode-aware reordering: nudge core implementation code vs docs and non-core when requested def _is_doc_path(p: str) -> bool: pl = str(p or "").lower() @@ -1491,6 +1646,20 @@ def _read_snip(args): } for r in results ] + elif not debug: + # Strip debug/internal fields from results to reduce token bloat + # Keeps: score, path, host_path, container_path, symbol, snippet, + # start_line, end_line, tags, pseudo + results = [_strip_debug_fields(r) for r in results] + + _res_ok = bool(res.get("ok", True)) if isinstance(res, dict) else True + try: + _res_code = int((res or {}).get("code", 0)) + except Exception: + _res_code = 0 + if results: + _res_ok = True + _res_code = 0 response = { "args": { @@ -1504,6 +1673,7 @@ def _read_snip(args): "rerank_return_m": int(rerank_return_m), "rerank_timeout_ms": int(rerank_timeout_ms), "collection": collection, + "profile": profile, "language": language, "under": under, "kind": kind, @@ -1518,13 +1688,23 @@ def _read_snip(args): "compact": (_to_bool(compact_raw, compact)), }, "used_rerank": bool(used_rerank), - "rerank_counters": rerank_counters, - "code_signals": code_signals if code_signals.get("has_code_signals") else None, "total": len(results), "results": results, - **res, + "ok": _res_ok, + "code": _res_code, } + # Expose a concise failure reason without leaking raw subprocess streams by default. + if (not _res_ok or _res_code != 0) and not results: + response["error"] = "search backend execution failed" + + # Only include debug fields when explicitly requested + if debug: + response["subprocess"] = res + response["rerank_counters"] = rerank_counters + if code_signals.get("has_code_signals"): + response["code_signals"] = code_signals + # Apply TOON formatting if requested or enabled globally # Full mode (compact=False) still saves tokens vs JSON while preserving all fields if _should_use_toon(output_format): diff --git a/scripts/mcp_impl/search_profiles.py b/scripts/mcp_impl/search_profiles.py new file mode 100644 index 00000000..f1f43ea7 --- /dev/null +++ b/scripts/mcp_impl/search_profiles.py @@ -0,0 +1,81 @@ +"""Shared repo_search profile definitions. + +Profiles are intentionally small path constraints, not alternate search +algorithms. They let callers express common scopes without exposing separate +MCP tools for every preset. +""" + +from __future__ import annotations + +from typing import Iterable + +TEST_GLOBS = [ + "tests/**", + "test/**", + "**/*test*.*", + "**/*_test.*", + "**/Test*/**", +] + +CONFIG_GLOBS = [ + "**/*.yml", + "**/*.yaml", + "**/*.json", + "**/*.toml", + "**/*.ini", + "**/*.env", + "**/*.config", + "**/*.conf", + "**/*.properties", + "**/*.csproj", + "**/*.props", + "**/*.targets", + "**/*.xml", + "**/appsettings*.json", +] + +CODE_GLOBS = [ + "**/*.py", + "**/*.js", + "**/*.ts", + "**/*.tsx", + "**/*.jsx", + "**/*.mjs", + "**/*.cjs", + "**/*.go", + "**/*.java", + "**/*.cs", + "**/*.rb", + "**/*.php", + "**/*.rs", + "**/*.c", + "**/*.h", + "**/*.cpp", + "**/*.hpp", +] + +PROFILE_GLOBS = { + "test": TEST_GLOBS, + "tests": TEST_GLOBS, + "config": CONFIG_GLOBS, + "configs": CONFIG_GLOBS, + "code": CODE_GLOBS, +} + + +def normalize_profile(profile: object) -> str: + return str(profile or "").strip().lower().replace("-", "_") + + +def globs_for_profile(profile: object) -> list[str]: + return list(PROFILE_GLOBS.get(normalize_profile(profile), [])) + + +def append_profile_globs(path_globs: Iterable[str], profile: object) -> list[str]: + merged = [str(g).strip() for g in path_globs if str(g).strip()] + seen = set(merged) + for glob in globs_for_profile(profile): + if glob not in seen: + merged.append(glob) + seen.add(glob) + return merged diff --git a/scripts/mcp_impl/search_specialized.py b/scripts/mcp_impl/search_specialized.py deleted file mode 100644 index 62f06ef4..00000000 --- a/scripts/mcp_impl/search_specialized.py +++ /dev/null @@ -1,259 +0,0 @@ -#!/usr/bin/env python3 -""" -mcp/search_specialized.py - Specialized search implementations for MCP indexer server. - -Extracted from mcp_indexer_server.py for better modularity. -Contains: -- _search_tests_for_impl: Search for test files -- _search_config_for_impl: Search for config files -- _search_callers_for_impl: Search for callers/usages -- _search_importers_for_impl: Search for importers - -Note: The @mcp.tool() decorated functions remain in mcp_indexer_server.py -as thin wrappers that call these implementations. -""" - -from __future__ import annotations - -__all__ = [ - "_search_tests_for_impl", - "_search_config_for_impl", - "_search_callers_for_impl", - "_search_importers_for_impl", -] - -import logging -from typing import Any, Dict, Optional - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Imports from sibling modules -# --------------------------------------------------------------------------- -from scripts.mcp_impl.utils import _extract_kwargs_payload - - -# Test file globs -TEST_GLOBS = [ - "tests/**", - "test/**", - "**/*test*.*", - "**/*_test.*", - "**/Test*/**", -] - -# Config file globs -CONFIG_GLOBS = [ - "**/*.yml", - "**/*.yaml", - "**/*.json", - "**/*.toml", - "**/*.ini", - "**/*.env", - "**/*.config", - "**/*.conf", - "**/*.properties", - "**/*.csproj", - "**/*.props", - "**/*.targets", - "**/*.xml", - "**/appsettings*.json", -] - -# Code file globs for importers -CODE_GLOBS = [ - "**/*.py", - "**/*.js", - "**/*.ts", - "**/*.tsx", - "**/*.jsx", - "**/*.mjs", - "**/*.cjs", - "**/*.go", - "**/*.java", - "**/*.cs", - "**/*.rb", - "**/*.php", - "**/*.rs", - "**/*.c", - "**/*.h", - "**/*.cpp", - "**/*.hpp", -] - - -async def _search_tests_for_impl( - query: Any = None, - limit: Any = None, - include_snippet: Any = None, - context_lines: Any = None, - under: Any = None, - language: Any = None, - session: Any = None, - compact: Any = None, - kwargs: Any = None, - ctx: Any = None, - repo_search_fn=None, -) -> Dict[str, Any]: - """Find test files related to a query. - - What it does: - - Presets common test file globs and forwards to repo_search - - Accepts extra filters via kwargs (e.g., language, under, case) - - Parameters: - - query: str or list[str]; limit; include_snippet/context_lines; under; language; compact - - Returns: repo_search result shape. - """ - globs = list(TEST_GLOBS) - # Allow caller to add more with path_glob kwarg - # Handle kwargs being passed as a string by some MCP clients - _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} - extra_glob = _kwargs.get("path_glob") - if extra_glob: - if isinstance(extra_glob, (list, tuple)): - globs.extend([str(x) for x in extra_glob]) - else: - globs.append(str(extra_glob)) - - if repo_search_fn is None: - from scripts.mcp_impl.search import _repo_search_impl - repo_search_fn = _repo_search_impl - - return await repo_search_fn( - query=query, - limit=limit, - include_snippet=include_snippet, - context_lines=context_lines, - under=under, - language=language, - path_glob=globs, - session=session, - compact=compact, - ctx=ctx, - kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, - ) - - -async def _search_config_for_impl( - query: Any = None, - limit: Any = None, - include_snippet: Any = None, - context_lines: Any = None, - under: Any = None, - session: Any = None, - compact: Any = None, - kwargs: Any = None, - ctx: Any = None, - repo_search_fn=None, -) -> Dict[str, Any]: - """Find likely configuration files for a service/query. - - What it does: - - Presets config file globs (yaml/json/toml/etc.) and forwards to repo_search - - Accepts extra filters via kwargs - - Returns: repo_search result shape. - """ - globs = list(CONFIG_GLOBS) - # Handle kwargs being passed as a string by some MCP clients - _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} - extra_glob = _kwargs.get("path_glob") - if extra_glob: - if isinstance(extra_glob, (list, tuple)): - globs.extend([str(x) for x in extra_glob]) - else: - globs.append(str(extra_glob)) - - if repo_search_fn is None: - from scripts.mcp_impl.search import _repo_search_impl - repo_search_fn = _repo_search_impl - - return await repo_search_fn( - query=query, - limit=limit, - include_snippet=include_snippet, - context_lines=context_lines, - under=under, - session=session, - path_glob=globs, - compact=compact, - ctx=ctx, - kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, - ) - - -async def _search_callers_for_impl( - query: Any = None, - limit: Any = None, - language: Any = None, - session: Any = None, - kwargs: Any = None, - ctx: Any = None, - repo_search_fn=None, -) -> Dict[str, Any]: - """Heuristic search for callers/usages of a symbol. - - When to use: - - You want files that reference/invoke a function/class - - Notes: - - Thin wrapper over repo_search today; pass language or path_glob to narrow - - Returns repo_search result shape - """ - if repo_search_fn is None: - from scripts.mcp_impl.search import _repo_search_impl - repo_search_fn = _repo_search_impl - - return await repo_search_fn( - query=query, - limit=limit, - language=language, - session=session, - ctx=ctx, - kwargs=kwargs, - ) - - -async def _search_importers_for_impl( - query: Any = None, - limit: Any = None, - language: Any = None, - session: Any = None, - kwargs: Any = None, - ctx: Any = None, - repo_search_fn=None, -) -> Dict[str, Any]: - """Find files likely importing or referencing a module/symbol. - - What it does: - - Presets code globs across common languages; forwards to repo_search - - Accepts additional filters via kwargs (e.g., under, case) - - Returns: repo_search result shape. - """ - globs = list(CODE_GLOBS) - # Handle kwargs being passed as a string by some MCP clients - _kwargs = _extract_kwargs_payload(kwargs) if kwargs else {} - extra_glob = _kwargs.get("path_glob") - if extra_glob: - if isinstance(extra_glob, (list, tuple)): - globs.extend([str(x) for x in extra_glob]) - else: - globs.append(str(extra_glob)) - - if repo_search_fn is None: - from scripts.mcp_impl.search import _repo_search_impl - repo_search_fn = _repo_search_impl - - # Forward to repo_search with preset path_glob; caller can still pass other filters - return await repo_search_fn( - query=query, - limit=limit, - language=language, - path_glob=globs, - session=session, - ctx=ctx, - kwargs={k: v for k, v in _kwargs.items() if k not in {"path_glob"}}, - ) diff --git a/scripts/mcp_impl/symbol_graph.py b/scripts/mcp_impl/symbol_graph.py index da518ac4..f3a8f926 100644 --- a/scripts/mcp_impl/symbol_graph.py +++ b/scripts/mcp_impl/symbol_graph.py @@ -20,10 +20,54 @@ import logging import os import re +import time from typing import Any, Dict, List, Optional, Set +from scripts.path_scope import ( + normalize_under as _normalize_under_scope, + metadata_matches_under as _metadata_matches_under, + path_matches_under as _path_matches_under, +) + logger = logging.getLogger(__name__) +try: + from scripts.ingest.graph_edges import GRAPH_COLLECTION_SUFFIX as _GRAPH_SUFFIX +except Exception: + _GRAPH_SUFFIX = "_graph" + +GRAPH_COLLECTION_SUFFIX = _GRAPH_SUFFIX +# Time-based cache: collection -> expiry timestamp (5 minutes TTL) +_MISSING_GRAPH_COLLECTIONS: dict[str, float] = {} +_MISSING_GRAPH_TTL = 300 # 5 minutes + + +def _clean_expired_missing_graphs() -> None: + """Remove expired entries from the missing graph cache.""" + now = time.monotonic() + expired = [coll for coll, expiry in _MISSING_GRAPH_COLLECTIONS.items() if expiry <= now] + for coll in expired: + _MISSING_GRAPH_COLLECTIONS.pop(coll, None) + + +def _is_graph_missing(collection: str) -> bool: + """Check if a graph collection is marked as missing (with expiration).""" + _clean_expired_missing_graphs() + if collection in _MISSING_GRAPH_COLLECTIONS: + return _MISSING_GRAPH_COLLECTIONS.get(collection, 0) > time.monotonic() + return False + + +def _mark_graph_missing(collection: str) -> None: + """Mark a graph collection as missing (with TTL).""" + _MISSING_GRAPH_COLLECTIONS[collection] = time.monotonic() + _MISSING_GRAPH_TTL + + +def _clear_graph_missing(collection: str) -> None: + """Remove a collection from the missing graph cache (e.g., after successful creation).""" + _MISSING_GRAPH_COLLECTIONS.pop(collection, None) + + __all__ = [ "_symbol_graph_impl", "_format_symbol_graph_toon", @@ -105,23 +149,18 @@ def _symbol_variants(symbol: str) -> List[str]: return list(dict.fromkeys(variants)) # Dedupe preserving order def _norm_under(u: Optional[str]) -> Optional[str]: - """Normalize an `under` path to match ingest's stored `metadata.path_prefix` values. + """Normalize user-facing `under` to recursive subtree scope token.""" + return _normalize_under_scope(u) - This mirrors the engine's convention: normalize to a /work/... style path. - Note: `under` in this engine is an exact directory filter (not recursive). - """ - if not u: - return None - s = str(u).strip().replace("\\", "/") - s = "/".join([p for p in s.split("/") if p]) - if not s: - return None - # Normalize to /work/... - if not s.startswith("/"): - v = "/work/" + s - else: - v = "/work/" + s.lstrip("/") if not s.startswith("/work/") else s - return v.rstrip("/") + +def _point_matches_under(pt: Any, under: Optional[str]) -> bool: + if not under: + return True + payload = getattr(pt, "payload", None) or {} + md = payload.get("metadata", payload) + if not isinstance(md, dict): + md = {} + return _metadata_matches_under(md, under) async def _symbol_graph_impl( @@ -142,7 +181,7 @@ async def _symbol_graph_impl( query_type: One of "callers", "definition", "importers" limit: Maximum number of results language: Optional language filter - under: Optional path prefix filter + under: Optional recursive workspace subtree filter collection: Optional collection override session: Optional session ID for collection routing ctx: MCP context (optional) @@ -193,18 +232,32 @@ async def _symbol_graph_impl( results = [] + norm_under = _norm_under(under) + try: if query_type == "callers": - # Find chunks where metadata.calls array contains the symbol (exact match) - results = await _query_array_field( + # Prefer graph edges collection when available (fast keyword filters). + results = await _query_graph_edges_collection( client=client, collection=coll, - field_key="metadata.calls", - value=symbol, + symbol=symbol, + edge_type="calls", limit=limit, language=language, - under=_norm_under(under), + repo_filter=None, + under=norm_under, ) + if not results: + # Fall back to array field lookup in the main collection. + results = await _query_array_field( + client=client, + collection=coll, + field_key="metadata.calls", + value=symbol, + limit=limit, + language=language, + under=norm_under, + ) elif query_type == "definition": # Find chunks where symbol_path matches the symbol results = await _query_definition( @@ -213,19 +266,30 @@ async def _symbol_graph_impl( symbol=symbol, limit=limit, language=language, - under=_norm_under(under), + under=norm_under, ) elif query_type == "importers": - # Find chunks where metadata.imports array contains the symbol - results = await _query_array_field( + results = await _query_graph_edges_collection( client=client, collection=coll, - field_key="metadata.imports", - value=symbol, + symbol=symbol, + edge_type="imports", limit=limit, language=language, - under=_norm_under(under), + repo_filter=None, + under=norm_under, ) + if not results: + # Fall back to array field lookup in the main collection. + results = await _query_array_field( + client=client, + collection=coll, + field_key="metadata.imports", + value=symbol, + limit=limit, + language=language, + under=norm_under, + ) # If no results, fall back to semantic search if not results: @@ -234,6 +298,7 @@ async def _symbol_graph_impl( query_type=query_type, limit=limit, language=language, + under=norm_under, collection=coll, session=session, ) @@ -246,6 +311,7 @@ async def _symbol_graph_impl( query_type=query_type, limit=limit, language=language, + under=norm_under, collection=coll, session=session, ) @@ -259,6 +325,155 @@ async def _symbol_graph_impl( } +async def _query_graph_edges_collection( + client: Any, + collection: str, + symbol: str, + edge_type: str, + limit: int, + language: Optional[str] = None, + repo_filter: str | None = None, + under: str | None = None, +) -> List[Dict[str, Any]]: + """Query `_graph` and hydrate results from the main collection. + + The graph collection stores file-level edges: + - caller_path -> callee_symbol (calls/imports) + """ + from qdrant_client import models as qmodels + + graph_coll = f"{collection}{GRAPH_COLLECTION_SUFFIX}" + if _is_graph_missing(graph_coll): + return [] + + # Build graph filter + must: list[Any] = [ + qmodels.FieldCondition( + key="edge_type", match=qmodels.MatchValue(value=str(edge_type)) + ) + ] + if repo_filter: + rf = str(repo_filter).strip() + if rf and rf != "*": + must.append( + qmodels.FieldCondition(key="repo", match=qmodels.MatchValue(value=rf)) + ) + + # Try exact match, then symbol variants. + callee_variants = _symbol_variants(symbol) or [symbol] + seen_paths: set[str] = set() + caller_paths: List[str] = [] + + for variant in callee_variants: + if len(caller_paths) >= limit: + break + v = str(variant).strip() + if not v: + continue + flt = qmodels.Filter( + must=must + + [ + qmodels.FieldCondition( + key="callee_symbol", match=qmodels.MatchValue(value=v) + ) + ] + ) + + def _scroll(_flt=flt): + return client.scroll( + collection_name=graph_coll, + scroll_filter=_flt, + limit=max(32, limit * 4), + with_payload=True, + with_vectors=False, + ) + + try: + points, _ = await asyncio.to_thread(_scroll) + except Exception as e: + err = str(e).lower() + if "404" in err or "doesn't exist" in err or "not found" in err: + _mark_graph_missing(graph_coll) + return [] + logger.exception( + "_query_graph_edges_collection scroll failed for %s", graph_coll + ) + raise + + for rec in points or []: + payload = getattr(rec, "payload", None) or {} + p = payload.get("caller_path") or "" + if not p: + continue + path_s = str(p) + if under and not _path_matches_under( + path_s, under, repo_hint=(payload.get("repo") or repo_filter) + ): + continue + if path_s in seen_paths: + continue + seen_paths.add(path_s) + caller_paths.append(path_s) + if len(caller_paths) >= limit: + break + + if not caller_paths: + return [] + + # Hydrate caller paths back into normal symbol_graph point-shaped results. + hydrated: List[Dict[str, Any]] = [] + for p in caller_paths[:limit]: + if len(hydrated) >= limit: + break + + def _scroll_main(_p=p, _language=language): + must = [ + qmodels.FieldCondition( + key="metadata.path", match=qmodels.MatchValue(value=_p) + ) + ] + if _language: + must.append( + qmodels.FieldCondition( + key="metadata.language", + match=qmodels.MatchValue(value=str(_language).lower()), + ) + ) + return client.scroll( + collection_name=collection, + scroll_filter=qmodels.Filter( + must=must + ), + limit=1, + with_payload=True, + with_vectors=False, + ) + + try: + pts, _ = await asyncio.to_thread(_scroll_main) + except Exception: + pts = [] + + if pts: + hydrated.append(_format_point(pts[0])) + else: + # If language filtering was requested but no matching main-collection doc + # exists (or hydration failed), skip returning a placeholder to avoid + # producing language-inconsistent results. + if not language: + hydrated.append( + { + "path": p, + "symbol": "", + "symbol_path": "", + "start_line": 0, + "end_line": 0, + } + ) + + return hydrated + + async def _query_array_field( client: Any, collection: str, @@ -290,14 +505,6 @@ async def _query_array_field( match=qmodels.MatchValue(value=language.lower()), ) ) - if under: - base_conditions.append( - qmodels.FieldCondition( - key="metadata.path_prefix", - match=qmodels.MatchValue(value=under), - ) - ) - # Strategy 1: Exact match with MatchAny (most reliable for array fields) try: filter1 = qmodels.Filter( @@ -321,6 +528,8 @@ def scroll1(): scroll_result = await asyncio.to_thread(scroll1) points = scroll_result[0] if scroll_result else [] for pt in points: + if under and not _point_matches_under(pt, under): + continue pt_id = str(getattr(pt, "id", id(pt))) if pt_id not in seen_ids: seen_ids.add(pt_id) @@ -356,6 +565,8 @@ def scroll2(): scroll_result = await asyncio.to_thread(scroll2) points = scroll_result[0] if scroll_result else [] for pt in points: + if under and not _point_matches_under(pt, under): + continue pt_id = str(getattr(pt, "id", id(pt))) if pt_id not in seen_ids: seen_ids.add(pt_id) @@ -387,6 +598,8 @@ def scroll3(): scroll_result = await asyncio.to_thread(scroll3) points = scroll_result[0] if scroll_result else [] for pt in points: + if under and not _point_matches_under(pt, under): + continue pt_id = str(getattr(pt, "id", id(pt))) if pt_id not in seen_ids: seen_ids.add(pt_id) @@ -422,14 +635,6 @@ async def _query_definition( match=qmodels.MatchValue(value=language.lower()), ) ) - if under: - base_conditions.append( - qmodels.FieldCondition( - key="metadata.path_prefix", - match=qmodels.MatchValue(value=under), - ) - ) - # Strategy 1: Exact match on symbol_path (e.g., "MyClass.my_method") try: filter1 = qmodels.Filter( @@ -514,6 +719,8 @@ def scroll3(): seen_ids = set() unique_results = [] for pt in results: + if under and not _point_matches_under(pt, under): + continue pt_id = getattr(pt, "id", None) if pt_id not in seen_ids: seen_ids.add(pt_id) @@ -570,6 +777,7 @@ async def _fallback_semantic_search( query_type: str, limit: int = 20, language: Optional[str] = None, + under: Optional[str] = None, collection: Optional[str] = None, session: Optional[str] = None, ) -> List[Dict[str, Any]]: @@ -591,6 +799,8 @@ async def _fallback_semantic_search( query=query, limit=limit, language=language, + under=under, + collection=collection, session=session, output_format="json", # Avoid TOON encoding for internal calls ) @@ -598,7 +808,7 @@ async def _fallback_semantic_search( # Handle case where results might be TOON-encoded string (shouldn't happen with output_format="json") results = search_result.get("results", []) if isinstance(results, str): - # If somehow still a string, return empty - TOON decoding is not worth it here + # Internal callers require structured rows; skip malformed text-only responses. logger.debug("Fallback search returned TOON-encoded results, skipping") return [] return results @@ -655,7 +865,7 @@ async def _compute_called_by( symbol: The symbol name to find callers for limit: Maximum number of callers to return language: Optional language filter - under: Optional path prefix filter + under: Optional recursive workspace subtree filter collection: Optional collection override Returns: @@ -703,13 +913,6 @@ async def _compute_called_by( ) ) norm_under = _norm_under(under) - if norm_under: - base_conditions.append( - qmodels.FieldCondition( - key="metadata.path_prefix", - match=qmodels.MatchValue(value=norm_under), - ) - ) callers: List[Dict[str, Any]] = [] seen_ids: Set[str] = set() @@ -743,6 +946,8 @@ def do_scroll(): points = scroll_result[0] if scroll_result else [] for pt in points: + if norm_under and not _point_matches_under(pt, norm_under): + continue pt_id = str(getattr(pt, "id", id(pt))) if pt_id in seen_ids: continue diff --git a/scripts/mcp_impl/toon.py b/scripts/mcp_impl/toon.py index f1a0f209..02b5bead 100644 --- a/scripts/mcp_impl/toon.py +++ b/scripts/mcp_impl/toon.py @@ -21,6 +21,8 @@ import os from typing import Any, Dict +from scripts.toon_encoder import encode_context_results, encode_search_results + logger = logging.getLogger(__name__) @@ -51,9 +53,9 @@ def _should_use_toon(output_format: Any) -> bool: # TOON response formatting # --------------------------------------------------------------------------- def _format_results_as_toon(response: Dict[str, Any], compact: bool = False) -> Dict[str, Any]: - """Convert response to use TOON-formatted results string instead of JSON array. + """Add a TOON-formatted render while preserving structured results. - Replaces 'results' array with 'results' string in TOON format to save tokens. + Keeps 'results' as JSON-compatible structured data for machine callers. Always adds output_format marker when TOON is requested, even for empty results. Args: @@ -61,21 +63,14 @@ def _format_results_as_toon(response: Dict[str, Any], compact: bool = False) -> compact: If True, use more compact TOON encoding Returns: - Modified response with TOON-encoded results + Modified response with TOON-encoded text """ try: - from scripts.toon_encoder import encode_search_results - results = response.get("results", []) if isinstance(results, list): - # Replace JSON array with TOON string (handles empty arrays too) - toon_results = encode_search_results(results, compact=compact) - response["results"] = toon_results + response["text"] = encode_search_results(results, compact=compact) response["output_format"] = "toon" - return response - except ImportError: - logger.warning("TOON encoder not available, returning JSON format") return response except Exception as e: logger.debug(f"TOON encoding failed: {e}") @@ -83,32 +78,26 @@ def _format_results_as_toon(response: Dict[str, Any], compact: bool = False) -> def _format_context_results_as_toon(response: Dict[str, Any], compact: bool = False) -> Dict[str, Any]: - """Convert context_search response to TOON format, handling mixed code/memory results. + """Add a TOON render for context_search mixed code/memory results. Uses encode_context_results which properly handles memory entries (content/score) vs code entries (path/line), avoiding blank rows or dropped content. + Keeps 'results' as JSON-compatible structured data for machine callers. Args: response: Context search response dict with 'results' key compact: If True, use more compact TOON encoding Returns: - Modified response with TOON-encoded results + Modified response with TOON-encoded text """ try: - from scripts.toon_encoder import encode_context_results - results = response.get("results", []) if isinstance(results, list): - toon_results = encode_context_results(results, compact=compact) - response["results"] = toon_results + response["text"] = encode_context_results(results, compact=compact) response["output_format"] = "toon" - return response - except ImportError: - logger.warning("TOON encoder not available, returning JSON format") return response except Exception as e: logger.debug(f"TOON encoding failed: {e}") return response - diff --git a/scripts/mcp_impl/workspace.py b/scripts/mcp_impl/workspace.py index 6115f8f6..5b56aced 100644 --- a/scripts/mcp_impl/workspace.py +++ b/scripts/mcp_impl/workspace.py @@ -26,7 +26,6 @@ "_state_file_path", "_read_ws_state", "_default_collection", - "_work_script", ] import json @@ -98,23 +97,3 @@ def _default_collection() -> str: if isinstance(coll, str) and coll.strip(): return coll.strip() return DEFAULT_COLLECTION - - -def _work_script(name: str) -> str: - """Return path to script respecting bind mounts first, then /app, then local fallback.""" - try: - work_path = os.path.join("/work", "scripts", name) - if os.path.exists(work_path): - return work_path - except Exception: - pass - - try: - app_path = os.path.join("/app", "scripts", name) - if os.path.exists(app_path): - return app_path - except Exception: - pass - - return os.path.join(os.getcwd(), "scripts", name) - diff --git a/scripts/mcp_indexer_server.py b/scripts/mcp_indexer_server.py index d12aee9b..68ef533b 100644 --- a/scripts/mcp_indexer_server.py +++ b/scripts/mcp_indexer_server.py @@ -29,19 +29,8 @@ # CRITICAL: OpenLit must be initialized BEFORE any qdrant_client imports # to properly instrument vector DB calls. This import must come first! # --------------------------------------------------------------------------- -import os as _os -import sys as _sys -_roots_env = _os.environ.get("WORK_ROOTS", "") -_roots = [p.strip() for p in _roots_env.split(",") if p.strip()] or ["/work", "/app"] -for _root in _roots: - if _root and _root not in _sys.path: - _sys.path.insert(0, _root) - -# Now import OpenLit init (before any other scripts imports) -try: - from scripts import openlit_init # noqa: F401 - triggers early instrumentation -except ImportError: - pass # OpenLit not available +from scripts import openlit_init # noqa: F401 - triggers early instrumentation + import json import asyncio import re @@ -68,31 +57,13 @@ def _json_dumps_bytes(obj) -> bytes: from typing import Any, Dict, Optional, List, Tuple from pathlib import Path -import sys - -# Import structured logging and error handling (after sys.path setup) -# Will be imported after sys.path is configured below -import contextlib - -# Ensure code roots are on sys.path so absolute imports like 'from scripts.x import y' work -# when this file is executed directly (sys.path[0] may be /work/scripts). -# Supports multiple roots via WORK_ROOTS env (comma-separated), defaults to /work and /app. -_roots_env = os.environ.get("WORK_ROOTS", "") -_roots = [p.strip() for p in _roots_env.split(",") if p.strip()] or ["/work", "/app"] -try: - for _root in _roots: - if _root and _root not in sys.path: - sys.path.insert(0, _root) -except Exception: - pass +import qdrant_client # Note: OpenLit initialization is handled by early import of scripts.openlit_init # at the top of this file (before any qdrant_client imports) -# Session state imported from mcp_workspace shim (-> scripts.mcp.workspace) -# Must be after sys.path setup -from scripts.mcp_workspace import ( +from scripts.mcp_impl.workspace import ( _MEM_COLL_CACHE, SESSION_DEFAULTS, SESSION_DEFAULTS_BY_SESSION, @@ -100,31 +71,20 @@ def _json_dumps_bytes(obj) -> bytes: _SESSION_CTX_LOCK, ) -# Import structured logging and error handling (after sys.path setup) -try: - from scripts.logger import ( - get_logger, - ContextLogger, - RetrievalError, - IndexingError, - DecoderError, - ValidationError, - ConfigurationError, - safe_int, - safe_float, - safe_bool, - ) - - logger = get_logger(__name__) -except ImportError: - # Fallback if logger module not available - import logging - - logger = logging.getLogger(__name__) - logging.basicConfig(level=logging.INFO) +from scripts.logger import ( + get_logger, + ContextLogger, + RetrievalError, + IndexingError, + DecoderError, + ValidationError, + ConfigurationError, + safe_int, + safe_float, + safe_bool, +) - # Import safe conversion functions from utils (single source of truth) - from scripts.mcp_impl.utils import safe_int, safe_float, safe_bool +logger = get_logger(__name__) from scripts.mcp_auth import ( @@ -135,7 +95,7 @@ def _json_dumps_bytes(obj) -> bytes: # --------------------------------------------------------------------------- # Re-exports from extracted modules (backwards compatibility) # --------------------------------------------------------------------------- -from scripts.mcp_utils import ( +from scripts.mcp_impl.utils import ( _coerce_bool, _coerce_int, _coerce_str, @@ -152,7 +112,7 @@ def _json_dumps_bytes(obj) -> bytes: _primary_identifier_from_queries, ) -from scripts.mcp_toon import ( +from scripts.mcp_impl.toon import ( _is_toon_output_enabled, _should_use_toon, _format_results_as_toon, @@ -163,20 +123,7 @@ def _json_dumps_bytes(obj) -> bytes: from scripts.mcp_impl.context_search import _context_search_impl from scripts.mcp_impl.query_expand import _expand_query_impl from scripts.mcp_impl.search import _repo_search_impl -from scripts.mcp_impl.info_request import ( - _extract_symbols_from_query, - _extract_related_concepts, - _format_information_field, - _extract_relationships, - _calculate_confidence, -) from scripts.mcp_impl.admin_tools import _collection_map_impl -from scripts.mcp_impl.search_specialized import ( - _search_tests_for_impl, - _search_config_for_impl, - _search_callers_for_impl, - _search_importers_for_impl, -) from scripts.mcp_impl.search_history import ( _search_commits_for_impl, _change_history_for_path_impl, @@ -191,25 +138,13 @@ def _json_dumps_bytes(obj) -> bytes: _ENV_LOCK = threading.Lock() # Shared utilities (lex hashing, snippet highlighter) -try: - from scripts.utils import highlight_snippet as _do_highlight_snippet -except Exception as e: - logger.warning(f"Failed to import rich for syntax highlighting: {e}") - _do_highlight_snippet = None # fallback guarded at call site +from scripts.utils import highlight_snippet as _do_highlight_snippet # Back-compat shim for tests expecting _highlight_snippet in this module # Delegates to scripts.utils.highlight_snippet when available -try: - - def _highlight_snippet(snippet, tokens): # type: ignore - return ( - _do_highlight_snippet(snippet, tokens) if _do_highlight_snippet else snippet - ) -except Exception: - - def _highlight_snippet(snippet, tokens): # type: ignore - return snippet +def _highlight_snippet(snippet, tokens): # type: ignore + return _do_highlight_snippet(snippet, tokens) try: @@ -243,7 +178,7 @@ def _highlight_snippet(snippet, tokens): # type: ignore try: from scripts.workspace_state import get_collection_name as _ws_get_collection_name # type: ignore - if DEFAULT_COLLECTION in {"", "default-collection", "my-collection", "codebase"}: + if DEFAULT_COLLECTION in {"", "codebase"}: resolved = _ws_get_collection_name(None) if resolved: DEFAULT_COLLECTION = resolved @@ -281,15 +216,14 @@ def _highlight_snippet(snippet, tokens): # type: ignore ) # Disable strict identifier requirement -# --- TOON functions imported from scripts.mcp_toon --- +# --- TOON functions imported from scripts.mcp_impl.toon --- # (see imports at top of file for backwards compatibility re-exports) -# --- Workspace state functions imported from mcp_workspace shim --- -from scripts.mcp_workspace import ( +# --- Workspace state functions imported from workspace helper module --- +from scripts.mcp_impl.workspace import ( _state_file_path, _read_ws_state, _default_collection, - _work_script, ) # Disable DNS rebinding protection - breaks Docker internal networking (Host: mcp:8000) @@ -300,6 +234,23 @@ def _highlight_snippet(snippet, tokens): # type: ignore ) mcp = FastMCP(APP_NAME, transport_security=_security_settings) +# Minimal resource so MCP clients can verify resource wiring. +@mcp.resource( + "resource://context-engine/indexer/info", + name="context-engine-indexer-info", + title="Context Engine Indexer Info", + description="Basic metadata about the running indexer MCP server.", + mime_type="application/json", +) +def _indexer_info_resource(): + return { + "app": APP_NAME, + "host": HOST, + "port": PORT, + "qdrant_url": QDRANT_URL, + "default_collection": DEFAULT_COLLECTION, + } + # Capture tool registry automatically by wrapping the decorator once _TOOLS_REGISTRY: list[dict] = [] @@ -441,98 +392,26 @@ def log_message(self, *args, **kwargs): return False -# Import the new subprocess manager -try: - from scripts.subprocess_manager import run_subprocess_async -except ImportError: - # Fallback if subprocess_manager not available - logger.warning("subprocess_manager not available, using fallback implementation") - - async def run_subprocess_async( - cmd: List[str], - timeout: Optional[float] = None, - env: Optional[Dict[str, str]] = None, - ) -> Dict[str, Any]: - """Fallback subprocess runner if subprocess_manager is not available.""" - proc: Optional[asyncio.subprocess.Process] = None - try: - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=env, - ) - # Default timeout from env if not provided by caller - if timeout is None: - timeout = MCP_TOOL_TIMEOUT_SECS - try: - stdout_b, stderr_b = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) - code = proc.returncode - except asyncio.TimeoutError: - try: - proc.kill() - except Exception: - pass - return { - "ok": False, - "code": -1, - "stdout": "", - "stderr": f"Command timed out after {timeout}s", - } - stdout = (stdout_b or b"").decode("utf-8", errors="ignore") - stderr = (stderr_b or b"").decode("utf-8", errors="ignore") - - def _cap_tail(s: str) -> str: - if not s: - return s - return ( - s - if len(s) <= MAX_LOG_TAIL - else ("...[tail truncated]\n" + s[-MAX_LOG_TAIL:]) - ) - - return { - "ok": code == 0, - "code": code, - "stdout": _cap_tail(stdout), - "stderr": _cap_tail(stderr), - } - except Exception as e: - return {"ok": False, "code": -2, "stdout": "", "stderr": str(e)} - finally: - try: - if proc is not None: - if proc.stdout is not None: - proc.stdout.close() - if proc.stderr is not None: - proc.stderr.close() - # Ensure the process is reaped - with contextlib.suppress(Exception): - await proc.wait() - except Exception: - pass +from scripts.subprocess_manager import run_subprocess_async -# --- Admin tool helpers imported from mcp_admin_tools shim --- -from scripts.mcp_admin_tools import ( +# --- Admin tool helpers imported from admin helper module --- +from scripts.mcp_impl.admin_tools import ( _EMBED_MODEL_CACHE, _EMBED_MODEL_LOCKS, _run_async, _get_embedding_model, - _invalidate_router_scratchpad, _detect_current_repo, ) # Lenient argument normalization to tolerate buggy clients (e.g., JSON-in-kwargs, booleans where strings expected) -# Note: _maybe_parse_jsonish and other parsing helpers are now imported from scripts.mcp_utils +# Note: _maybe_parse_jsonish and other parsing helpers are now imported from scripts.mcp_impl.utils from typing import Any as _Any, Dict as _Dict # Extra parsing helpers for quirky clients that send stringified kwargs import urllib.parse as _urlparse, ast as _ast -# --- Utility functions imported from scripts.mcp_utils --- +# --- Utility functions imported from scripts.mcp_impl.utils --- # (see imports at top of file for backwards compatibility re-exports: # _parse_kv_string, _coerce_value_string, _to_str_list_relaxed, # _extract_kwargs_payload, _looks_jsonish_string, _coerce_bool, @@ -588,7 +467,7 @@ async def qdrant_index_root( ) # type: ignore if _ws_is_multi_repo_mode(): - coll = _ws_get_collection_name("/work") or _default_collection() + coll = _default_collection() else: coll = _ws_get_collection_name(None) or _default_collection() except Exception: @@ -600,18 +479,12 @@ async def qdrant_index_root( env["QDRANT_URL"] = QDRANT_URL env["COLLECTION_NAME"] = coll - cmd = ["python", _work_script("ingest_code.py"), "--root", "/work"] + cmd = ["python", "-m", "scripts.ingest_code", "--root", "/work"] if recreate: cmd.append("--recreate") res = await _run_async(cmd, env=env) ret = {"args": {"root": "/work", "collection": coll, "recreate": recreate}, **res} - try: - if ret.get("ok") and int(ret.get("code", 1)) == 0: - if _invalidate_router_scratchpad("/work"): - ret["invalidated_router_scratchpad"] = True - except Exception: - pass return ret @@ -630,17 +503,13 @@ async def qdrant_list(kwargs: Any = None) -> Dict[str, Any]: - {"collections": [str, ...]} or {"error": "..."} """ try: - from qdrant_client import QdrantClient - - client = QdrantClient( + client = qdrant_client.QdrantClient( url=QDRANT_URL, api_key=os.environ.get("QDRANT_API_KEY"), timeout=float(os.environ.get("QDRANT_TIMEOUT", "20") or 20), ) cols_info = await asyncio.to_thread(client.get_collections) return {"collections": [c.name for c in cols_info.collections]} - except ImportError: - return {"error": "qdrant_client is not installed in this container"} except Exception as e: return {"error": str(e)} @@ -759,10 +628,9 @@ async def qdrant_status( pass coll = collection or _default_collection() try: - from qdrant_client import QdrantClient import datetime as _dt - client = QdrantClient( + client = qdrant_client.QdrantClient( url=QDRANT_URL, api_key=os.environ.get("QDRANT_API_KEY"), timeout=float(os.environ.get("QDRANT_TIMEOUT", "20") or 20), @@ -909,7 +777,7 @@ async def qdrant_index( ) # type: ignore if _ws_is_multi_repo_mode(): - coll = _ws_get_collection_name(root) or _default_collection() + coll = _default_collection() else: coll = _ws_get_collection_name(None) or _default_collection() except Exception: @@ -923,7 +791,8 @@ async def qdrant_index( cmd = [ "python", - _work_script("ingest_code.py"), + "-m", + "scripts.ingest_code", "--root", root, ] @@ -932,12 +801,6 @@ async def qdrant_index( res = await _run_async(cmd, env=env) ret = {"args": {"root": root, "collection": coll, "recreate": recreate}, **res} - try: - if ret.get("ok") and int(ret.get("code", 1)) == 0: - if _invalidate_router_scratchpad("/work"): - ret["invalidated_router_scratchpad"] = True - except Exception: - pass return ret @@ -1028,7 +891,7 @@ async def qdrant_prune(kwargs: Any = None, **ignored: Any) -> Dict[str, Any]: env = os.environ.copy() env["PRUNE_ROOT"] = "/work" - cmd = ["python", _work_script("prune.py")] + cmd = ["python", "-m", "scripts.prune"] res = await _run_async(cmd, env=env) return res @@ -1036,7 +899,7 @@ async def qdrant_prune(kwargs: Any = None, **ignored: Any) -> Dict[str, Any]: # --------------------------------------------------------------------------- # Code signal detection imported from mcp_code_signals shim # --------------------------------------------------------------------------- -from scripts.mcp_code_signals import ( +from scripts.mcp_impl.code_signals import ( _CODE_INTENT_CACHE, _CODE_INTENT_LOCK, _CODE_QUERY_ARCHETYPES, @@ -1068,6 +931,7 @@ async def repo_search( collection: Any = None, workspace_path: Any = None, mode: Any = None, + profile: Any = None, session: Any = None, ctx: Context = None, language: Any = None, @@ -1082,6 +946,7 @@ async def repo_search( case: Any = None, repo: Any = None, compact: Any = None, + debug: Any = None, output_format: Any = None, args: Any = None, kwargs: Any = None, @@ -1098,12 +963,14 @@ async def repo_search( - per_path: int (default 2). Max results per file. - include_snippet/context_lines: return inline snippets near hits when true. - rerank_*: ONNX reranker is ON by default for best relevance; timeouts fall back to hybrid. + - profile: Optional useful path profile: tests, config, or code. + - debug: bool (default false). Include verbose internal fields (components, rerank_counters, etc). - output_format: "json" (default) or "toon" for token-efficient TOON format. - collection: str. Target collection; defaults to workspace state or env COLLECTION_NAME. - repo: str or list[str]. Filter by repo name(s). Use "*" to search all repos. Returns: - - Dict with keys: results, total, used_rerank, rerank_counters + - Dict with keys: results, total, used_rerank, [rerank_counters if debug=true] """ return await _repo_search_impl( query=query, @@ -1120,6 +987,7 @@ async def repo_search( collection=collection, workspace_path=workspace_path, mode=mode, + profile=profile, session=session, ctx=ctx, language=language, @@ -1134,6 +1002,7 @@ async def repo_search( case=case, repo=repo, compact=compact, + debug=debug, output_format=output_format, args=args, kwargs=kwargs, @@ -1195,6 +1064,7 @@ async def repo_search_compat(**arguments) -> Dict[str, Any]: "not_": not_value, "case": args.get("case"), "compact": args.get("compact"), + "debug": args.get("debug"), "mode": args.get("mode"), "repo": args.get("repo"), # Cross-codebase isolation "output_format": args.get("output_format"), # "json" or "toon" @@ -1256,138 +1126,8 @@ async def context_answer_compat(arguments: Any = None) -> Dict[str, Any]: # --------------------------------------------------------------------------- -# Specialized search tools - thin wrappers delegating to extracted impls +# symbol_graph - graph query tool # --------------------------------------------------------------------------- -@mcp.tool() -async def search_tests_for( - query: Any = None, - limit: Any = None, - include_snippet: Any = None, - context_lines: Any = None, - under: Any = None, - language: Any = None, - session: Any = None, - compact: Any = None, - kwargs: Any = None, - ctx: Context = None, -) -> Dict[str, Any]: - """Find test files related to a query. - - What it does: - - Presets common test file globs and forwards to repo_search - - Accepts extra filters via kwargs (e.g., language, under, case) - - Parameters: - - query: str or list[str]; limit; include_snippet/context_lines; under; language; compact - - Returns: repo_search result shape. - """ - return await _search_tests_for_impl( - query=query, - limit=limit, - include_snippet=include_snippet, - context_lines=context_lines, - under=under, - language=language, - session=session, - compact=compact, - kwargs=kwargs, - ctx=ctx, - repo_search_fn=repo_search, - ) - - -@mcp.tool() -async def search_config_for( - query: Any = None, - limit: Any = None, - include_snippet: Any = None, - context_lines: Any = None, - under: Any = None, - session: Any = None, - compact: Any = None, - kwargs: Any = None, - ctx: Context = None, -) -> Dict[str, Any]: - """Find likely configuration files for a service/query. - - What it does: - - Presets config file globs (yaml/json/toml/etc.) and forwards to repo_search - - Accepts extra filters via kwargs - - Returns: repo_search result shape. - """ - return await _search_config_for_impl( - query=query, - limit=limit, - include_snippet=include_snippet, - context_lines=context_lines, - under=under, - session=session, - compact=compact, - kwargs=kwargs, - ctx=ctx, - repo_search_fn=repo_search, - ) - - -@mcp.tool() -async def search_callers_for( - query: Any = None, - limit: Any = None, - language: Any = None, - session: Any = None, - kwargs: Any = None, - ctx: Context = None, -) -> Dict[str, Any]: - """Heuristic search for callers/usages of a symbol. - - When to use: - - You want files that reference/invoke a function/class - - Notes: - - Thin wrapper over repo_search today; pass language or path_glob to narrow - - Returns repo_search result shape - """ - return await _search_callers_for_impl( - query=query, - limit=limit, - language=language, - session=session, - kwargs=kwargs, - ctx=ctx, - repo_search_fn=repo_search, - ) - - -@mcp.tool() -async def search_importers_for( - query: Any = None, - limit: Any = None, - language: Any = None, - session: Any = None, - kwargs: Any = None, - ctx: Context = None, -) -> Dict[str, Any]: - """Find files likely importing or referencing a module/symbol. - - What it does: - - Presets code globs across common languages; forwards to repo_search - - Accepts additional filters via kwargs (e.g., under, case) - - Returns: repo_search result shape. - """ - return await _search_importers_for_impl( - query=query, - limit=limit, - language=language, - session=session, - kwargs=kwargs, - ctx=ctx, - repo_search_fn=repo_search, - ) - - @mcp.tool() async def symbol_graph( symbol: str = None, @@ -1395,6 +1135,7 @@ async def symbol_graph( limit: Any = None, language: Any = None, under: Any = None, + collection: Any = None, session: Any = None, output_format: Any = None, ctx: Context = None, @@ -1411,7 +1152,8 @@ async def symbol_graph( - query_type: str. One of "callers", "definition", "importers". - limit: int (default 20). Maximum results to return. - language: str (optional). Filter by programming language. - - under: str (optional). Filter by path prefix. + - under: str (optional). Filter by recursive workspace subtree (e.g., "scripts" -> scripts/**). + - collection: str (optional). Target collection; defaults to env/WS collection. - output_format: "json" (default) or "toon" for token-efficient format. Returns: @@ -1434,6 +1176,7 @@ async def symbol_graph( limit=_limit, language=str(language).strip() if language else None, under=str(under).strip() if under else None, + collection=str(collection).strip() if collection else None, session=str(session).strip() if session else None, ctx=ctx, ) @@ -1510,8 +1253,8 @@ async def change_history_for_path( search_commits_fn=search_commits_for, ) -# --- context_answer helpers imported from mcp_context_answer shim --- -from scripts.mcp_context_answer import ( +# --- context_answer helpers imported from context_answer helper module --- +from scripts.mcp_impl.context_answer import ( _cleanup_answer, _answer_style_guidance, _strip_preamble_labels, @@ -1617,275 +1360,6 @@ async def context_answer( prepare_filters_and_retrieve_fn=_ca_prepare_filters_and_retrieve, ) -@mcp.tool() -async def code_search( - query: Any = None, - limit: Any = None, - per_path: Any = None, - include_snippet: Any = None, - context_lines: Any = None, - rerank_enabled: Any = None, - rerank_top_n: Any = None, - rerank_return_m: Any = None, - rerank_timeout_ms: Any = None, - highlight_snippet: Any = None, - collection: Any = None, - language: Any = None, - under: Any = None, - kind: Any = None, - symbol: Any = None, - path_regex: Any = None, - path_glob: Any = None, - not_glob: Any = None, - ext: Any = None, - not_: Any = None, - case: Any = None, - session: Any = None, - compact: Any = None, - kwargs: Any = None, -) -> Dict[str, Any]: - """Exact alias of repo_search (hybrid code search with reranking enabled by default). - - Prefer repo_search; this name exists for discoverability in some IDEs/agents. - Same parameters and return shape as repo_search. - Reranking (rerank_enabled=true) is ON by default for optimal result quality. - """ - return await repo_search( - query=query, - limit=limit, - per_path=per_path, - include_snippet=include_snippet, - context_lines=context_lines, - rerank_enabled=rerank_enabled, - rerank_top_n=rerank_top_n, - rerank_return_m=rerank_return_m, - rerank_timeout_ms=rerank_timeout_ms, - highlight_snippet=highlight_snippet, - collection=collection, - language=language, - under=under, - kind=kind, - symbol=symbol, - path_regex=path_regex, - path_glob=path_glob, - not_glob=not_glob, - ext=ext, - not_=not_, - case=case, - session=session, - compact=compact, - kwargs=kwargs, - ) - - -# --------------------------------------------------------------------------- -# info_request: Simplified codebase retrieval with explanation mode -# (helpers imported from scripts.mcp_impl.info_request) -# --------------------------------------------------------------------------- -@mcp.tool() -async def info_request( - # Primary parameter - info_request: str = None, - information_request: str = None, # Alias - # Explanation mode - include_explanation: bool = None, - # Relationship mapping - include_relationships: bool = None, - # Auth/session (passed through to repo_search) - session: str = None, - # Optional filters (pass-through to repo_search) - limit: int = None, - language: str = None, - under: str = None, - repo: Any = None, - path_glob: Any = None, - # Additional options - include_snippet: bool = None, - context_lines: int = None, - # Output format - output_format: Any = None, # "json" (default) or "toon" for token-efficient format - kwargs: Any = None, -) -> Dict[str, Any]: - """Simplified codebase retrieval with optional explanation mode. - - When to use: - - Simple, single-parameter code search with human-readable descriptions - - When you want optional explanation mode for richer context - - Drop-in replacement for basic codebase retrieval tools - - Key parameters: - - info_request: str. Natural language description of the code you're looking for. - - information_request: str. Alias for info_request. - - include_explanation: bool (default false). Add summary, primary_locations, related_concepts. - - include_relationships: bool (default false). Add imports_from, calls, related_paths to results. - - limit: int (default 10). Maximum results to return. - - language: str. Filter by programming language. - - under: str. Limit search to specific directory. - - repo: str or list[str]. Filter by repository name(s). - - output_format: "json" (default) or "toon" for token-efficient TOON format. - - Returns: - - Compact mode (default): results with information field and relevance_score alias - - Explanation mode: adds summary, primary_locations, related_concepts, query_understanding - - Example: - - {"info_request": "database connection pooling"} - - {"info_request": "authentication middleware", "include_explanation": true} - """ - # Resolve query from either parameter - query = info_request or information_request - if not query or not str(query).strip(): - return {"ok": False, "error": "info_request parameter is required", "results": []} - query = str(query).strip() - - # Resolve defaults from env - _default_limit = safe_int( - os.environ.get("INFO_REQUEST_LIMIT", "10"), default=10, logger=logger - ) - _default_context = safe_int( - os.environ.get("INFO_REQUEST_CONTEXT_LINES", "5"), default=5, logger=logger - ) - _default_explain = str( - os.environ.get("INFO_REQUEST_EXPLAIN_DEFAULT", "0") - ).strip().lower() in {"1", "true", "yes", "on"} - _default_relationships = str( - os.environ.get("INFO_REQUEST_RELATIONSHIPS", "0") - ).strip().lower() in {"1", "true", "yes", "on"} - - # Apply defaults - eff_limit = limit if limit is not None else _default_limit - eff_context = context_lines if context_lines is not None else _default_context - eff_snippet = include_snippet if include_snippet is not None else True - eff_explain = include_explanation if include_explanation is not None else _default_explain - eff_relationships = include_relationships if include_relationships is not None else _default_relationships - - # Smart limits based on query characteristics (only if user didn't override) - if limit is None: - query_words = len(query.split()) - query_lower = query.lower() - if query_words <= 2: # Short query like "auth handler" - eff_limit = 15 # More results for broad queries - elif "how does" in query_lower or "what is" in query_lower: - eff_limit = 8 # Questions need focused results - - # Call repo_search (always JSON - we format TOON ourselves after enhancement) - search_result = await repo_search( - query=query, - limit=eff_limit, - per_path=3, # Better default for info requests - session=session, - include_snippet=eff_snippet, - context_lines=eff_context, - language=language, - under=under, - repo=repo, - path_glob=path_glob, - output_format="json", # Always get JSON to iterate results - kwargs=kwargs, - ) - - # Extract results - results = search_result.get("results", []) - total = search_result.get("total", len(results)) - used_rerank = search_result.get("used_rerank", False) - - # Enhance each result with information field and optional relationships - enhanced_results = [] - for r in results: - enhanced = dict(r) - enhanced["information"] = _format_information_field(r) - enhanced["relevance_score"] = r.get("score", 0.0) # Alias - # Add relationships if requested - if eff_relationships: - enhanced["relationships"] = _extract_relationships(r) - enhanced_results.append(enhanced) - - # Build better search strategy string - strategy_parts = ["hybrid"] - if used_rerank: - strategy_parts.append("rerank") - if repo: - strategy_parts.append("repo_filtered") - if language: - strategy_parts.append(f"lang:{language}") - if under: - strategy_parts.append("path_filtered") - search_strategy = "+".join(strategy_parts) - - # Build response - response: Dict[str, Any] = { - "ok": True, - "results": enhanced_results, - "total": total, - "search_strategy": search_strategy, - } - - # Add explanation if requested - if eff_explain: - # Primary locations: unique file paths - seen_paths = set() - primary_locations = [] - for r in results: - p = r.get("path", "") - if p and p not in seen_paths: - seen_paths.add(p) - primary_locations.append(p) - if len(primary_locations) >= 5: - break - - # Related concepts - related_concepts = _extract_related_concepts(query, results) - - # Detected symbols from query - detected_symbols = _extract_symbols_from_query(query) - - # Summary - n_files = len(seen_paths) - summary = f"Found {total} results related to '{query}' across {n_files} file{'s' if n_files != 1 else ''}" - - # Group results by file - files_map: Dict[str, list] = {} - for r in enhanced_results: - p = r.get("path", "") - if p not in files_map: - files_map[p] = [] - files_map[p].append({ - "symbol": r.get("symbol", ""), - "line": r.get("start_line", 0), - "score": r.get("score", 0.0), - }) - - grouped_results = { - "by_file": { - path: { - "count": len(items), - "top_symbols": [i["symbol"] for i in sorted(items, key=lambda x: -x["score"])[:3] if i["symbol"]], - } - for path, items in files_map.items() - } - } - - # Calculate confidence - confidence = _calculate_confidence(query, enhanced_results) - - response["summary"] = summary - response["primary_locations"] = primary_locations - response["related_concepts"] = related_concepts - response["grouped_results"] = grouped_results - response["confidence"] = confidence - response["query_understanding"] = { - "intent": "search_for_code", - "detected_language": language or None, - "detected_symbols": detected_symbols, - "search_strategy": search_strategy, - } - - # Apply TOON formatting if requested or enabled globally - if _should_use_toon(output_format): - return _format_results_as_toon(response, compact=False) # Keep info_request fields - return response - - # --------------------------------------------------------------------------- # context_search - thin wrapper delegating to _context_search_impl # --------------------------------------------------------------------------- @@ -2153,7 +1627,7 @@ async def pattern_search( "on", }: try: - from scripts.rerank_local import _get_rerank_session # type: ignore + from scripts.rerank_tools.local import _get_rerank_session # type: ignore _ = _get_rerank_session() except Exception: @@ -2165,7 +1639,8 @@ async def pattern_search( _env["COLLECTION_NAME"] = _default_collection() _cmd = [ "python", - "/work/scripts/rerank_local.py", + "-m", + "scripts.rerank_tools.local", "--query", "warmup", "--topk", diff --git a/scripts/mcp_memory_server.py b/scripts/mcp_memory_server.py index 78712c10..896c9017 100644 --- a/scripts/mcp_memory_server.py +++ b/scripts/mcp_memory_server.py @@ -3,20 +3,8 @@ # to properly instrument vector DB calls. # --------------------------------------------------------------------------- import os -import sys as _sys -# Ensure repo roots are importable so 'scripts' resolves inside container -_roots_env = os.environ.get("WORK_ROOTS", "") -_roots = [p.strip() for p in _roots_env.split(",") if p.strip()] or ["/work", "/app"] -for _root in _roots: - if _root and _root not in _sys.path: - _sys.path.insert(0, _root) - -# Now import OpenLit init (before any other scripts imports that may use qdrant) -try: - from scripts import openlit_init # noqa: F401 - triggers early instrumentation -except ImportError: - pass # OpenLit not available +from scripts import openlit_init # noqa: F401 - triggers early instrumentation import json import threading @@ -41,24 +29,28 @@ from qdrant_client import QdrantClient, models -# Import connection pooling for proper resource management -try: - from scripts.qdrant_client_manager import ( - get_qdrant_client, - return_qdrant_client, - pooled_qdrant_client, - ) - _POOL_AVAILABLE = True -except ImportError: - _POOL_AVAILABLE = False +from scripts.qdrant_client_manager import ( + get_qdrant_client, + return_qdrant_client, +) + +def _env_flag(name: str, default: bool = False) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + +def _resolve_default_collection() -> str: + raw = (os.environ.get("DEFAULT_COLLECTION") or os.environ.get("COLLECTION_NAME") or "").strip() + if _env_flag("MULTI_REPO_MODE") and raw in {"", "codebase"}: + return "" + return raw or "codebase" + # Env QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") -DEFAULT_COLLECTION = ( - os.environ.get("DEFAULT_COLLECTION") - or os.environ.get("COLLECTION_NAME") - or "codebase" -) +DEFAULT_COLLECTION = _resolve_default_collection() LEX_VECTOR_NAME = os.environ.get("LEX_VECTOR_NAME", "lex") LEX_VECTOR_DIM = int(os.environ.get("LEX_VECTOR_DIM", "4096") or 4096) EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") @@ -76,9 +68,9 @@ # I/O-safety knobs for memory server behavior # These env vars allow tuning startup latency vs. first-call latency, especially important # on slow storage backends (e.g., Ceph + HDD). See comments below for rationale. -MEMORY_ENSURE_ON_START = str(os.environ.get("MEMORY_ENSURE_ON_START", "1")).strip().lower() in {"1", "true", "yes", "on"} -MEMORY_COLD_SKIP_DENSE = str(os.environ.get("MEMORY_COLD_SKIP_DENSE", "0")).strip().lower() in {"1", "true", "yes", "on"} -MEMORY_PROBE_EMBED_DIM = str(os.environ.get("MEMORY_PROBE_EMBED_DIM", "1")).strip().lower() in {"1", "true", "yes", "on"} +MEMORY_ENSURE_ON_START = _env_flag("MEMORY_ENSURE_ON_START", False) +MEMORY_COLD_SKIP_DENSE = _env_flag("MEMORY_COLD_SKIP_DENSE", False) +MEMORY_PROBE_EMBED_DIM = _env_flag("MEMORY_PROBE_EMBED_DIM", True) try: MEMORY_VECTOR_DIM = int(os.environ.get("MEMORY_VECTOR_DIM") or os.environ.get("EMBED_DIM") or "768") except Exception: @@ -90,15 +82,8 @@ # Use the centralized embedder from scripts.embedder for consistent caching. # This eliminates duplicate model loading and ensures consistent behavior. -# Reference to the centralized embedder for cold-skip detection -try: - from scripts.embedder import get_embedding_model as _centralized_get_embedding_model - from scripts.embedder import is_model_cached as _is_model_cached - _EMBEDDER_AVAILABLE = True -except ImportError: - _EMBEDDER_AVAILABLE = False - def _is_model_cached(model_name: str = "") -> bool: # type: ignore[misc] - return False # Fallback: assume not cached +from scripts.embedder import get_embedding_model as _centralized_get_embedding_model +from scripts.embedder import is_model_cached as _is_model_cached def _get_embedding_model(): """Get the embedding model using the centralized embedder. @@ -108,12 +93,7 @@ def _get_embedding_model(): - Qwen3 model support with feature flags - Automatic cache invalidation on corrupted downloads """ - if _EMBEDDER_AVAILABLE: - return _centralized_get_embedding_model(EMBEDDING_MODEL) - - # Fallback for environments without centralized embedder (rare) - from fastembed import TextEmbedding - return TextEmbedding(model_name=EMBEDDING_MODEL) + return _centralized_get_embedding_model(EMBEDDING_MODEL) # Track ensured collections to reduce redundant ensure calls. # RATIONALE: Avoid repeated Qdrant network calls for the same collection. @@ -270,31 +250,19 @@ def log_message(self, *args, **kwargs): # --------------------------------------------------------------------------- # Qdrant Client Management # --------------------------------------------------------------------------- -# Use connection pooling when available, fallback to creating clients on-demand. -# This prevents socket exhaustion under load and improves connection reuse. - def _get_qdrant_client() -> QdrantClient: - """Get a Qdrant client from pool or create one.""" - if _POOL_AVAILABLE: - return get_qdrant_client( - url=QDRANT_URL, - api_key=os.environ.get("QDRANT_API_KEY") - ) - return QdrantClient(url=QDRANT_URL, api_key=os.environ.get("QDRANT_API_KEY")) + """Get a Qdrant client from the shared pool.""" + return get_qdrant_client( + url=QDRANT_URL, + api_key=os.environ.get("QDRANT_API_KEY") + ) def _return_qdrant_client(client: QdrantClient): """Return a client to the pool, or close it if pooling unavailable.""" if client is None: return - if _POOL_AVAILABLE: - return_qdrant_client(client) - else: - # Fallback path: close client to avoid socket leak - try: - client.close() - except Exception: - pass # Best effort cleanup + return_qdrant_client(client) # Ensure collection exists with dual vectors @@ -303,13 +271,9 @@ def _return_qdrant_client(client: QdrantClient): def _ensure_collection(name: str): """Create collection if missing. - Default behavior mirrors the original implementation for PR compatibility: - - Probe the embedding model to detect the dense vector dimension (MEMORY_PROBE_EMBED_DIM=1) - - Eager ensure on startup (MEMORY_ENSURE_ON_START=1) - For slow storage backends (e.g., Ceph + HDD), set the following in your env: - MEMORY_PROBE_EMBED_DIM=0 -> skip model probing; use MEMORY_VECTOR_DIM/EMBED_DIM - - MEMORY_ENSURE_ON_START=0 -> ensure lazily on first tool call + - MEMORY_ENSURE_ON_START=1 -> eagerly create DEFAULT_COLLECTION at startup """ client = _get_qdrant_client() try: @@ -379,9 +343,8 @@ def _ensure_collection(name: str): _return_qdrant_client(client) -# Optional eager collection ensure on startup (enabled by default for backward compatibility). -# Set MEMORY_ENSURE_ON_START=0 to defer ensure to first tool call (recommended on slow storage). -if MEMORY_ENSURE_ON_START: +# Optional eager collection ensure for single-collection deployments. +if MEMORY_ENSURE_ON_START and DEFAULT_COLLECTION: try: _ensure_collection(DEFAULT_COLLECTION) except Exception: @@ -767,7 +730,10 @@ def _resolve_collection( except Exception: pass - return coll or DEFAULT_COLLECTION + resolved = coll or DEFAULT_COLLECTION + if not resolved: + raise ValueError("collection is required in multi-repo memory server mode") + return resolved if __name__ == "__main__": diff --git a/scripts/mcp_router.py b/scripts/mcp_router.py deleted file mode 100644 index c872954a..00000000 --- a/scripts/mcp_router.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 -""" -Backwards-compatibility shim for mcp_router. - -NOTE: This file is NOT used by Python when scripts/mcp_router/ package exists. -Python's import system prioritizes packages (directories with __init__.py) over -modules (.py files) with the same name. This file exists for: - -1. Documentation: Shows all available exports at a glance -2. Symmetry: Matches the pattern used by rerank_recursive.py -3. Fallback: Would work if the package directory were removed - -All imports resolve to scripts.mcp_router/ (the package): - from scripts.mcp_router import build_plan, classify_intent - # or - from scripts import mcp_router - mcp_router.build_plan("query") - -Usage: - python -m scripts.mcp_router --plan "How do I ...?" - python -m scripts.mcp_router --run "What is hybrid search?" -""" -from __future__ import annotations - -# Re-export everything from the package -from scripts.mcp_router import ( - # Config - HTTP_URL_INDEXER, - HTTP_URL_MEMORY, - DEFAULT_HTTP_URL, - HEALTH_PORT_INDEXER, - HEALTH_PORT_MEMORY, - LANGS, - cache_ttl_sec, - scratchpad_ttl_sec, - divergence_thresholds, - divergence_is_fatal_for, - # Intent constants - INTENT_ANSWER, - INTENT_SEARCH, - INTENT_SEARCH_TESTS, - INTENT_SEARCH_CONFIG, - INTENT_SEARCH_CALLERS, - INTENT_SEARCH_IMPORTERS, - INTENT_MEMORY_STORE, - INTENT_MEMORY_FIND, - INTENT_INDEX, - INTENT_PRUNE, - INTENT_STATUS, - INTENT_LIST, - # Intent functions - classify_intent, - get_last_intent_debug, - _classify_intent_rules, - # Memory - parse_memory_store_payload, - # Client - call_tool_http, - is_failure_response, - discover_tool_endpoints, - default_tool_endpoints, - tools_describe_cached, - _mcp_handshake, - _post_raw, - _post_raw_retry, - _parse_stream_or_json, - _filter_args, - # Scratchpad - scratchpad_path, - load_scratchpad, - save_scratchpad, - looks_like_repeat, - looks_like_same_filters, - looks_like_expand, - # Hints - parse_repo_hints, - clean_query_and_dsl, - select_best_search_tool_by_signature, - # Batching - BatchingContextAnswerClient, - get_batch_client, - # Validation - is_result_good, - extract_metric_from_resp, - material_drop, - # Planning - build_plan, - # CLI - main, - # Legacy aliases - _is_failure_response, - _is_result_good, - _discover_tool_endpoints, -) - -__all__ = [ - # Config - "HTTP_URL_INDEXER", - "HTTP_URL_MEMORY", - "DEFAULT_HTTP_URL", - "HEALTH_PORT_INDEXER", - "HEALTH_PORT_MEMORY", - "LANGS", - "cache_ttl_sec", - "scratchpad_ttl_sec", - "divergence_thresholds", - "divergence_is_fatal_for", - # Intent - "INTENT_ANSWER", - "INTENT_SEARCH", - "INTENT_SEARCH_TESTS", - "INTENT_SEARCH_CONFIG", - "INTENT_SEARCH_CALLERS", - "INTENT_SEARCH_IMPORTERS", - "INTENT_MEMORY_STORE", - "INTENT_MEMORY_FIND", - "INTENT_INDEX", - "INTENT_PRUNE", - "INTENT_STATUS", - "INTENT_LIST", - "classify_intent", - "get_last_intent_debug", - # Memory - "parse_memory_store_payload", - # Client - "call_tool_http", - "is_failure_response", - "discover_tool_endpoints", - "default_tool_endpoints", - "tools_describe_cached", - # Scratchpad - "scratchpad_path", - "load_scratchpad", - "save_scratchpad", - "looks_like_repeat", - "looks_like_same_filters", - "looks_like_expand", - # Hints - "parse_repo_hints", - "clean_query_and_dsl", - "select_best_search_tool_by_signature", - # Batching - "BatchingContextAnswerClient", - "get_batch_client", - # Validation - "is_result_good", - "extract_metric_from_resp", - "material_drop", - # Planning - "build_plan", - # CLI - "main", -] - -if __name__ == "__main__": - import sys - raise SystemExit(main(sys.argv[1:])) diff --git a/scripts/mcp_router/__init__.py b/scripts/mcp_router/__init__.py deleted file mode 100644 index 0bf56bac..00000000 --- a/scripts/mcp_router/__init__.py +++ /dev/null @@ -1,206 +0,0 @@ -""" -mcp_router - Modular MCP routing package. - -This package provides intent classification, tool planning, and HTTP execution -for routing queries to the appropriate MCP tools. - -Public API: -- classify_intent: Determine query intent -- build_plan: Create execution plan for query -- call_tool_http: Execute MCP tool over HTTP -- discover_tool_endpoints: Find available tools -""" -from __future__ import annotations - -# Config exports -from .config import ( - HTTP_URL_INDEXER, - HTTP_URL_MEMORY, - DEFAULT_HTTP_URL, - HEALTH_PORT_INDEXER, - HEALTH_PORT_MEMORY, - LANGS, - cache_ttl_sec, - scratchpad_ttl_sec, - divergence_thresholds, - divergence_is_fatal_for, -) - -# Intent exports -from .intent import ( - INTENT_ANSWER, - INTENT_SEARCH, - INTENT_SEARCH_TESTS, - INTENT_SEARCH_CONFIG, - INTENT_SEARCH_CALLERS, - INTENT_SEARCH_IMPORTERS, - INTENT_MEMORY_STORE, - INTENT_MEMORY_FIND, - INTENT_INDEX, - INTENT_PRUNE, - INTENT_STATUS, - INTENT_LIST, - classify_intent, - get_last_intent_debug, -) - -# Memory exports -from .memory import parse_memory_store_payload - -# Client exports -from .client import ( - call_tool_http, - is_failure_response, - discover_tool_endpoints, - default_tool_endpoints, - tools_describe_cached, - _mcp_handshake, - _post_raw, - _post_raw_retry, - _parse_stream_or_json, - _filter_args, -) - -# Scratchpad exports -from .scratchpad import ( - scratchpad_path, - load_scratchpad, - save_scratchpad, - looks_like_repeat, - looks_like_same_filters, - looks_like_expand, -) - -# Hints exports -from .hints import ( - parse_repo_hints, - clean_query_and_dsl, - select_best_search_tool_by_signature, -) - -# Batching exports -from .batching import ( - BatchingContextAnswerClient, - get_batch_client, -) - -# Validation exports -from .validation import ( - is_result_good, - extract_metric_from_resp, - material_drop, -) - -# Planning exports -from .planning import build_plan, route_query - -# --------------------------------------------------------------------------- -# Private function imports for backward compatibility -# --------------------------------------------------------------------------- -from .intent import _classify_intent_rules - -# --------------------------------------------------------------------------- -# Legacy aliases (underscore-prefixed) for backward compatibility -# --------------------------------------------------------------------------- -_LAST_INTENT_DEBUG = {} # Use get_last_intent_debug() instead -_BATCH_CLIENT = None # Lazy initialized - -def _get_batch_client(): - global _BATCH_CLIENT - if _BATCH_CLIENT is None: - _BATCH_CLIENT = get_batch_client() - return _BATCH_CLIENT - -# Function aliases -_parse_memory_store_payload = parse_memory_store_payload -_looks_like_repeat = looks_like_repeat -_looks_like_same_filters = looks_like_same_filters -_looks_like_expand = looks_like_expand -_load_scratchpad = load_scratchpad -_save_scratchpad = save_scratchpad -_scratchpad_path = scratchpad_path -_scratchpad_ttl_sec = scratchpad_ttl_sec -_cache_ttl_sec = cache_ttl_sec -_discover_tool_endpoints = discover_tool_endpoints -_default_tool_endpoints = default_tool_endpoints -_tools_describe_cached = tools_describe_cached -_is_failure_response = is_failure_response -_is_result_good = is_result_good -_extract_metric_from_resp = extract_metric_from_resp -_material_drop = material_drop -_divergence_thresholds = divergence_thresholds -_divergence_is_fatal_for = divergence_is_fatal_for -_parse_repo_hints = parse_repo_hints -_clean_query_and_dsl = clean_query_and_dsl -_select_best_search_tool_by_signature = select_best_search_tool_by_signature - -# Health port aliases -_HEALTH_PORT_INDEXER = HEALTH_PORT_INDEXER -_HEALTH_PORT_MEMORY = HEALTH_PORT_MEMORY - -# Language set alias -_LANGS = LANGS - - -__all__ = [ - # Config - "HTTP_URL_INDEXER", - "HTTP_URL_MEMORY", - "DEFAULT_HTTP_URL", - "HEALTH_PORT_INDEXER", - "HEALTH_PORT_MEMORY", - "LANGS", - "cache_ttl_sec", - "scratchpad_ttl_sec", - "divergence_thresholds", - "divergence_is_fatal_for", - # Intent - "INTENT_ANSWER", - "INTENT_SEARCH", - "INTENT_SEARCH_TESTS", - "INTENT_SEARCH_CONFIG", - "INTENT_SEARCH_CALLERS", - "INTENT_SEARCH_IMPORTERS", - "INTENT_MEMORY_STORE", - "INTENT_MEMORY_FIND", - "INTENT_INDEX", - "INTENT_PRUNE", - "INTENT_STATUS", - "INTENT_LIST", - "classify_intent", - "get_last_intent_debug", - # Memory - "parse_memory_store_payload", - # Client - "call_tool_http", - "is_failure_response", - "discover_tool_endpoints", - "default_tool_endpoints", - "tools_describe_cached", - # Scratchpad - "scratchpad_path", - "load_scratchpad", - "save_scratchpad", - "looks_like_repeat", - "looks_like_same_filters", - "looks_like_expand", - # Hints - "parse_repo_hints", - "clean_query_and_dsl", - "select_best_search_tool_by_signature", - # Batching - "BatchingContextAnswerClient", - "get_batch_client", - # Validation - "is_result_good", - "extract_metric_from_resp", - "material_drop", - # Planning - "build_plan", - "route_query", - # CLI - "main", -] - -# Import main for CLI compatibility -from .cli import main diff --git a/scripts/mcp_router/__main__.py b/scripts/mcp_router/__main__.py deleted file mode 100644 index f7c2d058..00000000 --- a/scripts/mcp_router/__main__.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -""" -Allow running as: python -m scripts.mcp_router "query" -""" -from .cli import main - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/mcp_router/batching.py b/scripts/mcp_router/batching.py deleted file mode 100644 index 7300a87e..00000000 --- a/scripts/mcp_router/batching.py +++ /dev/null @@ -1,331 +0,0 @@ -""" -mcp_router/batching.py - Context answer batching client. -""" -from __future__ import annotations - -import json -import os -import re -import sys -import threading -import time -from typing import Any, Dict - -from .config import HTTP_URL_INDEXER -from .client import call_tool_http - - -class BatchingContextAnswerClient: - """Lightweight in-memory batching for context_answer calls. - - - Queues short-lived requests keyed by (base_url, collection, filters_fingerprint) - - Flushes after a small window or when batch size cap is hit - - For multi-item batches, sends query=[...] with mode="pack" - - Shares the same response with all enqueued requests - """ - - def __init__(self, call_func=None, enable: bool | None = None, window_ms: int | None = None, - max_batch: int | None = None, budget_ms: int | None = None): - self._call = call_func or call_tool_http - if enable is None: - env_enabled = os.environ.get("ROUTER_BATCH_ENABLED") - if env_enabled is None: - env_enabled = os.environ.get("ROUTER_BATCH_ENABLE", "0") - self.enabled = str(env_enabled).strip().lower() in {"1", "true", "yes", "on"} - else: - self.enabled = bool(enable) - self.window_ms = int(os.environ.get("ROUTER_BATCH_WINDOW_MS", str(window_ms if window_ms is not None else 100)) or 100) - env_max = os.environ.get("ROUTER_BATCH_MAX_SIZE") - if env_max is None: - env_max = os.environ.get("ROUTER_BATCH_MAX") - self.max_batch = int(env_max or (max_batch if max_batch is not None else 8)) - env_budget = os.environ.get("ROUTER_BATCH_LATENCY_BUDGET_MS") - if env_budget is None: - env_budget = os.environ.get("ROUTER_BATCH_BUDGET_MS") - self.budget_ms = int(env_budget or (budget_ms if budget_ms is not None else 2000)) - self._lock = threading.RLock() - self._groups: dict[str, dict[str, Any]] = {} - - def _should_bypass(self, args: Dict[str, Any]) -> bool: - try: - if isinstance(args, dict): - v = args.get("immediate") - if v is not None and str(v).strip().lower() in {"1", "true", "yes", "on"}: - return True - except Exception: - pass - if str(os.environ.get("ROUTER_BATCH_BYPASS", "0")).strip().lower() in {"1", "true", "yes", "on"}: - return True - try: - q = str((args or {}).get("query") or "") - if "immediate answer" in q.lower(): - return True - except Exception: - pass - return False - - def _norm_query(self, q: str) -> str: - try: - return re.sub(r"\s+", " ", str(q or "").strip()) - except Exception: - return str(q) - - def _filters_fingerprint(self, args: Dict[str, Any]) -> str: - keep = { - "collection", "language", "under", "kind", "symbol", "ext", - "path_regex", "path_glob", "not_glob", "not_", "case", - "limit", "per_path", "include_snippet", - } - try: - filt = {k: args.get(k) for k in keep if k in args} - def _norm(v): - if v is None: - return None - if isinstance(v, (list, tuple)): - return [str(x) for x in v] - return v - clean = {k: _norm(v) for k, v in filt.items()} - return json.dumps(clean, sort_keys=True, ensure_ascii=False) - except Exception: - return "{}" - - def _group_key(self, base_url: str, args: Dict[str, Any]) -> str: - coll = str(args.get("collection") or "") - fp = self._filters_fingerprint(args) - repo = os.getcwd() - return f"{base_url}|{coll}|answer|{fp}|{repo}" - - def call_or_enqueue(self, base_url: str, tool: str, args: Dict[str, Any], timeout: float = 120.0) -> Dict[str, Any]: - if not self.enabled: - return self._call(base_url, tool, args, timeout=timeout) - if self._should_bypass(args): - return self._call(base_url, tool, args, timeout=timeout) - - start_ts = time.time() - key = self._group_key(base_url, args or {}) - norm_q = self._norm_query((args or {}).get("query") or "") - ev = threading.Event() - slot = {"event": ev, "result": None, "error": None, "query": norm_q, "args": dict(args or {})} - - with self._lock: - g = self._groups.get(key) - if not g: - g = { - "created": time.time(), - "items": [], - "timer": None, - } - self._groups[key] = g - g["items"].append(slot) - if g["timer"] is None: - delay = max(0.0, float(self.window_ms) / 1000.0) - t = threading.Timer(delay, self._flush, args=(key,)) - g["timer"] = t - t.daemon = True - t.start() - if len(g["items"]) >= self.max_batch: - t = g.get("timer") - if t: - try: - t.cancel() - except Exception: - pass - g["timer"] = None - threading.Thread(target=self._flush, args=(key,), daemon=True).start() - - remain = max(0.05, self.budget_ms / 1000.0) - ev.wait(timeout=min(timeout, remain)) - if not ev.is_set(): - try: - res = self._call(base_url, tool, args, timeout=timeout) - slot["result"] = res - ev.set() - try: - with self._lock: - gg = self._groups.get(key) - if gg: - lst = gg.get("items") or [] - if slot in lst: - try: - lst.remove(slot) - except Exception: - pass - if not lst: - t2 = gg.get("timer") - if t2: - try: - t2.cancel() - except Exception: - pass - self._groups.pop(key, None) - except Exception: - pass - try: - print(json.dumps({"router": {"batch_fallback": True, "elapsed_ms": int((time.time()-start_ts)*1000)}}), file=sys.stderr) - except Exception: - pass - return res - except Exception as e: - slot["error"] = e - ev.set() - raise - - if slot.get("error") is not None: - raise slot["error"] - return slot.get("result") or {} - - def _flush(self, key: str) -> None: - with self._lock: - g = self._groups.get(key) - if not g: - return - items = g.get("items") or [] - g["items"] = [] - g["timer"] = None - if not items: - self._groups.pop(key, None) - return - - unique_q: list[str] = [] - seen_q = set() - for it in items: - q = it.get("query") or "" - if q not in seen_q: - seen_q.add(q) - unique_q.append(q) - first_args = dict(items[0].get("args") or {}) - forward = {k: v for k, v in first_args.items() if k not in {"query", "queries"}} - base_url = None - try: - base_url = key.split("|")[0] - except Exception: - base_url = HTTP_URL_INDEXER - - started = time.time() - results_by_q: Dict[str, Any] = {} - errors_by_q: Dict[str, Exception] = {} - calls = 0 - try: - import copy as _copy - except Exception: - _copy = None - - if len(unique_q) > 1: - args_all = dict(forward) - args_all["query"] = list(unique_q) - args_all["mode"] = args_all.get("mode") or "pack" - try: - agg_res = self._call(base_url, "context_answer", args_all, timeout=120.0) - calls = 1 - try: - payload = ((agg_res or {}).get("result") or {}).get("structuredContent") or {} - body = (payload.get("result") or {}) - except Exception: - payload, body = {}, {} - - abq = None - try: - abq = body.get("answers_by_query") - except Exception: - abq = None - if isinstance(abq, list) and abq: - _map: Dict[str, Any] = {} - by_idx = (len(abq) >= len(unique_q)) - for i, entry in enumerate(abq): - try: - qv = entry.get("query") - qk = None - if isinstance(qv, list) and qv: - qk = str(qv[0]) - elif isinstance(qv, str): - qk = qv - except Exception: - qk = None - entry_key = qk if qk else (unique_q[i] if by_idx and i < len(unique_q) else None) - if not entry_key: - continue - per = _copy.deepcopy(agg_res) if _copy else json.loads(json.dumps(agg_res)) - try: - per_body = (per.get("result") or {}).get("structuredContent", {}).get("result", {}) - except Exception: - per_body = None - try: - ans_i = str(entry.get("answer") or "") - cits_i = entry.get("citations") or [] - if per_body is not None: - per_body["answer"] = ans_i - per_body["citations"] = cits_i - per_body["query"] = [entry_key] - except Exception: - pass - _map[str(entry_key)] = per - for uq in unique_q: - if str(uq) in _map: - results_by_q[uq] = _map[str(uq)] - remaining = [uq for uq in unique_q if uq not in results_by_q] - else: - remaining = list(unique_q) - - if remaining: - for uq in remaining: - args_i = dict(forward) - args_i["query"] = uq - try: - results_by_q[uq] = self._call(base_url, "context_answer", args_i, timeout=120.0) - except Exception as e: - errors_by_q[uq] = e - calls += len(remaining) - except Exception as e: - for uq in unique_q: - errors_by_q[uq] = e - calls = 1 - else: - args1 = dict(forward) - args1["query"] = unique_q[0] if unique_q else "" - try: - results_by_q[args1["query"]] = self._call(base_url, "context_answer", args1, timeout=120.0) - except Exception as e: - errors_by_q[args1["query"]] = e - calls = 1 - - elapsed_ms = int((time.time() - started) * 1000) - try: - print(json.dumps({ - "router": { - "batch_flushed": True, - "n_items": len(items), - "unique_q": len(unique_q), - "calls": int(calls), - "elapsed_ms": elapsed_ms, - "ok": (len(errors_by_q) == 0), - } - }), file=sys.stderr) - except Exception: - pass - - for it in items: - q = it.get("query") or "" - it["result"] = results_by_q.get(q) - it["error"] = errors_by_q.get(q) - ev = it.get("event") - try: - if hasattr(ev, "set"): - ev.set() - except Exception: - pass - with self._lock: - gg = self._groups.get(key) - if gg and not gg.get("items"): - self._groups.pop(key, None) - - -# Global client singleton -_BATCH_CLIENT: BatchingContextAnswerClient | None = None - - -def get_batch_client() -> BatchingContextAnswerClient: - """Get or create global batch client.""" - global _BATCH_CLIENT - if _BATCH_CLIENT is None: - _BATCH_CLIENT = BatchingContextAnswerClient() - return _BATCH_CLIENT diff --git a/scripts/mcp_router/cli.py b/scripts/mcp_router/cli.py deleted file mode 100644 index 50d63d55..00000000 --- a/scripts/mcp_router/cli.py +++ /dev/null @@ -1,305 +0,0 @@ -#!/usr/bin/env python3 -""" -mcp_router/cli.py - CLI entrypoint for MCP router. - -Usage: - python -m scripts.mcp_router --plan "How do I ...?" - python -m scripts.mcp_router --run "What is hybrid search?" -""" -from __future__ import annotations - -import argparse -import json -import re -import sys -import time -from typing import Any, Dict, List - -from .config import HTTP_URL_INDEXER, scratchpad_ttl_sec, divergence_thresholds -from .client import call_tool_http, is_failure_response, discover_tool_endpoints -from .scratchpad import ( - load_scratchpad, - save_scratchpad, - looks_like_repeat, - looks_like_expand, -) -from .validation import ( - is_result_good, - extract_metric_from_resp, - material_drop, -) -from .batching import get_batch_client -from .planning import build_plan -from .config import divergence_is_fatal_for - - -def main(argv: List[str] | None = None) -> int: - """Main CLI entrypoint.""" - if argv is None: - argv = sys.argv[1:] - - ap = argparse.ArgumentParser() - ap.add_argument("query", help="User query to route") - ap.add_argument("--plan", action="store_true", help="Only print routing plan (no execution)") - ap.add_argument("--run", action="store_true", help="Execute the routed tool(s) over HTTP") - ap.add_argument("--timeout", type=float, default=180.0, help="HTTP timeout for tool calls") - args = ap.parse_args(argv) - - plan = build_plan(args.query) - print(json.dumps({"router": {"url": HTTP_URL_INDEXER, "plan": plan}}, indent=2)) - - if args.plan and not args.run: - return 0 - - # Load scratchpad for prior context - sp = {} - fresh = False - prior_answer = None - prior_citations = None - prior_paths = None - try: - sp = load_scratchpad() - ts = float(sp.get("timestamp") or 0.0) - fresh = bool(ts and (time.time() - ts) <= scratchpad_ttl_sec()) - if fresh: - prior_answer = sp.get("last_answer") - prior_citations = sp.get("last_citations") - prior_paths = sp.get("last_paths") - except Exception: - pass - - # Execute sequentially until one succeeds - last_err = None - last = None - tool_servers = discover_tool_endpoints() - mem_snippets: list[str] = list(sp.get("mem_snippets") or []) if fresh else [] - batch_client = get_batch_client() - - for idx, (tool, targs) in enumerate(plan): - base_url = tool_servers.get(tool, HTTP_URL_INDEXER) - - # Skip memory.find if we already have fresh snippets and this is a repeat/expand - if (tool.lower().endswith("find") or tool.lower() in {"find", "memory.find"}) and mem_snippets and fresh and (looks_like_repeat(args.query) or looks_like_expand(args.query)): - try: - print(json.dumps({"tool": tool, "skipped": "scratchpad_fresh"})) - except Exception: - pass - continue - - # Augment answer queries with context - if tool in {"context_answer", "context_answer_compat"} and (mem_snippets or (fresh and (prior_answer or prior_citations or prior_paths))): - try: - tq = str((targs or {}).get("query") or args.query) - sections = [tq] - if mem_snippets: - bullets = [] - for s in mem_snippets[:3]: - ss = re.sub(r"\s+", " ", str(s)).strip() - if len(ss) > 200: - ss = ss[:197] + "..." - bullets.append(f"- {ss}") - sections.append("Memory context:\n" + "\n".join(bullets)) - if fresh and (looks_like_expand(args.query) or looks_like_repeat(args.query)): - if isinstance(prior_answer, str) and prior_answer.strip(): - pa = re.sub(r"\s+", " ", prior_answer).strip() - if len(pa) > 400: - pa = pa[:397] + "..." - sections.append("Prior summary:\n" + pa) - paths_list = [] - if isinstance(prior_paths, list) and prior_paths: - paths_list = [str(p) for p in prior_paths[:5]] - elif isinstance(prior_citations, list) and prior_citations: - uniq = [] - for c in prior_citations: - if isinstance(c, dict) and c.get("path") and c["path"] not in uniq: - uniq.append(c["path"]) - paths_list = uniq[:5] - if paths_list: - sections.append("Citations context:\n" + "\n".join(f"- {p}" for p in paths_list)) - aug = "\n\n".join(sections) - targs = {**(targs or {}), "query": aug} - except Exception: - pass - - try: - if tool in {"context_answer", "context_answer_compat"}: - res = batch_client.call_or_enqueue(base_url, tool, targs, timeout=args.timeout) - else: - res = call_tool_http(base_url, tool, targs, timeout=args.timeout) - print(json.dumps({"tool": tool, "result": res}, indent=2)) - last = res - - # Capture memory snippets - try: - if tool.lower().endswith("find") or tool.lower() in {"find", "memory.find"}: - r = res.get("result") or {} - items = [] - sc = r.get("structuredContent") - if isinstance(sc, dict): - rs0 = sc.get("result") or sc - if isinstance(rs0, dict): - items = rs0.get("results") or rs0.get("hits") or [] - if not items: - content = r.get("content") - if isinstance(content, list): - for c in content: - if not isinstance(c, dict): - continue - if "json" in c: - j = c.get("json") - if isinstance(j, (dict, list)): - container = j.get("result") if isinstance(j, dict) and "result" in j else j - if isinstance(container, dict): - items = container.get("results") or container.get("hits") or [] - if items: - break - if c.get("type") == "text": - ttxt = c.get("text") - if isinstance(ttxt, str) and ttxt.strip(): - try: - j = json.loads(ttxt) - except Exception: - continue - container = j.get("result") if isinstance(j, dict) and "result" in j else j - if isinstance(container, dict): - items = container.get("results") or container.get("hits") or [] - if items: - break - for it in items: - if isinstance(it, dict): - txt = it.get("information") or it.get("content") or it.get("text") - if isinstance(txt, str) and txt.strip(): - mem_snippets.append(txt.strip()) - except Exception: - pass - - # Determine if we should treat this step as terminal - has_future_answer = any(tn in {"context_answer", "context_answer_compat"} for (tn, _) in plan[idx + 1:]) - if (not is_failure_response(res)) and is_result_good(tool, res): - if tool.lower() in {"find", "memory.find"} and has_future_answer: - continue - - # Persist scratchpad - try: - last_filters: Dict[str, Any] = {} - for (tn, ta) in plan: - if tn == "repo_search" or tn.startswith("search_"): - if isinstance(ta, dict): - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - if ta.get(k) not in (None, ""): - last_filters[k] = ta.get(k) - break - - last_answer_text = None - last_citations_list = None - last_paths_list: list[str] | None = None - if tool in {"context_answer", "context_answer_compat"}: - try: - r0 = res.get("result") or {} - sc0 = r0.get("structuredContent") or {} - rs0 = sc0.get("result") or sc0 - if isinstance(rs0, dict): - ans0 = rs0.get("answer") - if isinstance(ans0, str): - last_answer_text = ans0 - cites0 = rs0.get("citations") - if isinstance(cites0, list): - last_citations_list = cites0 - uniqp: list[str] = [] - for c in cites0: - if isinstance(c, dict) and c.get("path") and c["path"] not in uniqp: - uniqp.append(c["path"]) - last_paths_list = uniqp - except Exception: - pass - - # Divergence detection - divergence_should_abort = False - last_metrics_prev = {} - try: - last_metrics_prev = sp.get("last_metrics") or {} - if not isinstance(last_metrics_prev, dict): - last_metrics_prev = {} - except Exception: - last_metrics_prev = {} - metric = extract_metric_from_resp(tool, res) - last_metrics_map = dict(last_metrics_prev) - if metric is not None: - mname, mval = metric - prev_val = None - try: - prev_val = last_metrics_prev.get(tool, {}).get(mname) - if prev_val is not None: - prev_val = float(prev_val) - except Exception: - prev_val = None - drop_frac, min_base = divergence_thresholds() - if material_drop(prev_val, float(mval), drop_frac, min_base): - fatal = divergence_is_fatal_for(tool) - try: - print(json.dumps({ - "divergence": { - "tool": tool, - "metric": mname, - "previous": prev_val, - "current": float(mval), - "drop_frac": drop_frac, - "fatal": fatal, - } - })) - except Exception: - pass - if fatal: - divergence_should_abort = True - try: - last_metrics_map.setdefault(tool, {})[mname] = float(mval) - except Exception: - pass - else: - last_metrics_map = last_metrics_prev - - success_criteria = { - "context_answer": {"expected_fields": ["answer"], "min_citations": 0}, - "context_answer_compat": {"expected_fields": ["answer"], "min_citations": 0}, - "repo_search": {"min_results": 1}, - "search_config_for": {"min_results": 1}, - "search_tests_for": {"min_results": 1}, - "search_callers_for": {"min_results": 1}, - "search_importers_for": {"min_results": 1}, - "find": {"min_results": 1}, - } - sp = { - "last_query": args.query, - "last_plan": plan, - "last_filters": last_filters or None, - "mem_snippets": mem_snippets[:5], - "last_answer": last_answer_text, - "last_citations": last_citations_list, - "last_paths": last_paths_list, - "success_criteria": success_criteria, - "last_metrics": last_metrics_map, - "timestamp": time.time(), - } - save_scratchpad(sp) - except Exception: - pass - - if divergence_should_abort: - continue - - return 0 - except Exception as e: - last_err = e - try: - print(json.dumps({"tool": tool, "server": base_url, "error": str(e)}), file=sys.stderr) - except Exception: - pass - continue - - if last_err: - print(f"Router: all attempts failed: {last_err}", file=sys.stderr) - return 1 if last is not None else 2 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/mcp_router/client.py b/scripts/mcp_router/client.py deleted file mode 100644 index 88503ef6..00000000 --- a/scripts/mcp_router/client.py +++ /dev/null @@ -1,318 +0,0 @@ -""" -mcp_router/client.py - HTTP/MCP client helpers. -""" -from __future__ import annotations - -import json -import os -import time -from typing import Any, Dict, List, Tuple -from urllib import request - -from .config import ( - HTTP_URL_INDEXER, - HTTP_URL_MEMORY, - HEALTH_PORT_INDEXER, - HEALTH_PORT_MEMORY, - cache_ttl_sec, -) - -# Caches -_TOOL_ENDPOINTS_CACHE_MAP: Dict[str, str] = {} -_TOOL_ENDPOINTS_CACHE_TS: float = 0.0 -_TOOLS_DESCR_CACHE: Dict[str, list] = {} -_TOOLS_DESCR_TS: Dict[str, float] = {} - - -def _post_raw(url: str, payload: Dict[str, Any], headers: Dict[str, str], timeout: float = 60.0) -> Tuple[Dict[str, str], bytes]: - req = request.Request(url, method="POST") - for k, v in headers.items(): - req.add_header(k, v) - data = json.dumps(payload).encode("utf-8") - with request.urlopen(req, data=data, timeout=timeout) as resp: - body = resp.read() - hdrs = {k.lower(): v for k, v in resp.headers.items()} - return hdrs, body - - -def _post_raw_retry(url: str, payload: Dict[str, Any], headers: Dict[str, str], - timeout: float = 60.0, retries: int = 2, backoff: float = 0.5) -> Tuple[Dict[str, str], bytes]: - last_exc: Exception | None = None - for i in range(max(0, retries) + 1): - try: - return _post_raw(url, payload, headers, timeout=timeout) - except Exception as e: - last_exc = e - if i < retries: - try: - time.sleep(backoff * (2 ** i)) - except Exception: - pass - else: - raise last_exc - - -def _parse_stream_or_json(body: bytes) -> Dict[str, Any]: - txt = body.decode("utf-8", errors="ignore") - if "data:" in txt and ("event:" in txt or txt.strip().startswith("data:")): - last = None - for line in txt.splitlines(): - if line.startswith("data:"): - last = line[len("data:"):].strip() - if last: - try: - return json.loads(last) - except Exception: - pass - return json.loads(txt) - - -def _filter_args(d: Dict[str, Any]) -> Dict[str, Any]: - """Remove None/empty values from args dict.""" - return {k: v for k, v in d.items() if v not in (None, "")} - - -def _mcp_handshake(base_url: str, timeout: float = 30.0) -> Dict[str, str]: - """Perform MCP handshake and return headers with session ID.""" - headers = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - } - init_payload = { - "jsonrpc": "2.0", - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "router", "version": "0.1.0"}, - }, - "id": 1, - } - hdrs, body = _post_raw_retry(base_url, init_payload, headers, timeout=timeout) - sid = hdrs.get("mcp-session-id") or hdrs.get("Mcp-Session-Id") - if not sid: - try: - j = _parse_stream_or_json(body) - sid = j.get("sessionId") - except Exception: - sid = None - if sid: - headers["Mcp-Session-Id"] = sid - try: - _post_raw_retry(base_url, {"jsonrpc": "2.0", "method": "notifications/initialized"}, headers, timeout=timeout) - except Exception: - pass - return headers - - -def _extract_iserror_text(resp: Dict[str, Any]) -> str | None: - try: - r = resp.get("result") or {} - if isinstance(r, dict) and r.get("isError"): - content = r.get("content") - if isinstance(content, list) and content and isinstance(content[0], dict): - if content[0].get("type") == "text": - return content[0].get("text") - except Exception: - pass - return None - - -def call_tool_http(base_url: str, tool_name: str, args: Dict[str, Any], timeout: float = 120.0) -> Dict[str, Any]: - """Call an MCP tool over HTTP.""" - headers = _mcp_handshake(base_url, timeout=min(timeout, 30.0)) - - def _do_call(arguments: Dict[str, Any]) -> Dict[str, Any]: - payload = { - "jsonrpc": "2.0", - "id": "router-1", - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments, - }, - } - _, body = _post_raw_retry(base_url, payload, headers, timeout=timeout) - return _parse_stream_or_json(body) - - args1 = _filter_args(args) - if tool_name.endswith("_compat"): - resp = _do_call({"arguments": args1}) - else: - resp = _do_call(args1) - - def _get_structured_error(r: Dict[str, Any]) -> str | None: - try: - rr = r.get("result") or {} - sc = rr.get("structuredContent") or {} - rs = sc.get("result") or {} - err = rs.get("error") - if isinstance(err, str): - return err - except Exception: - pass - return None - - msg = _extract_iserror_text(resp) - serr = _get_structured_error(resp) - if msg: - low = msg.lower() - if ("kwargs" in low) and ("field required" in low or "missing" in low): - return _do_call({"kwargs": args1}) - if ("arguments" in low) and ("field required" in low or "missing" in low): - return _do_call({"arguments": args1}) - if (serr and serr.strip().lower() == "query required") and ("query" in args1 or "queries" in args1): - resp4 = _do_call({"kwargs": args1}) - serr2 = _get_structured_error(resp4) - if not (serr2 and serr2.strip().lower() == "query required"): - return resp4 - resp5 = _do_call({"arguments": {"kwargs": args1}}) - serr3 = _get_structured_error(resp5) - if not (serr3 and serr3.strip().lower() == "query required"): - return resp5 - return _do_call({"arguments": args1}) - return resp - - -def is_failure_response(resp: Dict[str, Any]) -> bool: - """Check if response indicates a failure.""" - try: - r = resp.get("result") or {} - if r.get("isError") is True: - return True - sc = r.get("structuredContent") or {} - rs = sc.get("result") or {} - if isinstance(rs, dict) and isinstance(rs.get("error"), str): - return True - except Exception: - return False - return False - - -def _tools_describe_from_health(base_url: str, timeout: float = 3.0) -> list[dict]: - """Fetch tool descriptors from health /tools endpoint.""" - try: - import urllib.request - if base_url == HTTP_URL_INDEXER: - url = f"http://localhost:{HEALTH_PORT_INDEXER}/tools" - elif base_url == HTTP_URL_MEMORY: - url = f"http://localhost:{HEALTH_PORT_MEMORY}/tools" - else: - return [] - with urllib.request.urlopen(url, timeout=timeout) as r: - if getattr(r, "status", 200) != 200: - return [] - body = r.read() - j = _parse_stream_or_json(body) - tools = (j.get("tools") if isinstance(j, dict) else None) or [] - out = [] - for t in tools: - if not isinstance(t, dict): - continue - nm = t.get("name") - if not nm: - continue - out.append({"name": nm, "description": (t.get("description") or "").strip()}) - return out - except Exception: - return [] - - -def _mcp_tools_list(base_url: str, timeout: float = 30.0) -> List[str]: - """Get list of tool names from MCP server.""" - try: - headers = _mcp_handshake(base_url, timeout=min(timeout, 15.0)) - payload = {"jsonrpc": "2.0", "id": "router-list", "method": "tools/list"} - _, body = _post_raw_retry(base_url, payload, headers, timeout=timeout) - j = _parse_stream_or_json(body) - tools = ((j.get("result") or {}).get("tools") or []) - names: List[str] = [] - for t in tools: - try: - n = t.get("name") if isinstance(t, dict) else None - if isinstance(n, str) and n: - names.append(n) - except Exception: - continue - return names - except Exception: - return [] - - -def _mcp_tools_describe(base_url: str, timeout: float = 20.0) -> list[dict]: - """Return tool dicts from tools/list.""" - try: - headers = _mcp_handshake(base_url, timeout=min(timeout, 10.0)) - payload = {"jsonrpc": "2.0", "id": "router-list2", "method": "tools/list"} - _, body = _post_raw_retry(base_url, payload, headers, timeout=timeout) - j = _parse_stream_or_json(body) - tools = ((j.get("result") or {}).get("tools") or []) - out = [] - for t in tools: - if not isinstance(t, dict): - continue - name = (t.get("name") or "").strip() - if not name: - continue - out.append(t) - return out - except Exception: - return [] - - -def tools_describe_cached(base_url: str, allow_network: bool = True, timeout: float = 20.0) -> list[dict]: - """Get tool descriptions with caching.""" - now = time.time() - ts = _TOOLS_DESCR_TS.get(base_url, 0.0) - if base_url in _TOOLS_DESCR_CACHE and (now - ts) <= cache_ttl_sec(): - return _TOOLS_DESCR_CACHE[base_url] - if not allow_network: - return _TOOLS_DESCR_CACHE.get(base_url, []) - desc = _tools_describe_from_health(base_url, timeout=min(timeout, 3.0)) or _mcp_tools_describe(base_url, timeout=timeout) - _TOOLS_DESCR_CACHE[base_url] = desc - _TOOLS_DESCR_TS[base_url] = now - return desc - - -def default_tool_endpoints() -> Dict[str, str]: - """Return default tool -> endpoint mapping.""" - idx = HTTP_URL_INDEXER - mem = HTTP_URL_MEMORY - mapping: Dict[str, str] = {} - for n in [ - "repo_search", "context_answer", "context_answer_compat", "expand_query", - "search_tests_for", "search_config_for", "search_callers_for", "search_importers_for", - "qdrant_index_root", "qdrant_prune", "qdrant_status", "qdrant_list", - "workspace_info", "list_workspaces", "change_history_for_path", "code_search", "context_search", - ]: - mapping[n] = idx - mapping["store"] = mem - mapping["find"] = mem - return mapping - - -def discover_tool_endpoints(force: bool = False, allow_network: bool = True) -> Dict[str, str]: - """Discover tool -> endpoint mapping from servers.""" - global _TOOL_ENDPOINTS_CACHE_TS, _TOOL_ENDPOINTS_CACHE_MAP - now = time.time() - ttl = cache_ttl_sec() - if not force and _TOOL_ENDPOINTS_CACHE_MAP and (now - _TOOL_ENDPOINTS_CACHE_TS) <= ttl: - return _TOOL_ENDPOINTS_CACHE_MAP - if not allow_network: - return _TOOL_ENDPOINTS_CACHE_MAP or default_tool_endpoints() - mapping: Dict[str, str] = {} - idx_desc = tools_describe_cached(HTTP_URL_INDEXER, allow_network=allow_network) - for t in idx_desc: - n = t.get("name") if isinstance(t, dict) else None - if n: - mapping[n] = HTTP_URL_INDEXER - mem_desc = tools_describe_cached(HTTP_URL_MEMORY, allow_network=allow_network) - for t in mem_desc: - n = t.get("name") if isinstance(t, dict) else None - if n and n not in mapping: - mapping[n] = HTTP_URL_MEMORY - if mapping: - _TOOL_ENDPOINTS_CACHE_MAP.clear() - _TOOL_ENDPOINTS_CACHE_MAP.update(mapping) - _TOOL_ENDPOINTS_CACHE_TS = now - return mapping or (_TOOL_ENDPOINTS_CACHE_MAP or default_tool_endpoints()) diff --git a/scripts/mcp_router/config.py b/scripts/mcp_router/config.py deleted file mode 100644 index b67b1282..00000000 --- a/scripts/mcp_router/config.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -mcp_router/config.py - Shared configuration and constants. -""" -from __future__ import annotations - -import os - -# HTTP endpoints -HTTP_URL_INDEXER = os.environ.get("MCP_INDEXER_HTTP_URL", "http://localhost:8003/mcp").rstrip("/") -HTTP_URL_MEMORY = os.environ.get("MCP_MEMORY_HTTP_URL", "http://localhost:8002/mcp").rstrip("/") -DEFAULT_HTTP_URL = HTTP_URL_INDEXER - -# Health ports -try: - HEALTH_PORT_INDEXER = int(os.environ.get("FASTMCP_INDEXER_HTTP_HEALTH_PORT", "18003") or 18003) -except (ValueError, TypeError): - HEALTH_PORT_INDEXER = 18003 - -try: - HEALTH_PORT_MEMORY = int(os.environ.get("FASTMCP_HTTP_HEALTH_PORT", "18002") or 18002) -except (ValueError, TypeError): - HEALTH_PORT_MEMORY = 18002 - - -def cache_ttl_sec() -> int: - try: - return int(os.environ.get("ROUTER_TOOLS_CACHE_TTL_SEC", "60") or 60) - except Exception: - return 60 - - -def scratchpad_ttl_sec() -> int: - try: - return int(os.environ.get("ROUTER_SCRATCHPAD_TTL_SEC", "300") or 300) - except Exception: - return 300 - - -def divergence_thresholds() -> tuple[float, int]: - try: - drop_frac = float(os.environ.get("ROUTER_DIVERGENCE_DROP_FRAC", "0.5") or 0.5) - except Exception: - drop_frac = 0.5 - try: - min_base = int(os.environ.get("ROUTER_DIVERGENCE_MIN_BASE", "3") or 3) - except Exception: - min_base = 3 - return drop_frac, min_base - - -def divergence_is_fatal_for(tool: str) -> bool: - try: - s = (os.environ.get("ROUTER_DIVERGENCE_FATAL_TOOLS", "") or "").strip() - if not s: - return False - low = s.lower() - if low in {"*", "all", "1", "true"}: - return True - names = {t.strip().lower() for t in s.split(",") if t.strip()} - return tool.strip().lower() in names - except Exception: - return False - - -# Language set for hint parsing -LANGS = { - "python", "typescript", "javascript", "go", "java", "rust", "kotlin", - "c++", "cpp", "csharp", "c#", "ruby", "php", "scala", "swift", "bash", "shell" -} diff --git a/scripts/mcp_router/hints.py b/scripts/mcp_router/hints.py deleted file mode 100644 index 18c46052..00000000 --- a/scripts/mcp_router/hints.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -mcp_router/hints.py - Query hint parsing and tool selection. -""" -from __future__ import annotations - -import re -from typing import Any, Dict, List, Tuple - -from .config import LANGS, HTTP_URL_INDEXER -from .client import tools_describe_cached -from .intent import _cosine, _embed_texts - - -def parse_repo_hints(q: str) -> Dict[str, Any]: - """Extract light filters from the query: language, under, symbol, ext, path_glob, not_glob.""" - s = q.strip() - low = s.lower() - out: Dict[str, Any] = {} - - # language - for lang in sorted(LANGS, key=len, reverse=True): - if re.search(rf"\b{re.escape(lang)}\b", low): - out["language"] = {"javascript": "js", "typescript": "ts", "c++": "cpp", "c#": "csharp"}.get(lang, lang) - break - - # under / in folder - m_under = re.search(r"\bunder\s+([\w./-]+)", low) - m_in = re.search(r"\b(?:in|inside)\s+([\w./-]+)", low) - m = m_under or m_in - if m: - cand = m.group(1) - if len(cand) >= 2 and cand not in LANGS: - out["under"] = cand - - # symbol-like tokens - m2 = re.search(r"([A-Za-z_][A-Za-z0-9_]*\s*\(\))|([A-Za-z_][\w]*\.[A-Za-z_][\w]*)|([A-Za-z_][\w]*::[A-Za-z_][\w]*)", s) - if m2: - sym = m2.group(0) - sym = re.sub(r"\s*\(\)\s*$", "", sym) - out["symbol"] = sym - - # file extension - m3 = re.search(r"\.(py|ts|tsx|js|jsx|go|java|rs|kt|rb|php|scala|swift)$", s) - if m3: - out["ext"] = m3.group(1) - - # glob inclusions - globs: List[str] = [] - if re.search(r"\bonly\b", low): - m_glob = re.search(r"\*\.[A-Za-z0-9]+", s) - if m_glob: - globs.append("**/" + m_glob.group(0)) - if "python" in low and "*.py" not in " ".join(globs): - globs.append("**/*.py") - if globs: - out["path_glob"] = globs - - # exclusions - not_glob: List[str] = [] - for ex in ["vendor", "node_modules", "dist", "build", "tests", "__pycache__"]: - if re.search(rf"\bexclude\s+{re.escape(ex)}\b", low): - not_glob.append(f"**/{ex}/**") - if not_glob: - out["not_glob"] = not_glob - - return out - - -def clean_query_and_dsl(q: str) -> Tuple[str, Dict[str, Any]]: - """Strip DSL tokens from query and return (clean_query, dsl_filters).""" - try: - from scripts.hybrid_search import parse_query_dsl - clean, extracted = parse_query_dsl([q]) - return (clean[0] if clean else ""), (extracted or {}) - except Exception: - return q, {} - - -def _signature_text(t: dict) -> str: - """Build signature text for tool similarity matching.""" - name = (t.get("name") or "").strip() - desc = (t.get("description") or "").strip() - params = [] - try: - schema = t.get("inputSchema") or {} - props = (schema.get("properties") or {}) if isinstance(schema, dict) else {} - params = [k for k in props.keys()] - except Exception: - params = [] - ptxt = (" params:" + ",".join(params)) if params else "" - return (name + "\n" + desc + ptxt).strip() - - -def select_best_search_tool_by_signature(q: str, tool_dict: dict[str, str], allow_network: bool = True) -> str | None: - """Select best matching search tool based on signature similarity.""" - candidates = [n for n in tool_dict.keys() if n == "repo_search" or n.startswith("search_")] - if not candidates: - return None - - per_server: dict[str, list[dict]] = {} - for base in set(tool_dict[t] for t in candidates): - try: - per_server[base] = tools_describe_cached(base, allow_network=allow_network) - except Exception: - per_server[base] = [] - - sig_map: dict[str, str] = {} - for tname in candidates: - base = tool_dict.get(tname) - descs = per_server.get(base, []) - obj = None - for td in descs: - if (td.get("name") or "").strip() == tname: - obj = td - break - sig_map[tname] = _signature_text(obj or {"name": tname, "description": ""}) - - texts = [q] + [sig_map[n] for n in candidates] - vecs = _embed_texts(texts) - if not vecs or len(vecs) < 1 + len(candidates): - return None - - qv = vecs[0] - scores: list[tuple[str, float]] = [] - for i, name in enumerate(candidates): - sv = vecs[1 + i] - scores.append((name, _cosine(qv, sv))) - scores.sort(key=lambda x: x[1], reverse=True) - - best, best_s = scores[0] - repo_s = next((s for n, s in scores if n == "repo_search"), None) - margin = 0.02 - if best == "repo_search" or repo_s is None: - return best - if best != "repo_search" and best_s >= (repo_s + margin): - return best - return "repo_search" diff --git a/scripts/mcp_router/intent.py b/scripts/mcp_router/intent.py deleted file mode 100644 index 7f96543f..00000000 --- a/scripts/mcp_router/intent.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -mcp_router/intent.py - Intent classification (rules + ML). -""" -from __future__ import annotations - -import json -import os -import re -import sys -from typing import Any, Dict, List - -# Intent constants -INTENT_ANSWER = "answer" -INTENT_SEARCH = "search" -INTENT_SEARCH_TESTS = "search_tests" -INTENT_SEARCH_CONFIG = "search_config" -INTENT_SEARCH_CALLERS = "search_callers" -INTENT_SEARCH_IMPORTERS = "search_importers" -INTENT_MEMORY_STORE = "memory_store" -INTENT_MEMORY_FIND = "memory_find" -INTENT_SYMBOL_GRAPH = "symbol_graph" -INTENT_INDEX = "index" -INTENT_PRUNE = "prune" -INTENT_STATUS = "status" -INTENT_LIST = "list" - -# Debug state -_LAST_INTENT_DEBUG: Dict[str, Any] = {} - - -def get_last_intent_debug() -> Dict[str, Any]: - """Get the last intent debug info.""" - return _LAST_INTENT_DEBUG - - -def _classify_intent_rules(q: str) -> str | None: - s = q.lower() - # Admin / maintenance first - if any(w in s for w in ["reindex", "reset", "recreate", "index now", "fresh index"]): - return INTENT_INDEX - if any(w in s for w in ["prune", "pruning", "cleanup", "clean up"]): - return INTENT_PRUNE - if any(w in s for w in ["status", "health", "points", "stats"]): - return INTENT_STATUS - if any(w in s for w in ["list collections", "collections", "list qdrant"]): - return INTENT_LIST - - # Search importers - check BEFORE tests to avoid "import" in test queries - if any(w in s for w in ["import", "imports", "importers", "who imports", "imports this", "importing modules", "files that import"]): - # Make sure it's not about "important" or similar - if not any(w in s for w in ["important", "importance"]): - return INTENT_SEARCH_IMPORTERS - - # Intent wrappers - if any(w in s for w in ["tests", "pytest", "unit test", "test file", "where are tests"]): - return INTENT_SEARCH_TESTS - - # Memory intents - be more specific to avoid false positives on "memory store implementation" - # Check for actual user-intent memory storage, not code references - memory_store_triggers = [ - "remember this", "save memory", "store memory", "remember that", - "save preference", "remember preference", "store a note", "save a note", "remember note" - ] - # IMPORTANT: "memory store" as a phrase often refers to code, not user intent - if any(w in s for w in memory_store_triggers): - # Exclude if it looks like a code search (has "implementation", "code", "function", etc) - if not any(exc in s for exc in ["implementation", "code", "function", "class", "module", "file", "search for"]): - return INTENT_MEMORY_STORE - if any(w in s for w in [ - "find memory", "recall", "retrieve memory", "memory search", "what did we save", - "recall notes", "find notes", "retrieve notes" - ]): - return INTENT_MEMORY_FIND - - # Symbol graph for callers - check BEFORE config to avoid false positives - if any(w in s for w in ["who calls", "callers of", "call sites", "function calls"]): - return INTENT_SYMBOL_GRAPH - if re.search(r"calls?\s+(the\s+)?\w+\s*(function|method)?", s): - return INTENT_SYMBOL_GRAPH - - # Config search - after callers check - if any(w in s for w in ["config", "yaml", "toml", "ini", "settings file", "configuration"]): - return INTENT_SEARCH_CONFIG - - # Fallback callers intent (used by search_callers_for) - if any(w in s for w in ["used by", "usage sites", "references this function"]): - return INTENT_SEARCH_CALLERS - - # Q&A-like prompts - if re.match(r"^(what|how|why|explain|describe|summarize)(\b|\s)", s): - return INTENT_ANSWER - if any(w in s for w in ["recap", "design doc", "architecture", "adr", "retrospective", "postmortem", "summary of", "summarize the design"]): - return INTENT_ANSWER - return None - - -def _intent_prototypes() -> Dict[str, List[str]]: - return { - INTENT_ANSWER: [ - "explain, describe, summarize, recap, design, architecture, ADR, why/how", - "summarize design decisions and architecture rationale", - ], - INTENT_SEARCH: [ - "find code references, search repository, locate files, find implementation", - "code search in repo, general lookup, search for implementation", - "find module, search function, locate class definition", - "search for memory store implementation", # Explicit example - ], - INTENT_MEMORY_STORE: [ - "remember this preference, save this note for later, store this memory", - "save my preference, remember that for next time", - # NOT: search for, find, implementation, code - ], - INTENT_MEMORY_FIND: [ - "what did we save, recall saved notes, retrieve memory, find my saved notes", - ], - INTENT_SEARCH_TESTS: [ - "find unit tests, test files, pytest, testing modules", - ], - INTENT_SEARCH_CONFIG: [ - "config files, configuration changes, yaml toml ini settings", - ], - INTENT_SEARCH_CALLERS: [ - "who calls this function, callers, usage sites, where is it used", - ], - INTENT_SEARCH_IMPORTERS: [ - "who imports this module, importers, importing modules, files that import", - ], - INTENT_SYMBOL_GRAPH: [ - "who calls this function, callers of, call graph, symbol callers", - ], - } - - -def _cosine(a: list[float], b: list[float]) -> float: - """Lightweight cosine similarity.""" - try: - s = 0.0 - na = 0.0 - nb = 0.0 - for i in range(min(len(a), len(b))): - va = float(a[i]) - vb = float(b[i]) - s += va * vb - na += va * va - nb += vb * vb - na = (na or 1.0) ** 0.5 - nb = (nb or 1.0) ** 0.5 - return s / (na * nb) - except Exception: - return 0.0 - - -def _embed_texts(texts: list[str]) -> list[list[float]]: - """Embed texts using available embedding model.""" - if not texts: - return [] - - # Try centralized embedder factory first - try: - from scripts.embedder import get_embedding_model - model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") - em = get_embedding_model(model_name) - raw = list(em.embed(texts)) - return [v.tolist() if hasattr(v, "tolist") else list(v) for v in raw] - except ImportError: - pass - - # Try fastembed directly - try: - from fastembed import TextEmbedding - model_name = os.environ.get("EMBEDDING_MODEL", "BAAI/bge-base-en-v1.5") - em = TextEmbedding(model_name=model_name) - raw = list(em.embed(texts)) - return [v.tolist() if hasattr(v, "tolist") else list(v) for v in raw] - except Exception: - pass - - # Fallback to lexical - try: - from scripts.utils import lex_hash_vector_text - return [lex_hash_vector_text(t, dim=4096) for t in texts] - except Exception: - return [[float(len(t))] for t in texts] - - -def _classify_intent_ml(q: str) -> str: - global _LAST_INTENT_DEBUG - protos = _intent_prototypes() - labels = list(protos.keys()) - texts = [q] + ["\n".join(protos[l]) for l in labels] - vecs = _embed_texts(texts) - if not vecs or len(vecs) < len(texts): - _LAST_INTENT_DEBUG = { - "strategy": "ml", - "intent": INTENT_SEARCH, - "confidence": 0.0, - "query": q, - "top_candidate": INTENT_SEARCH, - "top_score": 0.0, - "threshold": 0.25, - "candidates": [], - "reason": "embed_failed", - } - return INTENT_SEARCH - qv = vecs[0] - sims = [] - for i, lab in enumerate(labels): - sims.append((lab, _cosine(qv, vecs[1 + i]))) - sims.sort(key=lambda x: x[1], reverse=True) - top, score = sims[0] - picked = top if score >= 0.25 else INTENT_SEARCH - _LAST_INTENT_DEBUG = { - "strategy": "ml", - "intent": picked, - "confidence": float(score), - "query": q, - "top_candidate": top, - "top_score": float(score), - "threshold": 0.25, - "candidates": [(name, float(val)) for name, val in sims[:5]], - "fallback": picked == INTENT_SEARCH and top != INTENT_SEARCH, - } - return picked - - -def classify_intent(q: str) -> str: - """Classify user query into an intent.""" - global _LAST_INTENT_DEBUG - ruled = _classify_intent_rules(q) - if ruled is not None: - _LAST_INTENT_DEBUG = { - "strategy": "rules", - "intent": ruled, - "confidence": 1.0, - "query": q, - } - return ruled - picked = _classify_intent_ml(q) - try: - if os.environ.get("DEBUG_ROUTER") and isinstance(_LAST_INTENT_DEBUG, dict): - if _LAST_INTENT_DEBUG.get("fallback"): - print(json.dumps({"router": {"intent_fallback": _LAST_INTENT_DEBUG}}), file=sys.stderr) - except Exception: - pass - return picked diff --git a/scripts/mcp_router/memory.py b/scripts/mcp_router/memory.py deleted file mode 100644 index cb22f872..00000000 --- a/scripts/mcp_router/memory.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -mcp_router/memory.py - Memory store payload parsing. -""" -from __future__ import annotations - -import re -from typing import Any, Dict, Tuple - -_MEMORY_TRIGGER_RE = re.compile( - r"^(?:remember(?:\s+(?:this|that|me|to))?|save\s+memory|store\s+memory)\s*[:,\-]?\s*", - re.IGNORECASE, -) -_MEMORY_INTENT_SPLIT_RE = re.compile( - r"\b(?:then|and|also)\s+(?:reindex|index|recreate|prune|clean\s+up)\b", - re.IGNORECASE, -) -_MEMORY_META_KEYS = {"priority", "tag", "tags", "topic", "category", "owner"} - - -def parse_memory_store_payload(q: str) -> Tuple[str, Dict[str, Any]]: - """Parse memory store command, extracting content and metadata.""" - raw = str(q or "").strip() - if not raw: - return "", {} - cleaned = _MEMORY_TRIGGER_RE.sub("", raw, count=1).lstrip() - meta: Dict[str, Any] = {} - - def _assign_meta(key: str, value: str) -> None: - k = key.lower() - v = value.strip().strip(" \t\r\n,;.") - if not v: - return - if k in {"tag", "tags"}: - tags = [t.strip() for t in re.split(r"[,\s/]+", v) if t.strip()] - if tags: - meta["tags"] = tags - else: - meta[k] = v - - if cleaned.startswith("["): - m = re.match(r"\[([^\]]+)\]\s*(.*)", cleaned, flags=re.S) - if m: - meta_block = m.group(1) - cleaned = m.group(2) - for key, val in re.findall(r"(\w+)\s*=\s*([^\s,;]+(?:,[^\s,;]+)*)", meta_block): - if key.strip().lower() in _MEMORY_META_KEYS: - _assign_meta(key, val) - - while True: - m = re.match( - r"^(?P(?:priority|tag|tags|topic|category|owner))\s*=\s*(?P[^\s;:]+)\s*[,;:]?\s*(?P.*)$", - cleaned, - flags=re.IGNORECASE | re.S, - ) - if not m: - break - _assign_meta(m.group("key"), m.group("val")) - cleaned = m.group("rest") - - cleaned = cleaned.lstrip(":- ").lstrip() - - split = _MEMORY_INTENT_SPLIT_RE.search(cleaned) - if split: - cleaned = cleaned[: split.start()].rstrip(" ,;.") - - cleaned = cleaned.strip().strip('"').strip() - if not cleaned: - cleaned = raw - return cleaned, meta diff --git a/scripts/mcp_router/planning.py b/scripts/mcp_router/planning.py deleted file mode 100644 index fcc1b56a..00000000 --- a/scripts/mcp_router/planning.py +++ /dev/null @@ -1,312 +0,0 @@ -""" -mcp_router/planning.py - Tool planning and selection. -""" -from __future__ import annotations - -import os -from typing import Any, Dict, List, Tuple - -from .config import HTTP_URL_INDEXER -from .intent import ( - classify_intent, - INTENT_ANSWER, - INTENT_SEARCH, - INTENT_SEARCH_TESTS, - INTENT_SEARCH_CONFIG, - INTENT_SEARCH_CALLERS, - INTENT_SEARCH_IMPORTERS, - INTENT_MEMORY_STORE, - INTENT_MEMORY_FIND, - INTENT_SYMBOL_GRAPH, - INTENT_INDEX, - INTENT_PRUNE, - INTENT_STATUS, - INTENT_LIST, -) -from .memory import parse_memory_store_payload -from .hints import parse_repo_hints, clean_query_and_dsl, select_best_search_tool_by_signature -from .scratchpad import load_scratchpad, looks_like_repeat, looks_like_same_filters -from .client import discover_tool_endpoints - - -def build_plan(q: str) -> List[Tuple[str, Dict[str, Any]]]: - """Build execution plan for a query.""" - intent = classify_intent(q) - include_snippet = str(os.environ.get("ROUTER_INCLUDE_SNIPPET", "1")).lower() in {"1", "true", "yes", "on"} - search_limit = int(os.environ.get("ROUTER_SEARCH_LIMIT", "8") or 8) - max_tokens_env = os.environ.get("ROUTER_MAX_TOKENS", "").strip() - - def _reuse_last_filters(args: Dict[str, Any]) -> None: - try: - if looks_like_same_filters(q): - sp = load_scratchpad() - lf = sp.get("last_filters") if isinstance(sp, dict) else None - if isinstance(lf, dict): - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - if k not in args and lf.get(k) not in (None, ""): - args[k] = lf.get(k) - except Exception: - pass - - # Repeat/redo handling - try: - if looks_like_repeat(q): - sp = load_scratchpad() - lp = sp.get("last_plan") - if isinstance(lp, list) and lp: - norm: list[tuple] = [] - for it in lp: - if isinstance(it, (list, tuple)) and len(it) == 2: - norm.append((it[0], it[1])) - if norm: - return norm - except Exception: - pass - - # Multi-intent: memory store + reindex - lowq = q.lower() - if any(w in lowq for w in ["remember this", "store memory", "save memory", "remember that"]) and any(w in lowq for w in ["reindex", "index now", "recreate", "fresh index"]): - idx_args: Dict[str, Any] = {} - if any(w in lowq for w in ["recreate", "fresh", "from scratch", "fresh index"]): - idx_args["recreate"] = True - info, meta = parse_memory_store_payload(q) - store_args: Dict[str, Any] = {"information": info or q.strip()} - if meta: - allowed = {"priority", "tags", "topic", "category", "owner"} - cleaned = {k: v for k, v in meta.items() if k in allowed and v not in (None, "", [])} - if cleaned: - store_args["metadata"] = cleaned - return [("store", store_args), ("qdrant_index_root", idx_args)] - - if intent == INTENT_INDEX: - recreate = True if any(w in q.lower() for w in ["recreate", "fresh", "from scratch"]) else None - args = {} - if recreate is True: - args["recreate"] = True - return [("qdrant_index_root", args)] - - if intent == INTENT_PRUNE: - return [("qdrant_prune", {})] - - if intent == INTENT_STATUS: - return [("qdrant_status", {})] - - if intent == INTENT_LIST: - return [("qdrant_list", {})] - - if intent == INTENT_SEARCH: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"query": clean_q} - if search_limit: - args["limit"] = search_limit - if include_snippet: - args["include_snippet"] = True - _reuse_last_filters(args) - - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = dsl_filters.get(k) - if v not in (None, "") and k not in args: - args[k] = v - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, "") and k not in args: - args[k] = v - try: - tool_servers = discover_tool_endpoints(allow_network=False) - picked = select_best_search_tool_by_signature(q, tool_servers, allow_network=False) or "repo_search" - except Exception: - picked = "repo_search" - return [(picked, args)] - - if intent == INTENT_SEARCH_TESTS: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"query": clean_q} - if search_limit: - args["limit"] = search_limit - if include_snippet: - args["include_snippet"] = True - _reuse_last_filters(args) - - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = dsl_filters.get(k) - if v not in (None, "") and k not in args: - args[k] = v - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, "") and k not in args: - args[k] = v - return [("search_tests_for", args)] - - if intent == INTENT_SEARCH_CONFIG: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"query": clean_q} - if search_limit: - args["limit"] = search_limit - if include_snippet: - args["include_snippet"] = True - _reuse_last_filters(args) - - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = dsl_filters.get(k) - if v not in (None, "") and k not in args: - args[k] = v - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, "") and k not in args: - args[k] = v - return [("search_config_for", args)] - - if intent == INTENT_MEMORY_STORE: - info, meta = parse_memory_store_payload(q) - payload: Dict[str, Any] = {"information": info or q.strip()} - if meta: - allowed = {"priority", "tags", "topic", "category", "owner"} - cleaned = {k: v for k, v in meta.items() if k in allowed and v not in (None, "", [])} - if cleaned: - payload["metadata"] = cleaned - return [("store", payload)] - - if intent == INTENT_MEMORY_FIND: - args = {"query": q} - if search_limit: - args["limit"] = max(5, search_limit) - return [("find", args)] - - if intent == INTENT_SYMBOL_GRAPH: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"symbol": clean_q, "query_type": "callers"} - if search_limit: - args["limit"] = search_limit - for k in ("language", "under"): - v = dsl_filters.get(k) or hints.get(k) - if v not in (None, ""): - args[k] = v - return [("symbol_graph", args)] - - if intent == INTENT_SEARCH_CALLERS: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"query": clean_q} - if search_limit: - args["limit"] = search_limit - _reuse_last_filters(args) - - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = dsl_filters.get(k) - if v not in (None, "") and k not in args: - args[k] = v - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, "") and k not in args: - args[k] = v - return [("search_callers_for", args)] - - if intent == INTENT_SEARCH_IMPORTERS: - hints = parse_repo_hints(q) - clean_q, dsl_filters = clean_query_and_dsl(q) - args = {"query": clean_q} - if search_limit: - args["limit"] = search_limit - _reuse_last_filters(args) - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = dsl_filters.get(k) - if v not in (None, "") and k not in args: - args[k] = v - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, "") and k not in args: - args[k] = v - return [("search_importers_for", args)] - - if intent == INTENT_ANSWER: - def _looks_like_design_recap(s: str) -> bool: - low = s.lower() - return any(w in low for w in ["recap", "design doc", "architecture", "adr", "retrospective", "postmortem"]) and any(w in low for w in ["design", "summary", "recap", "explain"]) - - args: Dict[str, Any] = {"query": q} - if max_tokens_env: - try: - mt = int(max_tokens_env) - if mt > 0: - args["max_tokens"] = mt - except Exception: - pass - - hints = parse_repo_hints(q) - lowq = q.lower() - if "router" in lowq: - router_globs = ["**/mcp_router.py", "**/*router*.py"] - if not hints.get("path_glob"): - hints["path_glob"] = router_globs - if not hints.get("language"): - hints["language"] = "python" - for k in ("language", "under", "symbol", "ext", "path_glob", "not_glob"): - v = hints.get(k) - if v not in (None, ""): - args[k] = v - - plan: List[Tuple[str, Dict[str, Any]]] = [] - if _looks_like_design_recap(q): - plan.append(("find", {"query": q, "limit": 3})) - plan.extend([ - ("context_answer_compat", dict(args)), - ("context_answer", dict(args)), - ("repo_search", {**{k: v for k, v in args.items() if k != "max_tokens"}, "limit": max(5, search_limit)}), - ]) - return plan - - # Fallback - return [("repo_search", {"query": q, "limit": search_limit})] - - -async def route_query(query: str) -> Dict[str, Any]: - """ - Route a query to the appropriate tool. - - Returns dict with: - - tool: Selected tool name - - confidence: Routing confidence (0.0-1.0) - - intent: Classified intent - - args: Tool arguments - """ - intent = classify_intent(query) - plan = build_plan(query) - - if not plan: - return { - "tool": "repo_search", - "confidence": 0.3, - "intent": "fallback", - "args": {"query": query}, - } - - # First tool in plan is the primary selection - tool_name, tool_args = plan[0] - - # Map intent to confidence (higher for more specific intents) - intent_confidence = { - INTENT_ANSWER: 0.9, - INTENT_SEARCH: 0.7, - INTENT_SEARCH_TESTS: 0.85, - INTENT_SEARCH_CONFIG: 0.85, - INTENT_SEARCH_CALLERS: 0.85, - INTENT_SEARCH_IMPORTERS: 0.85, - INTENT_SYMBOL_GRAPH: 0.9, - INTENT_MEMORY_STORE: 0.9, - INTENT_MEMORY_FIND: 0.9, - INTENT_INDEX: 0.95, - INTENT_PRUNE: 0.95, - INTENT_STATUS: 0.95, - INTENT_LIST: 0.95, - } - - return { - "tool": tool_name, - "confidence": intent_confidence.get(intent, 0.5), - "intent": intent, - "args": tool_args, - } diff --git a/scripts/mcp_router/scratchpad.py b/scripts/mcp_router/scratchpad.py deleted file mode 100644 index 462ac2fa..00000000 --- a/scripts/mcp_router/scratchpad.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -mcp_router/scratchpad.py - Persistent scratchpad for context preservation. -""" -from __future__ import annotations - -import json -import os -import time -from typing import Any, Dict - -from .config import scratchpad_ttl_sec - - -def scratchpad_path() -> str: - """Get scratchpad file path.""" - base = os.path.join(os.getcwd(), ".codebase") - try: - os.makedirs(base, exist_ok=True) - except Exception: - pass - return os.path.join(base, "router_scratchpad.json") - - -def load_scratchpad() -> Dict[str, Any]: - """Load scratchpad with TTL handling.""" - import sys - p = scratchpad_path() - try: - with open(p, "r", encoding="utf-8") as f: - j = json.load(f) - if isinstance(j, dict): - try: - ts = float(j.get("timestamp") or 0.0) - except Exception: - ts = 0.0 - ttl = scratchpad_ttl_sec() - if ts and ttl >= 0 and (time.time() - ts) > ttl: - stale_keys = ( - "last_plan", - "last_filters", - "mem_snippets", - "last_answer", - "last_citations", - "last_paths", - "last_metrics", - ) - removed = False - for stale_key in stale_keys: - if stale_key in j: - j.pop(stale_key, None) - removed = True - if removed: - j["timestamp"] = 0.0 - try: - print( - json.dumps({ - "router": { - "scratchpad": "stale_cleared", - "age_sec": round(time.time() - ts, 2), - } - }), - file=sys.stderr, - ) - except Exception: - pass - return j - except Exception: - pass - return {} - - -def save_scratchpad(d: Dict[str, Any]) -> None: - """Save scratchpad atomically.""" - p = scratchpad_path() - tmp = p + ".tmp" - try: - with open(tmp, "w", encoding="utf-8") as f: - json.dump(d, f) - try: - f.flush() - os.fsync(f.fileno()) - except Exception: - pass - os.replace(tmp, p) - except Exception: - try: - if os.path.exists(tmp): - os.unlink(tmp) - except Exception: - pass - - -def looks_like_repeat(q: str) -> bool: - """Check if query looks like a repeat request.""" - s = q.strip().lower() - pats = [ - "repeat", "again", "same thing", "do that again", "rerun", "run it again", "same as before", - ] - return any(p in s for p in pats) - - -def looks_like_same_filters(q: str) -> bool: - """Check if query asks to reuse filters.""" - s = q.strip().lower() - return any(p in s for p in ["same filters", "reuse filters", "previous filters"]) - - -def looks_like_expand(q: str) -> bool: - """Check if query asks for expansion.""" - s = q.strip().lower() - pats = [ - "expand on", "expand that", "expand the summary", "elaborate", - "more detail", "more details", "go deeper", "add details", - ] - return any(p in s for p in pats) diff --git a/scripts/mcp_router/validation.py b/scripts/mcp_router/validation.py deleted file mode 100644 index 52457fb6..00000000 --- a/scripts/mcp_router/validation.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -mcp_router/validation.py - Response validation and metric extraction. -""" -from __future__ import annotations - -from typing import Any, Dict - -from .client import is_failure_response -from .config import divergence_thresholds, divergence_is_fatal_for - - -def is_result_good(tool: str, resp: Dict[str, Any]) -> bool: - """Check if result is good enough to stop the plan.""" - try: - r = resp.get("result") or {} - sc = r.get("structuredContent") or {} - rs = sc.get("result") or {} - - if tool in {"context_answer", "context_answer_compat"}: - ans = rs.get("answer") if isinstance(rs, dict) else None - if isinstance(ans, str): - s = ans.strip() - if s and not any(p in s.lower() for p in [ - "insufficient context", "not enough context", "no relevant", "don't know", "cannot answer" - ]): - return True - cites = rs.get("citations") if isinstance(rs, dict) else None - if isinstance(cites, list) and len(cites) > 0: - return True - return False - - if tool.startswith("search_") or tool == "repo_search": - total = rs.get("total") if isinstance(rs, dict) else None - if isinstance(total, int) and total > 0: - return True - results = rs.get("results") if isinstance(rs, dict) else None - if isinstance(results, list) and len(results) > 0: - return True - return False - - return not is_failure_response(resp) - except Exception: - return not is_failure_response(resp) - - -def extract_metric_from_resp(tool: str, resp: Dict[str, Any]) -> tuple[str, float] | None: - """Extract metric for divergence detection.""" - try: - r = resp.get("result") or {} - sc = r.get("structuredContent") or {} - rs = sc.get("result") or {} - - if tool in {"repo_search", "code_search", "context_search", "search_tests_for", "search_config_for", "search_callers_for", "search_importers_for"}: - tot = rs.get("total") - if isinstance(tot, (int, float)): - return ("total_results", float(tot)) - results = rs.get("results") - if isinstance(results, list): - return ("total_results", float(len(results))) - return None - - if tool in {"context_answer", "context_answer_compat"}: - cites = rs.get("citations") - if isinstance(cites, list): - return ("citations", float(len(cites))) - return ("citations", 0.0) - - if tool == "qdrant_status": - cnt = rs.get("count") - if isinstance(cnt, (int, float)): - return ("points", float(cnt)) - return None - except Exception: - return None - return None - - -def material_drop(prev: float | None, curr: float, drop_frac: float, min_base: int) -> bool: - """Check if there's a material drop in metrics.""" - try: - if prev is None: - return False - if prev < float(min_base): - return False - return curr < (float(prev) * float(drop_frac)) - except Exception: - return False diff --git a/scripts/mcp_toon.py b/scripts/mcp_toon.py deleted file mode 100644 index 3795c9c3..00000000 --- a/scripts/mcp_toon.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/mcp/toon.py""" -from scripts.mcp_impl.toon import * - diff --git a/scripts/mcp_utils.py b/scripts/mcp_utils.py deleted file mode 100644 index 9357a351..00000000 --- a/scripts/mcp_utils.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/mcp/utils.py""" -from scripts.mcp_impl.utils import * - diff --git a/scripts/mcp_workspace.py b/scripts/mcp_workspace.py deleted file mode 100644 index d8242bcb..00000000 --- a/scripts/mcp_workspace.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python3 -"""Shim for backward compatibility. See scripts/mcp/workspace.py""" -from scripts.mcp_impl.workspace import * - diff --git a/scripts/memory_backup.py b/scripts/memory_backup.py index 410ed90a..1518d550 100644 --- a/scripts/memory_backup.py +++ b/scripts/memory_backup.py @@ -19,17 +19,8 @@ from typing import List, Dict, Any, Optional from pathlib import Path -# Add project root to path for imports -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -try: - from qdrant_client import QdrantClient - from qdrant_client.models import Filter, FieldCondition, MatchValue -except ImportError: - print("ERROR: qdrant-client not installed. Install with: pip install qdrant-client") - sys.exit(1) +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, FieldCondition, MatchValue def get_qdrant_client() -> QdrantClient: @@ -316,4 +307,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/memory_restore.py b/scripts/memory_restore.py index c8a7789f..20840be1 100644 --- a/scripts/memory_restore.py +++ b/scripts/memory_restore.py @@ -20,29 +20,10 @@ from typing import List, Dict, Any, Optional from pathlib import Path -# Add project root to path for imports -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -try: - from qdrant_client import QdrantClient - from qdrant_client.models import VectorParams, Distance, HnswConfigDiff -except ImportError as e: - print(f"ERROR: Missing required dependency: {e}") - print("Install with: pip install qdrant-client fastembed") - sys.exit(1) - -# Use embedder factory for Qwen3 support; fallback to direct fastembed -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - try: - from fastembed import TextEmbedding - except ImportError: - TextEmbedding = None # type: ignore +from qdrant_client import QdrantClient +from qdrant_client.models import VectorParams, Distance, HnswConfigDiff + +from scripts.embedder import get_embedding_model as _get_embedding_model def get_qdrant_client() -> QdrantClient: @@ -55,18 +36,7 @@ def get_qdrant_client() -> QdrantClient: def get_embedding_model(model_name: str): """Initialize embedding model with Qwen3 support via embedder factory.""" - # Try centralized embedder factory first (supports Qwen3 feature flag) - if _EMBEDDER_FACTORY: - return _get_embedding_model(model_name) - # Fallback to direct fastembed - if TextEmbedding is not None: - try: - return TextEmbedding(model_name=model_name) - except Exception as e: - raise RuntimeError(f"Failed to load embedding model '{model_name}': {e}") - raise RuntimeError( - "No embedding model available. Install fastembed: pip install fastembed" - ) + return _get_embedding_model(model_name) def ensure_collection_exists( @@ -424,4 +394,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/path_scope.py b/scripts/path_scope.py new file mode 100644 index 00000000..2150926c --- /dev/null +++ b/scripts/path_scope.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +""" +Shared helpers for user-facing path scoping (`under`) across search tools. + +`under` is treated as a recursive subtree scope from the user's workspace +perspective (for example: "space" matches ".../space/**"). +""" + +from __future__ import annotations + +import os +import re +from functools import lru_cache +from typing import Any, Mapping, Optional, Set + +_MULTI_SLASH_RE = re.compile(r"/+") + + +def _normalize_path_token(value: Any) -> str: + s = str(value or "").strip().replace("\\", "/") + if not s: + return "" + s = _MULTI_SLASH_RE.sub("/", s) + # Normalize common "file://" style inputs. + if s.startswith("file://"): + s = s[7:] + return s.strip("/") + + +def _normalize_repo_hint(repo_hint: Any) -> str: + r = _normalize_path_token(repo_hint) + if not r: + return "" + return r.split("/")[-1] + + +def _repo_root_hint() -> str: + """Best-effort repository root (directory containing scripts/).""" + try: + return os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + except Exception: + return "" + + +def _maybe_expand_from_cwd(token: str) -> str: + """Recover under values that were relativized from the current subdirectory.""" + s = str(token or "").strip().strip("/") + if not s or "/" in s: + return s + try: + root = _repo_root_hint() + if not root: + return s + cwd = os.path.abspath(os.getcwd()) + if not (cwd == root or cwd.startswith(root + os.sep)): + return s + rel_cwd = os.path.relpath(cwd, root).replace("\\", "/").strip("/") + if not rel_cwd: + return s + rebased = f"{rel_cwd}/{s}" + rebased_path = os.path.join(root, *rebased.split("/")) + top_level_path = os.path.join(root, s) + if os.path.exists(rebased_path) and not os.path.exists(top_level_path): + return rebased + except Exception: + pass + return s + + +@lru_cache(maxsize=256) +def _unique_segment_path(root: str, segment: str) -> str: + """Return unique repo-relative directory path for a segment, else empty.""" + if not root or not segment: + return "" + top = os.path.join(root, segment) + if os.path.exists(top): + return "" + matches: list[str] = [] + skip = { + ".git", + ".codebase", + "__pycache__", + ".venv", + "node_modules", + } + try: + for dirpath, dirnames, _filenames in os.walk(root): + dirnames[:] = [d for d in dirnames if d not in skip and not d.startswith(".")] + if segment in dirnames: + rel = os.path.relpath(os.path.join(dirpath, segment), root).replace("\\", "/") + matches.append(rel.strip("/")) + if len(matches) > 1: + return "" + except Exception: + return "" + return matches[0] if len(matches) == 1 else "" + + +def _maybe_expand_unique_segment(token: str) -> str: + """Resolve single-segment under values to a unique subtree when possible.""" + s = str(token or "").strip().strip("/") + if not s or "/" in s: + return s + root = _repo_root_hint() + if not root: + return s + found = _unique_segment_path(root, s) + return found or s + + +def normalize_under(under: Optional[str]) -> Optional[str]: + """Normalize user-provided `under` into a comparable path token.""" + s = _normalize_path_token(under) + if not s or s in {".", "work"}: + return None + # Accept absolute-style workspace prefixes while preserving user-facing scope. + if s.startswith("work/"): + s = s[len("work/") :] + s = _maybe_expand_from_cwd(s) + s = _maybe_expand_unique_segment(s) + if not s or s in {".", "work"}: + return None + return s + + +def _path_forms(path: Any, repo_hint: Any = None) -> Set[str]: + """Generate comparable path forms from a path-like value.""" + p = _normalize_path_token(path) + if not p: + return set() + + forms: Set[str] = {p} + + repo = _normalize_repo_hint(repo_hint) + + if p.startswith("work/"): + rest = p[len("work/") :] + if rest: + forms.add(rest) + if "/" in rest and repo: + head, tail = rest.split("/", 1) + if head.casefold() == repo.casefold() and tail: + forms.add(tail) + + if repo: + def _cf_to_orig_idx(orig: str, cf_index: int) -> int: + if cf_index <= 0: + return 0 + acc = 0 + for i, ch in enumerate(orig): + nxt = acc + len(ch.casefold()) + if nxt > cf_index: + return i + acc = nxt + return len(orig) + + repo_cf = repo.casefold() + repo_prefix_cf = repo_cf + "/" + marker_cf = "/" + repo_cf + "/" + for f in list(forms): + f_cf = f.casefold() + if f_cf.startswith(repo_prefix_cf): + forms.add(f[len(repo) + 1 :]) + idx = f_cf.find(marker_cf) + if idx >= 0: + tail_start = _cf_to_orig_idx(f, idx + len(marker_cf)) + tail = f[tail_start:] + if tail: + forms.add(tail) + + return {x for x in forms if x} + + +def metadata_path_forms(metadata: Mapping[str, Any]) -> Set[str]: + """Collect path forms from a metadata payload.""" + repo_hint = metadata.get("repo") + forms: Set[str] = set() + for key in ( + "repo_rel_path", + "path", + "container_path", + "host_path", + "path_prefix", + "file_path", + "rel_path", + "client_path", + ): + v = metadata.get(key) + if v: + forms.update(_path_forms(v, repo_hint=repo_hint)) + return forms + + +def metadata_matches_under(metadata: Mapping[str, Any], under: Optional[str]) -> bool: + """Return True when metadata falls under the requested subtree scope.""" + norm_under = normalize_under(under) + if not norm_under: + return True + + repo_hint = metadata.get("repo") + under_forms = _path_forms(norm_under, repo_hint=repo_hint) + under_forms.add(norm_under) + if not norm_under.startswith("work/"): + under_forms.add("work/" + norm_under) + + under_forms_l = {u.casefold() for u in under_forms if u} + if not under_forms_l: + return True + + has_repo_hint = bool(str(repo_hint or "").strip()) + + for cand in metadata_path_forms(metadata): + cand_forms = {cand} + # Compatibility fallback for points that only store /work//... paths + # but do not carry metadata.repo (older/benchmark/custom payloads). + if not has_repo_hint: + c0 = cand.strip("/") + if c0.startswith("work/"): + rest = c0[len("work/") :] + if "/" in rest: + _head, tail = rest.split("/", 1) + if tail: + cand_forms.add(tail) + + for cf in cand_forms: + c = cf.casefold() + for u in under_forms_l: + if c == u or c.startswith(u + "/"): + return True + return False + + +def path_matches_under(path: Any, under: Optional[str], repo_hint: Any = None) -> bool: + """Path-only convenience wrapper for `under` subtree matching.""" + md = {"path": path} + if repo_hint: + md["repo"] = repo_hint + return metadata_matches_under(md, under) diff --git a/scripts/pattern_detection/search.py b/scripts/pattern_detection/search.py index a3dd5e23..d86b12fa 100644 --- a/scripts/pattern_detection/search.py +++ b/scripts/pattern_detection/search.py @@ -72,21 +72,19 @@ def _format_pattern_results_as_toon( response: Dict[str, Any], compact: bool = False, ) -> Dict[str, Any]: - """Convert pattern search response to TOON format. + """Add a TOON render while preserving structured pattern results. Args: response: Pattern search response dict with 'results' key compact: If True, use minimal fields only Returns: - Modified response with TOON-encoded results + Modified response with TOON-encoded text """ try: results = response.get("results", []) if isinstance(results, list): - # Encode results to TOON format - toon_results = encode_pattern_results(results, compact=compact) - response["results"] = toon_results + response["text"] = encode_pattern_results(results, compact=compact) response["output_format"] = "toon" return response except Exception as e: diff --git a/scripts/progressive_train.py b/scripts/progressive_train.py index cbf76383..68623c46 100644 --- a/scripts/progressive_train.py +++ b/scripts/progressive_train.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Progressive training evaluation - measures quality at checkpoints.""" -import sys, os -sys.path.insert(0, '.') +import os -from scripts.rerank_eval import get_candidates, rerank_learning, rerank_onnx, DEFAULT_EVAL_QUERIES -from scripts.rerank_recursive import rerank_with_learning +from scripts.rerank_tools.eval import get_candidates, rerank_learning, rerank_onnx, DEFAULT_EVAL_QUERIES +from scripts.rerank_recursive.recursive import rerank_with_learning from scripts.learning_reranker_worker import CollectionLearner import numpy as np @@ -88,4 +87,3 @@ def main(): if __name__ == '__main__': main() - diff --git a/scripts/prune.py b/scripts/prune.py index 5e2f14fb..2f628598 100755 --- a/scripts/prune.py +++ b/scripts/prune.py @@ -2,15 +2,37 @@ import os import hashlib from pathlib import Path -from typing import Tuple +from typing import Tuple, Any from qdrant_client import QdrantClient, models +try: + from scripts.ingest.graph_edges import ( + delete_edges_by_path as _shared_delete_graph_edges_by_path, + get_graph_collection_name as _shared_graph_collection_name, + ) +except Exception: + _shared_delete_graph_edges_by_path = None # type: ignore[assignment] + _shared_graph_collection_name = None # type: ignore[assignment] COLLECTION = os.environ.get("COLLECTION_NAME", "codebase") QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") API_KEY = os.environ.get("QDRANT_API_KEY") ROOT = Path(os.environ.get("PRUNE_ROOT", ".")).resolve() -GRAPH_COLLECTION = os.environ.get("GRAPH_COLLECTION_NAME", f"{COLLECTION}_graph") +GRAPH_COLLECTION = ( + _shared_graph_collection_name(COLLECTION) + if _shared_graph_collection_name is not None + else f"{COLLECTION}_graph" +) + + +def _norm_path(path_str: Any) -> str: + if not path_str: + return "" + try: + normalized = os.path.normpath(str(path_str)) + except Exception: + normalized = str(path_str) + return normalized.replace("\\", "/") def sha1_file(path: Path) -> str: @@ -39,50 +61,103 @@ def delete_by_path(client: QdrantClient, path_str: str) -> int: return 0 -def delete_graph_edges_by_path(client: QdrantClient, path_str: str) -> int: +def delete_graph_edges_by_path(client: QdrantClient, path_str: str, repo: str | None = None) -> int: """Best-effort deletion for graph-edge collections (if present). Some deployments store symbol-graph edges in a separate Qdrant collection - (commonly `${COLLECTION}_graph`). Those points may reference a file path as - either caller or callee; delete both to prevent stale graph results. + (commonly `${COLLECTION}_graph`). On this branch, edge docs are file-level and + reference a file path as `caller_path`. """ if not path_str: return 0 + path_str = _norm_path(path_str) - flt = models.Filter( - should=[ - models.FieldCondition( - key="caller_path", match=models.MatchValue(value=path_str) - ), - models.FieldCondition( - key="callee_path", match=models.MatchValue(value=path_str) - ), - ] - ) + # Canonical path: shared graph-edge deleter against _graph. + if _shared_delete_graph_edges_by_path is None: + return 0 try: - res = client.delete( + return int( + _shared_delete_graph_edges_by_path( + client, + COLLECTION, + caller_path=path_str, + repo=repo, + ) + or 0 + ) + except Exception: + return 0 + + +def _graph_collection_exists(client: QdrantClient) -> bool: + try: + client.get_collection(collection_name=GRAPH_COLLECTION) + return True + except Exception: + return False + + +def _delete_graph_points_by_ids(client: QdrantClient, ids: list[Any]) -> int: + if not ids: + return 0 + try: + from qdrant_client import models as qmodels + client.delete( collection_name=GRAPH_COLLECTION, - points_selector=models.FilterSelector(filter=flt), + points_selector=qmodels.PointIdsList(points=ids), ) - # Qdrant responses vary by client version; return 1 as "success" when count isn't available. - deleted_count = None - result_attr = getattr(res, "result", None) - if isinstance(result_attr, dict): - v = result_attr.get("deleted") - if isinstance(v, int): - deleted_count = v - if deleted_count is None: - v = getattr(res, "deleted", None) - if isinstance(v, int): - deleted_count = v - if deleted_count is None: - deleted_count = 1 - return deleted_count + return len(ids) except Exception: - # Non-fatal: graph collection may not exist in this deployment. return 0 +def delete_orphan_graph_edges(client: QdrantClient, valid_paths: set[str]) -> int: + """Delete graph-edge points whose `caller_path` no longer exists in base collection.""" + if not _graph_collection_exists(client): + return 0 + + removed = 0 + next_page = None + pending_ids: list[Any] = [] + batch_size = 256 + + while True: + try: + points, next_page = client.scroll( + collection_name=GRAPH_COLLECTION, + with_payload=True, + with_vectors=False, + limit=512, + offset=next_page, + scroll_filter=None, + ) + except Exception: + break + + if not points: + break + + for p in points: + payload = p.payload or {} + caller_path = _norm_path(payload.get("caller_path")) + if not caller_path: + continue + if caller_path in valid_paths: + continue + pending_ids.append(p.id) + if len(pending_ids) >= batch_size: + removed += _delete_graph_points_by_ids(client, pending_ids) + pending_ids = [] + + if next_page is None: + break + + if pending_ids: + removed += _delete_graph_points_by_ids(client, pending_ids) + + return removed + + def main(): client = QdrantClient(url=QDRANT_URL, api_key=API_KEY or None) @@ -90,6 +165,7 @@ def main(): removed_missing = 0 removed_mismatch = 0 removed_graph_edges = 0 + removed_orphan_graph_edges = 0 next_page = None while True: @@ -106,9 +182,9 @@ def main(): md = (p.payload or {}).get("metadata") or {} path_str = md.get("path") file_hash = md.get("file_hash") - if not path_str or path_str in seen: + norm_path = _norm_path(path_str) + if not norm_path or norm_path in seen: continue - seen.add(path_str) abs_path = ( ROOT / Path(path_str).relative_to("/work") if path_str.startswith("/work/") @@ -116,23 +192,37 @@ def main(): ) if not abs_path.exists(): removed_missing += delete_by_path(client, path_str) - removed_graph_edges += delete_graph_edges_by_path(client, path_str) + deleted = delete_graph_edges_by_path(client, path_str, md.get("repo")) + if deleted == 0: + # Repo tags can drift across ingestion modes; fall back to path-only delete. + deleted = delete_graph_edges_by_path(client, path_str, None) + removed_graph_edges += deleted print(f"[prune] removed missing file points: {path_str}") continue current_hash = sha1_file(abs_path) if file_hash and current_hash and current_hash != file_hash: removed_mismatch += delete_by_path(client, path_str) - removed_graph_edges += delete_graph_edges_by_path(client, path_str) + deleted = delete_graph_edges_by_path(client, path_str, md.get("repo")) + if deleted == 0: + deleted = delete_graph_edges_by_path(client, path_str, None) + removed_graph_edges += deleted print(f"[prune] removed outdated points (hash mismatch): {path_str}") + continue + + seen.add(norm_path) if next_page is None: break + # Secondary pass: if base points were manually deleted, remove orphan `_graph` edges. + removed_orphan_graph_edges = delete_orphan_graph_edges(client, seen) + print( "Prune complete. " f"removed_missing={removed_missing}, " f"removed_mismatch={removed_mismatch}, " - f"removed_graph_edges={removed_graph_edges}" + f"removed_graph_edges={removed_graph_edges}, " + f"removed_orphan_graph_edges={removed_orphan_graph_edges}" ) diff --git a/scripts/pseudo_config.py b/scripts/pseudo_config.py new file mode 100644 index 00000000..0f7b1f0b --- /dev/null +++ b/scripts/pseudo_config.py @@ -0,0 +1,52 @@ +"""Shared configuration helpers for pseudo/tags generation. + +This keeps env semantics consistent across: +- watcher (watch_index / watch_index_core) +- indexing CLI (scripts/ingest/cli.py) + +Policy: +- PSEUDO_BACKFILL_ENABLED controls whether the async backfill worker is enabled. +- PSEUDO_DEFER_TO_WORKER controls *foreground vs background* behavior only. + Deferral is only effective when the worker is enabled; otherwise we keep inline + pseudo generation ON to avoid silently dropping pseudo/tags. +""" + +from __future__ import annotations + +import os +from typing import Optional + + +def _parse_env_bool(value: Optional[str], *, default: bool = False) -> bool: + if value is None: + return default + v = str(value).strip().lower() + if not v: + return default + return v in {"1", "true", "yes", "on"} + + +def env_bool(key: str, *, default: bool = False) -> bool: + """Read a boolean env var using consistent truthy parsing.""" + return _parse_env_bool(os.environ.get(key), default=default) + + +def effective_defer_to_worker(*, defer_to_worker: bool, backfill_enabled: bool) -> bool: + """Whether we should disable inline pseudo/tags generation.""" + return bool(defer_to_worker and backfill_enabled) + + +def effective_pseudo_mode(*, defer_to_worker: bool, backfill_enabled: bool) -> str: + """Return pseudo_mode ('off'|'full') for indexing pipeline.""" + return "off" if effective_defer_to_worker( + defer_to_worker=defer_to_worker, + backfill_enabled=backfill_enabled, + ) else "full" + + +__all__ = [ + "env_bool", + "effective_defer_to_worker", + "effective_pseudo_mode", +] + diff --git a/scripts/qdrant_client_manager.py b/scripts/qdrant_client_manager.py index bcb4adbe..198b0cca 100644 --- a/scripts/qdrant_client_manager.py +++ b/scripts/qdrant_client_manager.py @@ -8,9 +8,19 @@ import threading import time import weakref -from typing import Optional, Dict, List +from typing import Any, Optional, Dict, List, TYPE_CHECKING from contextlib import contextmanager -from qdrant_client import QdrantClient + +if TYPE_CHECKING: + from qdrant_client import QdrantClient +else: + QdrantClient = Any + + +def _new_qdrant_client(url: str, api_key: Optional[str] = None) -> QdrantClient: + from qdrant_client import QdrantClient as _QdrantClient + + return _QdrantClient(url=url, api_key=api_key if api_key else None) # Connection pool implementation @@ -44,7 +54,7 @@ def get_client(self, url: str, api_key: Optional[str] = None) -> QdrantClient: # No suitable client found, create a new one if self._created_count < self.max_size: - client = QdrantClient(url=url, api_key=api_key) + client = _new_qdrant_client(url, api_key) pool_entry = { 'client': client, 'url': url, @@ -60,7 +70,7 @@ def get_client(self, url: str, api_key: Optional[str] = None) -> QdrantClient: else: # Pool is full, create a temporary client (not pooled) self._misses += 1 - return QdrantClient(url=url, api_key=api_key) + return _new_qdrant_client(url, api_key) def return_client(self, client: QdrantClient): """Return a client to the pool.""" @@ -166,13 +176,13 @@ def get_qdrant_client( # Fallback to singleton pattern for backward compatibility if force_new: - return QdrantClient(url=url, api_key=api_key if api_key else None) + return _new_qdrant_client(url, api_key) global _client with _client_lock: if _client is None: - _client = QdrantClient(url=url, api_key=api_key if api_key else None) + _client = _new_qdrant_client(url, api_key) return _client diff --git a/scripts/query_named_vector.py b/scripts/query_named_vector.py index d87ede8b..cfb64608 100644 --- a/scripts/query_named_vector.py +++ b/scripts/query_named_vector.py @@ -1,21 +1,8 @@ #!/usr/bin/env python3 import os -import sys -from pathlib import Path from qdrant_client import QdrantClient -# Ensure scripts is importable -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -# Use embedder factory for Qwen3 support -try: - from scripts.embedder import get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - from fastembed import TextEmbedding +from scripts.embedder import get_embedding_model from scripts.utils import sanitize_vector_name QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") @@ -24,10 +11,7 @@ VEC_NAME = os.environ.get("VECTOR_NAME") or sanitize_vector_name(MODEL) client = QdrantClient(url=QDRANT_URL) -if _EMBEDDER_FACTORY: - emb = get_embedding_model(MODEL) -else: - emb = TextEmbedding(model_name=MODEL) +emb = get_embedding_model(MODEL) q = "function that chunks code lines with overlap for semantic indexing" vec = next(emb.embed([q])) res = client.search( diff --git a/scripts/remote_upload_client.py b/scripts/remote_upload_client.py index fcc7d6ba..fbfacd97 100644 --- a/scripts/remote_upload_client.py +++ b/scripts/remote_upload_client.py @@ -46,6 +46,48 @@ # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +_git_history_skip_log_key: Optional[str] = None + + +def _is_usable_delta_status(status: Any) -> bool: + if not isinstance(status, dict): + return False + state = str(status.get("status") or "").strip().lower() + return ( + bool(status.get("success")) and + "workspace_path" in status and + "collection_name" in status and + state in {"ready", "processing", "completed"} + ) + + +def _server_status_error_message(status: Any) -> str: + if isinstance(status, dict): + error = status.get("error") + if isinstance(error, dict): + msg = str(error.get("message") or "").strip() + if msg: + return msg + state = str(status.get("status") or "").strip() + if state: + return f"Server status is {state}" + return "Invalid server status response" + + +def _env_flag(name: str, default: bool) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + +def _log_git_history_skip_once(reason: str, key: str) -> None: + global _git_history_skip_log_key + marker = f"{reason}:{key}" + if _git_history_skip_log_key == marker: + return + _git_history_skip_log_key = marker + logger.info("[git_history] skip (%s): %s", reason, key) DEFAULT_MAX_TEMP_CLEAN_ATTEMPTS = 3 DEFAULT_TEMP_CLEAN_SLEEP = 1.0 @@ -54,13 +96,17 @@ from scripts.workspace_state import ( get_cached_file_hash, set_cached_file_hash, - get_collection_name, _extract_repo_name_from_path, remove_cached_file, ) +from scripts.ingest.config import CODE_EXTS, EXTENSIONLESS_FILES -# Import existing hash function -import scripts.ingest_code as idx + +def hash_id(text: str, path: str, start: int, end: int) -> int: + h = hashlib.sha1( + f"{path}:{start}-{end}\n{text}".encode("utf-8", errors="ignore") + ).hexdigest() + return int(h[:16], 16) def _cache_missing_stats(file_hashes: Dict[str, Any]) -> Tuple[bool, int, int]: @@ -134,6 +180,24 @@ def _compute_logical_repo_id(workspace_path: str) -> str: return f"{prefix}{h}" +def _derive_metadata_root(workspace_path: str) -> Path: + """Infer host-side metadata root that corresponds to container `/work`.""" + try: + p = Path(workspace_path).resolve() + except Exception: + p = Path(workspace_path) + + if p.name == "dev-workspace": + return p.parent + if p.parent.name == "dev-workspace": + return p.parent.parent + if (p / ".codebase").exists(): + return p + if (p.parent / ".codebase").exists(): + return p.parent + return p.parent + + def _redact_emails(text: str) -> str: """Redact email addresses from commit messages for privacy.""" try: @@ -167,10 +231,12 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str } if max_commits <= 0: + _log_git_history_skip_once("disabled", f"max_commits={max_commits}") return None root = _find_git_root(Path(workspace_path)) if not root: + _log_git_history_skip_once("no_repo", workspace_path) return None # Git history cache: avoid emitting identical manifests when HEAD/settings are unchanged @@ -204,6 +270,7 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str cache = {} if current_head and cache.get("last_head") == current_head and cache.get("max_commits") == max_commits and str(cache.get("since") or "") == since: + _log_git_history_skip_once("cache_hit", f"head={current_head[:10]} since={since or '-'} max={max_commits}") return None base_head = "" @@ -254,12 +321,20 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str errors="replace", ) if proc.returncode != 0 or not proc.stdout.strip(): + _log_git_history_skip_once( + "rev_list_empty", + f"head={current_head[:10] if current_head else '-'} rc={proc.returncode}", + ) return None commits = [l.strip() for l in proc.stdout.splitlines() if l.strip()] except Exception: return None if not commits: + _log_git_history_skip_once( + "no_commits", + f"head={current_head[:10] if current_head else '-'}", + ) return None if len(commits) > max_commits: commits = commits[:max_commits] @@ -333,6 +408,10 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str continue if not records: + _log_git_history_skip_once( + "no_records", + f"commits={len(commits)} head={current_head[:10] if current_head else '-'}", + ) return None try: @@ -352,6 +431,14 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str "since": since, "commits": records, } + logger.info( + "[git_history] prepared manifest mode=%s commits=%d head=%s prev=%s base=%s", + manifest["mode"], + len(records), + (current_head[:10] if current_head else "-"), + (prev_head[:10] if prev_head else "-"), + (base_head[:10] if base_head else "-"), + ) # Update git history cache with the HEAD and settings used for this manifest try: @@ -370,7 +457,12 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str return manifest -def _load_local_cache_file_hashes(workspace_path: str, repo_name: Optional[str]) -> Dict[str, str]: +def _load_local_cache_file_hashes( + workspace_path: str, + repo_name: Optional[str], + *, + metadata_root: Optional[str] = None, +) -> Dict[str, str]: """Best-effort read of the local cache.json file_hashes map. This mirrors the layout used by workspace_state without introducing new @@ -378,7 +470,13 @@ def _load_local_cache_file_hashes(workspace_path: str, repo_name: Optional[str]) lookups still go through get_cached_file_hash. """ try: - base = Path(os.environ.get("WORKSPACE_PATH") or workspace_path).resolve() + base = Path( + metadata_root + or os.environ.get("CTXCE_METADATA_ROOT") + or os.environ.get("WATCH_ROOT") + or os.environ.get("WORKSPACE_PATH") + or workspace_path + ).resolve() multi_repo = os.environ.get("MULTI_REPO_MODE", "0").strip().lower() in {"1", "true", "yes", "on"} if multi_repo and repo_name: cache_path = base / ".codebase" / "repos" / repo_name / "cache.json" @@ -418,6 +516,21 @@ def _load_local_cache_file_hashes(workspace_path: str, repo_name: Optional[str]) return {} +def get_all_cached_paths( + repo_name: Optional[str] = None, + metadata_root: Optional[str] = None, +) -> List[str]: + """Return cached file paths from the local workspace cache.""" + effective_workspace = os.environ.get("WORKSPACE_PATH") or os.getcwd() + return list( + _load_local_cache_file_hashes( + effective_workspace, + repo_name, + metadata_root=metadata_root, + ).keys() + ) + + class RemoteUploadClient: """Client for uploading delta bundles to remote server.""" @@ -451,29 +564,24 @@ def _translate_to_container_path(self, host_path: str) -> str: return host_path.replace('\\', '/').replace(':', '') - def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: str, + def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: Optional[str] = None, max_retries: int = 3, timeout: int = 30, metadata_path: Optional[str] = None, logical_repo_id: Optional[str] = None): """Initialize remote upload client.""" self.upload_endpoint = upload_endpoint.rstrip('/') self.workspace_path = workspace_path + self.metadata_root = str(_derive_metadata_root(workspace_path)) self.collection_name = collection_name self.max_retries = max_retries self.timeout = timeout self.temp_dir = None self.logical_repo_id = logical_repo_id - # Set environment variables for cache functions - os.environ["WORKSPACE_PATH"] = workspace_path + from scripts.workspace_state import _extract_repo_name_from_path - # Get repo name for cache operations - try: - from scripts.workspace_state import _extract_repo_name_from_path - self.repo_name = _extract_repo_name_from_path(workspace_path) - # Fallback to directory name if repo detection fails (for non-git repos) - if not self.repo_name: - self.repo_name = Path(workspace_path).name - except ImportError: + self.repo_name = _extract_repo_name_from_path(workspace_path) + # Fallback to directory name if repo detection fails (for non-git repos) + if not self.repo_name: self.repo_name = Path(workspace_path).name # In-memory stat cache to avoid rehashing unchanged files on every watch iteration @@ -485,6 +593,173 @@ def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: s adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) self.session.mount("https://", adapter) + self.last_upload_result: Dict[str, Any] = {"outcome": "idle"} + self._last_plan_payload: Optional[Dict[str, Any]] = None + + def _get_cached_file_hash(self, file_path: str) -> str: + try: + return get_cached_file_hash( + file_path, + self.repo_name, + metadata_root=self.metadata_root, + ) + except TypeError: + # Support monkeypatched test doubles that don't accept metadata_root. + return get_cached_file_hash(file_path, self.repo_name) + + def _set_cached_file_hash(self, file_path: str, file_hash: str) -> None: + try: + set_cached_file_hash( + file_path, + file_hash, + self.repo_name, + metadata_root=self.metadata_root, + ) + except TypeError: + # Support monkeypatched test doubles that don't accept metadata_root. + set_cached_file_hash(file_path, file_hash, self.repo_name) + + def _remove_cached_file(self, file_path: str) -> None: + try: + remove_cached_file( + file_path, + self.repo_name, + metadata_root=self.metadata_root, + ) + except TypeError: + # Support monkeypatched test doubles that don't accept metadata_root. + remove_cached_file(file_path, self.repo_name) + + def _get_all_cached_paths(self) -> List[str]: + try: + return get_all_cached_paths( + self.repo_name, + metadata_root=self.metadata_root, + ) + except TypeError: + # Support monkeypatched test doubles that don't accept metadata_root. + return get_all_cached_paths(self.repo_name) + + def _set_last_upload_result(self, outcome: str, **details: Any) -> Dict[str, Any]: + result: Dict[str, Any] = {"outcome": outcome} + result.update(details) + self.last_upload_result = result + return result + + def log_watch_upload_result(self) -> None: + outcome = str((self.last_upload_result or {}).get("outcome") or "") + if outcome == "skipped_by_plan": + logger.info("[watch] No upload needed after plan") + elif outcome == "queued": + logger.info("[watch] Upload request accepted; server processing asynchronously") + elif outcome == "uploaded_async": + processed = (self.last_upload_result or {}).get("processed_operations") + logger.info("[watch] Upload processed asynchronously: %s", processed or {}) + elif outcome == "uploaded": + logger.info("[watch] Successfully uploaded changes") + elif outcome == "no_changes": + logger.info("[watch] No meaningful changes to upload") + else: + logger.info("[watch] Upload handling completed") + + def _finalize_successful_changes(self, changes: Dict[str, List]) -> None: + for path in changes.get("created", []): + try: + abs_path = str(path.resolve()) + current_hash = hashlib.sha1(path.read_bytes()).hexdigest() + self._set_cached_file_hash(abs_path, current_hash) + stat = path.stat() + self._stat_cache[abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + for path in changes.get("updated", []): + try: + abs_path = str(path.resolve()) + current_hash = hashlib.sha1(path.read_bytes()).hexdigest() + self._set_cached_file_hash(abs_path, current_hash) + stat = path.stat() + self._stat_cache[abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + for path in changes.get("deleted", []): + try: + abs_path = str(path.resolve()) + self._remove_cached_file(abs_path) + self._stat_cache.pop(abs_path, None) + except Exception: + continue + for source_path, dest_path in changes.get("moved", []): + try: + source_abs_path = str(source_path.resolve()) + self._remove_cached_file(source_abs_path) + self._stat_cache.pop(source_abs_path, None) + except Exception: + continue + try: + dest_abs_path = str(dest_path.resolve()) + current_hash = hashlib.sha1(dest_path.read_bytes()).hexdigest() + self._set_cached_file_hash(dest_abs_path, current_hash) + stat = dest_path.stat() + self._stat_cache[dest_abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + + def _await_async_upload_result( + self, + bundle_id: Optional[str], + sequence_number: Optional[int], + ) -> Optional[Dict[str, Any]]: + try: + max_wait = float(os.environ.get("CTXCE_REMOTE_UPLOAD_STATUS_WAIT_SECS", "5")) + except Exception: + max_wait = 5.0 + if max_wait <= 0: + return None + + try: + poll_interval = float(os.environ.get("CTXCE_REMOTE_UPLOAD_STATUS_POLL_INTERVAL_SECS", "1")) + except Exception: + poll_interval = 1.0 + poll_interval = max(0.1, poll_interval) + + deadline = time.time() + max_wait + while time.time() < deadline: + status = self.get_server_status() + if not status.get("success"): + return None + server_info = status.get("server_info", {}) if isinstance(status, dict) else {} + last_bundle_id = server_info.get("last_bundle_id") + last_upload_status = server_info.get("last_upload_status") + last_sequence = status.get("last_sequence") + bundle_matches = bool(bundle_id) and last_bundle_id == bundle_id + sequence_matches = sequence_number is not None and last_sequence == sequence_number + if bundle_matches or sequence_matches: + if last_upload_status == "completed": + return { + "outcome": "uploaded_async", + "bundle_id": last_bundle_id or bundle_id, + "sequence_number": last_sequence if last_sequence is not None else sequence_number, + "processed_operations": server_info.get("last_processed_operations"), + "processing_time_ms": server_info.get("last_processing_time_ms"), + } + if last_upload_status in ("failed", "error"): + return { + "outcome": "failed", + "bundle_id": last_bundle_id or bundle_id, + "sequence_number": last_sequence if last_sequence is not None else sequence_number, + "error": server_info.get("last_error"), + } + time.sleep(poll_interval) + return None def __enter__(self): """Context manager entry.""" @@ -505,7 +780,7 @@ def get_mapping_summary(self) -> Dict[str, Any]: container_path = self._translate_to_container_path(self.workspace_path) return { "repo_name": self.repo_name, - "collection_name": self.collection_name, + "collection_name": self.collection_name or "", "source_path": self.workspace_path, "container_path": container_path, "upload_endpoint": self.upload_endpoint, @@ -520,6 +795,65 @@ def log_mapping_summary(self) -> None: logger.info(f" source_path: {info['source_path']}") logger.info(f" container_path: {info['container_path']}") + def _excluded_dirnames(self) -> frozenset: + # Keep in sync with standalone_upload_client exclusions. + # NOTE: This caches the exclusion set per RemoteUploadClient instance. + # Runtime changes to DEV_REMOTE_MODE/REMOTE_UPLOAD_MODE won't be reflected + # until a new client is created (typically via process restart), which is + # acceptable for the upload client use case. + cached = getattr(self, "_excluded_dirnames_cache", None) + if cached is not None: + return cached + excluded = { + "node_modules", "vendor", "dist", "build", "target", "out", + ".git", ".hg", ".svn", ".vscode", ".idea", ".venv", "venv", + "__pycache__", ".pytest_cache", ".mypy_cache", ".cache", + ".context-engine", ".context-engine-uploader", ".codebase", + } + dev_remote = os.environ.get("DEV_REMOTE_MODE") == "1" or os.environ.get("REMOTE_UPLOAD_MODE") == "development" + if dev_remote: + excluded.add("dev-workspace") + cached = frozenset(excluded) + self._excluded_dirnames_cache = cached + return cached + + def _is_ignored_path(self, path: Path) -> bool: + """Return True when path is outside workspace or under excluded dirs.""" + try: + workspace_root = Path(self.workspace_path).resolve() + rel = path.resolve().relative_to(workspace_root) + except Exception: + return True + + dir_parts = set(rel.parts[:-1]) if len(rel.parts) > 1 else set() + if dir_parts & self._excluded_dirnames(): + return True + # Ignore hidden directories anywhere under the workspace, but allow + # extensionless dotfiles like `.gitignore` that we explicitly support. + if any(p.startswith(".") for p in rel.parts[:-1]): + return True + try: + extensionless = set((EXTENSIONLESS_FILES or {}).keys()) + except Exception: + extensionless = set() + if rel.name.startswith(".") and rel.name.lower() not in extensionless: + return True + return False + + def _is_watchable_path(self, path: Path) -> bool: + """Return True when a filesystem event path is eligible for upload processing.""" + if self._is_ignored_path(path): + return False + suffix = path.suffix.lower() + if CODE_EXTS.get(suffix, "unknown") != "unknown": + return True + name = path.name.lower() + try: + extensionless_names = {k.lower() for k in (EXTENSIONLESS_FILES or {}).keys()} + except Exception: + extensionless_names = set() + return name in extensionless_names or name.startswith("dockerfile") + def _get_temp_bundle_dir(self) -> Path: """Get or create temporary directory for bundle creation.""" if not self.temp_dir: @@ -547,6 +881,19 @@ def detect_file_changes(self, changed_paths: List[Path]) -> Dict[str, List]: } for path in changed_paths: + if self._is_ignored_path(path): + try: + abs_path = str(path.resolve()) + except Exception: + continue + cached_hash = self._get_cached_file_hash(abs_path) + if cached_hash: + changes["deleted"].append(path) + try: + self._stat_cache.pop(abs_path, None) + except Exception: + pass + continue # Resolve to an absolute path for stable cache keys try: abs_path = str(path.resolve()) @@ -554,7 +901,7 @@ def detect_file_changes(self, changed_paths: List[Path]) -> Dict[str, List]: # Skip paths that cannot be resolved continue - cached_hash = get_cached_file_hash(abs_path, self.repo_name) + cached_hash = self._get_cached_file_hash(abs_path) if not path.exists(): # File was deleted @@ -610,8 +957,6 @@ def detect_file_changes(self, changed_paths: List[Path]) -> Dict[str, List]: self._stat_cache[abs_path] = (getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), stat.st_size) except Exception: pass - set_cached_file_hash(abs_path, current_hash, self.repo_name) - # Detect moves by looking for files with same content hash # but different paths (requires additional tracking) changes["moved"] = self._detect_moves(changes["created"], changes["deleted"]) @@ -636,7 +981,7 @@ def _detect_moves(self, created_files: List[Path], deleted_files: List[Path]) -> for deleted_path in deleted_files: try: # Try to get cached hash first, fallback to file content - cached_hash = get_cached_file_hash(str(deleted_path), self.repo_name) + cached_hash = self._get_cached_file_hash(str(deleted_path)) if cached_hash: deleted_hashes[cached_hash] = deleted_path continue @@ -720,7 +1065,7 @@ def create_delta_bundle( # Get file info stat = path.stat() - language = idx.CODE_EXTS.get(path.suffix.lower(), "unknown") + language = CODE_EXTS.get(path.suffix.lower(), "unknown") operation = { "operation": "created", @@ -729,7 +1074,7 @@ def create_delta_bundle( "absolute_path": str(path.resolve()), "size_bytes": stat.st_size, "content_hash": content_hash, - "file_hash": f"sha1:{idx.hash_id(content.decode('utf-8', errors='ignore'), rel_path, 1, len(content.splitlines()))}", + "file_hash": f"sha1:{hash_id(content.decode('utf-8', errors='ignore'), rel_path, 1, len(content.splitlines()))}", "modified_time": datetime.fromtimestamp(stat.st_mtime).isoformat(), "language": language } @@ -749,7 +1094,7 @@ def create_delta_bundle( content = f.read() file_hash = hashlib.sha1(content).hexdigest() content_hash = f"sha1:{file_hash}" - previous_hash = get_cached_file_hash(str(path.resolve()), self.repo_name) + previous_hash = self._get_cached_file_hash(str(path.resolve())) # Write file to bundle bundle_file_path = files_dir / "updated" / rel_path @@ -758,7 +1103,7 @@ def create_delta_bundle( # Get file info stat = path.stat() - language = idx.CODE_EXTS.get(path.suffix.lower(), "unknown") + language = CODE_EXTS.get(path.suffix.lower(), "unknown") operation = { "operation": "updated", @@ -768,7 +1113,7 @@ def create_delta_bundle( "size_bytes": stat.st_size, "content_hash": content_hash, "previous_hash": f"sha1:{previous_hash}" if previous_hash else None, - "file_hash": f"sha1:{idx.hash_id(content.decode('utf-8', errors='ignore'), rel_path, 1, len(content.splitlines()))}", + "file_hash": f"sha1:{hash_id(content.decode('utf-8', errors='ignore'), rel_path, 1, len(content.splitlines()))}", "modified_time": datetime.fromtimestamp(stat.st_mtime).isoformat(), "language": language } @@ -797,7 +1142,7 @@ def create_delta_bundle( # Get file info stat = dest_path.stat() - language = idx.CODE_EXTS.get(dest_path.suffix.lower(), "unknown") + language = CODE_EXTS.get(dest_path.suffix.lower(), "unknown") operation = { "operation": "moved", @@ -809,7 +1154,7 @@ def create_delta_bundle( "source_absolute_path": str(source_path.resolve()), "size_bytes": stat.st_size, "content_hash": content_hash, - "file_hash": f"sha1:{idx.hash_id(content.decode('utf-8', errors='ignore'), dest_rel_path, 1, len(content.splitlines()))}", + "file_hash": f"sha1:{hash_id(content.decode('utf-8', errors='ignore'), dest_rel_path, 1, len(content.splitlines()))}", "modified_time": datetime.fromtimestamp(stat.st_mtime).isoformat(), "language": language } @@ -825,7 +1170,7 @@ def create_delta_bundle( for path in changes["deleted"]: rel_path = path.relative_to(Path(self.workspace_path)).as_posix() try: - previous_hash = get_cached_file_hash(str(path.resolve()), self.repo_name) + previous_hash = self._get_cached_file_hash(str(path.resolve())) operation = { "operation": "deleted", @@ -835,17 +1180,10 @@ def create_delta_bundle( "previous_hash": f"sha1:{previous_hash}" if previous_hash else None, "file_hash": None, "modified_time": datetime.now().isoformat(), - "language": idx.CODE_EXTS.get(path.suffix.lower(), "unknown") + "language": CODE_EXTS.get(path.suffix.lower(), "unknown") } operations.append(operation) - # Once a delete operation has been recorded, drop the cache entry - # so subsequent scans do not keep re-reporting the same deletion. - try: - remove_cached_file(str(path.resolve()), self.repo_name) - except Exception: - pass - except Exception as e: print(f"[bundle_create] Error processing deleted file {path}: {e}") continue @@ -855,7 +1193,6 @@ def create_delta_bundle( "version": "1.0", "bundle_id": bundle_id, "workspace_path": self.workspace_path, - "collection_name": self.collection_name, "created_at": created_at, # CLI is stateless - server handles sequence numbers "sequence_number": None, # Server will assign @@ -909,6 +1246,301 @@ def create_delta_bundle( return str(bundle_path), manifest + def _build_plan_payload(self, changes: Dict[str, List]) -> Dict[str, Any]: + created_at = datetime.now().isoformat() + bundle_id = str(uuid.uuid4()) + operations: List[Dict[str, Any]] = [] + file_hashes: Dict[str, str] = {} + total_size = 0 + + for path in changes["created"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = path.stat() + operations.append( + { + "operation": "created", + "path": rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "language": CODE_EXTS.get(path.suffix.lower(), "unknown"), + } + ) + file_hashes[rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning("[remote_upload] Failed to prepare created plan entry for %s: %s", path, e) + + for path in changes["updated"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = path.stat() + previous_hash = self._get_cached_file_hash(str(path.resolve())) + operations.append( + { + "operation": "updated", + "path": rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "previous_hash": f"sha1:{previous_hash}" if previous_hash else None, + "language": CODE_EXTS.get(path.suffix.lower(), "unknown"), + } + ) + file_hashes[rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning("[remote_upload] Failed to prepare updated plan entry for %s: %s", path, e) + + for source_path, dest_path in changes["moved"]: + dest_rel_path = dest_path.relative_to(Path(self.workspace_path)).as_posix() + source_rel_path = source_path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = dest_path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = dest_path.stat() + operations.append( + { + "operation": "moved", + "path": dest_rel_path, + "source_path": source_rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "language": CODE_EXTS.get(dest_path.suffix.lower(), "unknown"), + } + ) + file_hashes[dest_rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning( + "[remote_upload] Failed to prepare moved plan entry for %s -> %s: %s", + source_path, + dest_path, + e, + ) + + for path in changes["deleted"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + previous_hash = self._get_cached_file_hash(str(path.resolve())) + operations.append( + { + "operation": "deleted", + "path": rel_path, + "previous_hash": f"sha1:{previous_hash}" if previous_hash else None, + "language": CODE_EXTS.get(path.suffix.lower(), "unknown"), + } + ) + except Exception as e: + logger.warning("[remote_upload] Failed to prepare deleted plan entry for %s: %s", path, e) + + manifest = { + "version": "1.0", + "bundle_id": bundle_id, + "workspace_path": self.workspace_path, + "created_at": created_at, + "sequence_number": None, + "parent_sequence": None, + "operations": { + "created": len(changes["created"]), + "updated": len(changes["updated"]), + "deleted": len(changes["deleted"]), + "moved": len(changes["moved"]), + }, + "total_files": len(operations), + "total_size_bytes": total_size, + "compression": "gzip", + "encoding": "utf-8", + } + return { + "manifest": manifest, + "operations": operations, + "file_hashes": file_hashes, + } + + def _plan_delta_upload(self, changes: Dict[str, List]) -> Optional[Dict[str, Any]]: + if not _env_flag("CTXCE_REMOTE_UPLOAD_PLAN_ENABLED", True): + return None + try: + payload = self._build_plan_payload(changes) + self._last_plan_payload = payload + data = { + "workspace_path": self._translate_to_container_path(self.workspace_path), + "source_path": self.workspace_path, + "logical_repo_id": _compute_logical_repo_id(self.workspace_path), + "manifest": payload["manifest"], + "operations": payload["operations"], + "file_hashes": payload["file_hashes"], + } + sess = get_auth_session(self.upload_endpoint) + if sess: + data["session"] = sess + if getattr(self, "logical_repo_id", None): + data["logical_repo_id"] = self.logical_repo_id + + response = self.session.post( + f"{self.upload_endpoint}/api/v1/delta/plan", + json=data, + timeout=min(self.timeout, 60), + ) + if response.status_code in {404, 405}: + logger.info("[remote_upload] Plan endpoint unavailable; falling back to full bundle upload") + return None + response.raise_for_status() + body = response.json() + if not body.get("success", False): + logger.warning("[remote_upload] Plan request failed; falling back: %s", body.get("error")) + return None + return body + except Exception as e: + logger.warning("[remote_upload] Plan request failed; falling back to full bundle upload: %s", e) + return None + + def _build_apply_only_payload(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Dict[str, Any]: + payload = self._last_plan_payload or self._build_plan_payload(changes) + needed = plan.get("needed_files", {}) if isinstance(plan, dict) else {} + created_needed = set(needed.get("created", []) or []) + updated_needed = set(needed.get("updated", []) or []) + moved_needed = set(needed.get("moved", []) or []) + + # Check if ALL operations are hash-matched (nothing needs content at all) + # This happens when all needed_files lists are empty and there are no actual changes requiring content + has_changes_needing_content = bool(created_needed or updated_needed or moved_needed) + has_deletes = bool(changes.get("deleted", [])) + + # Only skip apply-only if there are NO operations needing content AND NO deletes + if not has_changes_needing_content and not has_deletes: + return { + "manifest": payload.get("manifest", {}), + "operations": [], + "file_hashes": {}, + } + + filtered_ops: List[Dict[str, Any]] = [] + filtered_hashes: Dict[str, str] = {} + for operation in payload.get("operations", []): + op_type = str(operation.get("operation") or "") + rel_path = str(operation.get("path") or "") + # Determine if this operation needs content (only those skip filtered_hashes) + needs_content = ( + (op_type == "created" and rel_path in created_needed) + or (op_type == "updated" and rel_path in updated_needed) + or (op_type == "moved" and rel_path in moved_needed) + ) + if needs_content: + # Skip operations that need content - they'll be uploaded separately + continue + # IMPORTANT: server-side apply_delta_operations() only accepts "deleted" and "moved" + # operations. Hash-matched "created" and "updated" operations must NOT be routed + # through apply_ops since the server will reject them. + if op_type not in {"deleted", "moved"}: + continue + # Preserve all other operations so server advances state + filtered_ops.append(operation) + # Include hash for non-deleted operations + if op_type != "deleted": + hash_value = payload.get("file_hashes", {}).get(rel_path) + if hash_value: + filtered_hashes[rel_path] = hash_value + return { + "manifest": payload.get("manifest", {}), + "operations": filtered_ops, + "file_hashes": filtered_hashes, + } + + def _apply_operations_without_content(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Optional[bool]: + payload = self._build_apply_only_payload(changes, plan) + operations = payload.get("operations", []) + if not operations: + return None + try: + data = { + "workspace_path": self._translate_to_container_path(self.workspace_path), + "source_path": self.workspace_path, + "logical_repo_id": _compute_logical_repo_id(self.workspace_path), + "manifest": payload["manifest"], + "operations": operations, + "file_hashes": payload["file_hashes"], + } + sess = get_auth_session(self.upload_endpoint) + if sess: + data["session"] = sess + if getattr(self, "logical_repo_id", None): + data["logical_repo_id"] = self.logical_repo_id + + logger.info( + "[remote_upload] Applying metadata-only operations without bundle: deleted=%s moved=%s", + sum(1 for op in operations if op.get("operation") == "deleted"), + sum(1 for op in operations if op.get("operation") == "moved"), + ) + response = self.session.post( + f"{self.upload_endpoint}/api/v1/delta/apply_ops", + json=data, + timeout=min(self.timeout, 60), + ) + if response.status_code in {404, 405}: + logger.info("[remote_upload] apply_ops endpoint unavailable; falling back to bundle upload") + return None + response.raise_for_status() + body = response.json() + if not body.get("success", False): + logger.warning("[remote_upload] apply_ops failed; falling back to bundle upload: %s", body.get("error")) + return None + # Only finalize changes that were actually processed by the server + # apply_delta_operations only handles deleted/moved operations + processed_ops = body.get("processed_operations") or {} + applied_changes = { + "deleted": changes.get("deleted", []), + "moved": changes.get("moved", []), + "created": [], + "updated": [], + } + self._finalize_successful_changes(applied_changes) + self._set_last_upload_result( + "uploaded", + bundle_id=body.get("bundle_id"), + sequence_number=body.get("sequence_number"), + processed_operations=processed_ops, + ) + logger.info( + "[remote_upload] Metadata-only operations applied: %s", + processed_ops, + ) + return True + except Exception as e: + logger.warning("[remote_upload] apply_ops failed; falling back to bundle upload: %s", e) + return None + + def _filter_changes_by_plan(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Dict[str, List]: + needed = plan.get("needed_files", {}) if isinstance(plan, dict) else {} + created_needed = set(needed.get("created", []) or []) + updated_needed = set(needed.get("updated", []) or []) + moved_needed = set(needed.get("moved", []) or []) + + filtered_created = [ + path for path in changes["created"] + if path.relative_to(Path(self.workspace_path)).as_posix() in created_needed + ] + filtered_updated = [ + path for path in changes["updated"] + if path.relative_to(Path(self.workspace_path)).as_posix() in updated_needed + ] + filtered_moved = [ + (source_path, dest_path) + for source_path, dest_path in changes["moved"] + if dest_path.relative_to(Path(self.workspace_path)).as_posix() in moved_needed + ] + return { + "created": filtered_created, + "updated": filtered_updated, + "deleted": list(changes["deleted"]), + "moved": filtered_moved, + "unchanged": [], + } + def upload_bundle(self, bundle_path: str, manifest: Dict[str, Any]) -> Dict[str, Any]: """ Upload delta bundle to remote server with exponential backoff retry. @@ -942,13 +1574,11 @@ def upload_bundle(self, bundle_path: str, manifest: Dict[str, Any]) -> Dict[str, } data = { "workspace_path": self._translate_to_container_path(self.workspace_path), - "collection_name": self.collection_name, "sequence_number": manifest.get("sequence_number"), "force": False, "source_path": self.workspace_path, "logical_repo_id": _compute_logical_repo_id(self.workspace_path), } - sess = get_auth_session(self.upload_endpoint) if sess: data["session"] = sess @@ -1156,7 +1786,16 @@ def get_server_status(self) -> Dict[str, Any]: ) if response.status_code == 200: - return response.json() + payload = response.json() + if not isinstance(payload, dict): + return { + "success": False, + "error": { + "code": "STATUS_INVALID", + "message": "Invalid status response payload", + }, + } + return {"success": True, **payload} # Handle error response error_msg = f"Status check failed with HTTP {response.status_code}" @@ -1180,6 +1819,93 @@ def has_meaningful_changes(self, changes: Dict[str, List]) -> bool: total_changes = sum(len(files) for op, files in changes.items() if op != "unchanged") return total_changes > 0 + def _collect_force_cleanup_paths(self) -> List[Path]: + """ + Return ignored paths that force mode should actively delete remotely. + + In dev-remote mode, dev-workspace is intentionally ignored during upload + scans to avoid recursive dogfooding. If that tree already exists on the + remote side from an older buggy upload, force mode should remove it even + when the local cache does not contain those paths. + """ + cleanup_paths: List[Path] = [] + if "dev-workspace" not in self._excluded_dirnames(): + return cleanup_paths + + dev_root = Path(self.workspace_path) / "dev-workspace" + if not dev_root.exists(): + return cleanup_paths + + for root, dirnames, filenames in os.walk(dev_root): + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + for filename in filenames: + path = Path(root) / filename + try: + if path.is_file(): + cleanup_paths.append(path) + except Exception: + continue + return cleanup_paths + + def build_force_changes(self, all_files: List[Path]) -> Dict[str, List]: + """ + Build force-upload changes while still cleaning stale cached paths. + + Force mode should re-upload every currently managed file, but it must also + emit deletes for files that only exist in the local cache now, including + paths that are ignored under the current client policy such as + dev-workspace in dev-remote mode. + """ + created_files: List[Path] = [] + path_map: Dict[Path, Path] = {} + for path in all_files: + if self._is_ignored_path(path): + continue + try: + resolved = path.resolve() + except Exception: + continue + created_files.append(path) + path_map[resolved] = path + + for cached_abs in self._get_all_cached_paths(): + try: + cached_path = Path(cached_abs) + resolved = cached_path.resolve() + except Exception: + continue + if resolved not in path_map: + path_map[resolved] = cached_path + + force_cleanup_paths = self._collect_force_cleanup_paths() + for cleanup_path in force_cleanup_paths: + try: + resolved = cleanup_path.resolve() + except Exception: + continue + if resolved not in path_map: + path_map[resolved] = cleanup_path + + probed = self.detect_file_changes(list(path_map.values())) + deleted_by_resolved: Dict[Path, Path] = {} + for deleted_path in probed.get("deleted", []): + try: + deleted_by_resolved[deleted_path.resolve()] = deleted_path + except Exception: + continue + for cleanup_path in force_cleanup_paths: + try: + deleted_by_resolved.setdefault(cleanup_path.resolve(), cleanup_path) + except Exception: + continue + return { + "created": created_files, + "updated": [], + "deleted": list(deleted_by_resolved.values()), + "moved": [], + "unchanged": [], + } + def upload_git_history_only(self, git_history: Dict[str, Any]) -> bool: try: empty_changes = { @@ -1224,10 +1950,13 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: # Validate input if not changes: logger.info("[remote_upload] No changes provided") + self._set_last_upload_result("no_changes") return True + if not self.has_meaningful_changes(changes): logger.info("[remote_upload] No meaningful changes detected, skipping upload") + self._set_last_upload_result("no_changes") return True # Log change summary @@ -1236,10 +1965,44 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: f"{len(changes['created'])} created, {len(changes['updated'])} updated, " f"{len(changes['deleted'])} deleted, {len(changes['moved'])} moved") + planned_changes = changes + plan = self._plan_delta_upload(changes) + if plan: + preview = plan.get("operation_counts_preview", {}) + logger.info( + "[remote_upload] Plan preview: needed created=%s updated=%s deleted=%s moved=%s " + "skipped_hash_match=%s needed_bytes=%s", + preview.get("created", 0), + preview.get("updated", 0), + preview.get("deleted", 0), + preview.get("moved", 0), + preview.get("skipped_hash_match", 0), + plan.get("needed_size_bytes", 0), + ) + planned_changes = self._filter_changes_by_plan(changes, plan) + has_content_work = bool( + planned_changes.get("created") + or planned_changes.get("updated") + or planned_changes.get("moved") + ) + if not has_content_work: + apply_only_result = self._apply_operations_without_content(changes, plan) + if apply_only_result is True: + return True + if not self.has_meaningful_changes(planned_changes): + logger.info("[remote_upload] Plan found no upload work; skipping bundle upload") + self._finalize_successful_changes(changes) + self._set_last_upload_result( + "skipped_by_plan", + plan_preview=preview, + needed_size_bytes=plan.get("needed_size_bytes", 0), + ) + return True + # Create delta bundle bundle_path = None try: - bundle_path, manifest = self.create_delta_bundle(changes) + bundle_path, manifest = self.create_delta_bundle(planned_changes) logger.info(f"[remote_upload] Created delta bundle: {manifest['bundle_id']} " f"(size: {manifest['total_size_bytes']} bytes)") @@ -1251,6 +2014,7 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: logger.error(f"[remote_upload] Error creating delta bundle: {e}") # Clean up any temporary files on failure self.cleanup() + self._set_last_upload_result("failed", stage="bundle_creation", error=str(e)) return False # Upload bundle with retry logic @@ -1258,9 +2022,84 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: response = self.upload_bundle(bundle_path, manifest) if response.get("success", False): - processed_ops = response.get('processed_operations', {}) - logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") - logger.info(f"[remote_upload] Processed operations: {processed_ops}") + async_failed = False + async_pending = False + processed_ops = response.get("processed_operations") + if processed_ops is None: + logger.info( + "[remote_upload] Bundle %s accepted by server; processing asynchronously (sequence=%s)", + manifest["bundle_id"], + response.get("sequence_number"), + ) + self._set_last_upload_result( + "queued", + bundle_id=manifest["bundle_id"], + sequence_number=response.get("sequence_number"), + ) + async_result = self._await_async_upload_result( + manifest["bundle_id"], + response.get("sequence_number"), + ) + if async_result is None: + # Server accepted the bundle but status is still pending. + async_pending = True + logger.warning( + "[remote_upload] Async upload timed out awaiting server response for bundle %s", + manifest["bundle_id"], + ) + else: + self.last_upload_result = async_result + outcome = str(async_result.get("outcome") or "") + if outcome == "uploaded_async": + self._finalize_successful_changes(planned_changes) + logger.info( + "[remote_upload] Async processing completed for bundle %s: %s", + manifest["bundle_id"], + async_result.get("processed_operations") or {}, + ) + elif outcome == "failed": + async_failed = True + logger.error( + "[remote_upload] Async processing failed for bundle %s: %s", + manifest["bundle_id"], + async_result.get("error"), + ) + self._set_last_upload_result( + "failed", + stage="async_processing", + bundle_id=async_result.get("bundle_id") or manifest["bundle_id"], + sequence_number=async_result.get("sequence_number") or response.get("sequence_number"), + error=async_result.get("error"), + ) + else: + async_pending = True + # Keep queued state for non-terminal async outcomes. + self._set_last_upload_result( + "queued", + bundle_id=async_result.get("bundle_id") or manifest["bundle_id"], + sequence_number=async_result.get("sequence_number") or response.get("sequence_number"), + ) + logger.warning( + "[remote_upload] Async upload still pending for bundle %s (sequence=%s, outcome=%s)", + manifest["bundle_id"], + response.get("sequence_number"), + outcome or "", + ) + else: + logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") + logger.info(f"[remote_upload] Processed operations: {processed_ops}") + self._finalize_successful_changes(planned_changes) + self._set_last_upload_result( + "uploaded", + bundle_id=manifest["bundle_id"], + sequence_number=response.get("sequence_number"), + processed_operations=processed_ops, + ) + if async_pending: + logger.info( + "[remote_upload] Bundle %s accepted and queued; deferring local finalization", + manifest["bundle_id"], + ) # Clean up temporary bundle after successful upload try: @@ -1272,20 +2111,24 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: except Exception as cleanup_error: logger.warning(f"[remote_upload] Failed to cleanup bundle {bundle_path}: {cleanup_error}") - return True + return not async_failed else: error_msg = response.get('error', {}).get('message', 'Unknown upload error') logger.error(f"[remote_upload] Upload failed: {error_msg}") + self._set_last_upload_result("failed", stage="upload", error=error_msg) return False except Exception as e: logger.error(f"[remote_upload] Error uploading bundle: {e}") + self._set_last_upload_result("failed", stage="upload", error=str(e)) return False except Exception as e: logger.error(f"[remote_upload] Unexpected error in process_changes_and_upload: {e}") + self._set_last_upload_result("failed", stage="unexpected", error=str(e)) return False + def get_all_code_files(self) -> List[Path]: """Get all code files in the workspace.""" files: List[Path] = [] @@ -1295,28 +2138,31 @@ def get_all_code_files(self) -> List[Path]: return files # Single walk with early pruning similar to standalone client - ext_suffixes = {str(ext).lower() for ext in idx.CODE_EXTS if str(ext).startswith('.')} - name_matches = {str(ext) for ext in idx.CODE_EXTS if not str(ext).startswith('.')} - dev_remote = os.environ.get("DEV_REMOTE_MODE") == "1" or os.environ.get("REMOTE_UPLOAD_MODE") == "development" - excluded = { - "node_modules", "vendor", "dist", "build", "target", "out", - ".git", ".hg", ".svn", ".vscode", ".idea", ".venv", "venv", - "__pycache__", ".pytest_cache", ".mypy_cache", ".cache", - ".context-engine", ".context-engine-uploader", ".codebase" - } - if dev_remote: - excluded.add("dev-workspace") + ext_suffixes = {str(ext).lower() for ext in CODE_EXTS if str(ext).startswith('.')} + try: + extensionless_names = {k.lower() for k in (EXTENSIONLESS_FILES or {}).keys()} + except Exception: + extensionless_names = set() + excluded = self._excluded_dirnames() seen = set() for root, dirnames, filenames in os.walk(workspace_path): dirnames[:] = [d for d in dirnames if d not in excluded and not d.startswith('.')] for filename in filenames: - if filename.startswith('.'): + # Allow dotfiles that are in EXTENSIONLESS_FILES (e.g., .gitignore) + fname_lower = filename.lower() + if filename.startswith('.') and fname_lower not in extensionless_names: continue candidate = Path(root) / filename + if self._is_ignored_path(candidate): + continue suffix = candidate.suffix.lower() - if filename in name_matches or suffix in ext_suffixes: + if ( + suffix in ext_suffixes + or fname_lower in extensionless_names + or fname_lower.startswith("dockerfile") + ): resolved = candidate.resolve() if resolved not in seen: seen.add(resolved) @@ -1368,13 +2214,13 @@ def on_any_event(self, event): # Always check src_path src_path = Path(event.src_path) - if idx.CODE_EXTS.get(src_path.suffix.lower(), "unknown") != "unknown": + if self.client._is_watchable_path(src_path): paths_to_process.append(src_path) # For FileMovedEvent, also process the destination path if hasattr(event, 'dest_path') and event.dest_path: dest_path = Path(event.dest_path) - if idx.CODE_EXTS.get(dest_path.suffix.lower(), "unknown") != "unknown": + if self.client._is_watchable_path(dest_path): paths_to_process.append(dest_path) if not paths_to_process: @@ -1395,6 +2241,8 @@ def on_any_event(self, event): def _process_pending_changes(self): """Process accumulated changes after debounce period.""" with self._lock: + # Timer fired; allow a new debounce to be armed while we process. + self._debounce_timer = None # Prevent re-entrancy if self._processing: return @@ -1406,19 +2254,21 @@ def _process_pending_changes(self): check_deletions = self._check_for_deletions self._check_for_deletions = False + upload_succeeded = False try: # Only include cached paths when deletion-related events occurred if check_deletions: cached_file_hashes = _load_local_cache_file_hashes( self.client.workspace_path, - self.client.repo_name + self.client.repo_name, + metadata_root=self.client.metadata_root, ) - all_paths = list(set(pending + [ - Path(p) for p in cached_file_hashes.keys() - ])) + cached_paths = [Path(p) for p in cached_file_hashes.keys()] + all_paths = list(set(pending + cached_paths)) else: all_paths = pending + changes = self.client.detect_file_changes(all_paths) meaningful_changes = ( len(changes.get("created", [])) + @@ -1431,7 +2281,8 @@ def _process_pending_changes(self): logger.info(f"[watch] Detected {meaningful_changes} changes: { {k: len(v) for k, v in changes.items() if k != 'unchanged'} }") success = self.client.process_changes_and_upload(changes) if success: - logger.info("[watch] Successfully uploaded changes") + self.client.log_watch_upload_result() + upload_succeeded = True else: logger.error("[watch] Failed to upload changes") else: @@ -1447,14 +2298,29 @@ def _process_pending_changes(self): success = self.client.upload_git_history_only(git_history) if success: logger.info("[watch] Successfully uploaded git history metadata") + upload_succeeded = True else: logger.error("[watch] Failed to upload git history metadata") + else: + upload_succeeded = True # No changes to process except Exception as e: logger.error(f"[watch] Error processing changes: {e}") finally: # Clear processing flag even if an error occurred with self._lock: self._processing = False + # Re-queue pending paths if upload failed + if not upload_succeeded and pending: + # Merge pending paths back into _pending_paths + for p in pending: + self._pending_paths.add(p) + # Arm next pass if there are pending paths + if self._pending_paths and self._debounce_timer is None: + self._debounce_timer = threading.Timer( + self.debounce_seconds, + self._process_pending_changes, + ) + self._debounce_timer.start() observer = Observer() @@ -1504,7 +2370,11 @@ def _watch_loop_polling(self, interval: int = 5): path_map[resolved] = p # Include any paths that are only present in the local cache (deleted files) - cached_file_hashes = _load_local_cache_file_hashes(self.workspace_path, self.repo_name) + cached_file_hashes = _load_local_cache_file_hashes( + self.workspace_path, + self.repo_name, + metadata_root=self.metadata_root, + ) for cached_abs in cached_file_hashes.keys(): try: cached_path = Path(cached_abs) @@ -1526,7 +2396,7 @@ def _watch_loop_polling(self, interval: int = 5): success = self.process_changes_and_upload(changes) if success: - logger.info(f"[watch] Successfully uploaded changes") + self.log_watch_upload_result() else: logger.error(f"[watch] Failed to upload changes") else: @@ -1584,80 +2454,7 @@ def process_and_upload_changes(self, changed_paths: List[Path]) -> bool: except Exception as e: logger.error(f"[remote_upload] Error detecting file changes: {e}") return False - - if not self.has_meaningful_changes(changes): - logger.info("[remote_upload] No meaningful changes detected, skipping upload") - return True - - # Log change summary - total_changes = sum(len(files) for op, files in changes.items() if op != "unchanged") - logger.info(f"[remote_upload] Detected {total_changes} meaningful changes: " - f"{len(changes['created'])} created, {len(changes['updated'])} updated, " - f"{len(changes['deleted'])} deleted, {len(changes['moved'])} moved") - - # Create delta bundle - bundle_path = None - try: - bundle_path, manifest = self.create_delta_bundle(changes) - logger.info(f"[remote_upload] Created delta bundle: {manifest['bundle_id']} " - f"(size: {manifest['total_size_bytes']} bytes)") - - # Validate bundle was created successfully - if not bundle_path or not os.path.exists(bundle_path): - raise RuntimeError(f"Failed to create bundle at {bundle_path}") - - except Exception as e: - logger.error(f"[remote_upload] Error creating delta bundle: {e}") - # Clean up any temporary files on failure - self.cleanup() - return False - - # Upload bundle with retry logic - try: - response = self.upload_bundle(bundle_path, manifest) - - if response.get("success", False): - processed_ops = response.get('processed_operations', {}) - logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") - logger.info(f"[remote_upload] Processed operations: {processed_ops}") - - # Clean up temporary bundle after successful upload - try: - if os.path.exists(bundle_path): - os.remove(bundle_path) - logger.debug(f"[remote_upload] Cleaned up temporary bundle: {bundle_path}") - # Also clean up the entire temp directory if this is the last bundle - self.cleanup() - except Exception as cleanup_error: - logger.warning(f"[remote_upload] Failed to cleanup bundle {bundle_path}: {cleanup_error}") - - return True - else: - error = response.get("error", {}) - error_code = error.get("code", "UNKNOWN") - error_msg = error.get("message", "Unknown error") - - logger.error(f"[remote_upload] Upload failed: {error_msg}") - - # Handle specific error types - # CLI is stateless - server handles sequence management - if error_code in ["BUNDLE_TOO_LARGE", "BUNDLE_NOT_FOUND"]: - # These are unrecoverable errors - logger.error(f"[remote_upload] Unrecoverable error ({error_code}): {error_msg}") - return False - elif error_code in ["TIMEOUT_ERROR", "CONNECTION_ERROR", "NETWORK_ERROR"]: - # These might be temporary, suggest fallback - logger.warning(f"[remote_upload] Network-related error ({error_code}): {error_msg}") - logger.warning("[remote_upload] Consider falling back to local mode if this persists") - return False - else: - # Other errors - logger.error(f"[remote_upload] Upload error ({error_code}): {error_msg}") - return False - - except Exception as e: - logger.error(f"[remote_upload] Unexpected error during upload: {e}") - return False + return self.process_changes_and_upload(changes) except Exception as e: logger.error(f"[remote_upload] Critical error in process_and_upload_changes: {e}") @@ -1688,7 +2485,7 @@ def _cleanup_dir_with_retries(path: Optional[str]) -> None: logger.debug(f"[remote_upload] Last cleanup error for {path}: {last_error}") -def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, str]: +def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, Any]: """Get remote upload configuration from environment variables and command-line arguments.""" # Use command-line path if provided, otherwise fall back to environment variables if cli_path: @@ -1698,17 +2495,10 @@ def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, str]: logical_repo_id = _compute_logical_repo_id(workspace_path) - # Use auto-generated collection name based on repo name - repo_name = _extract_repo_name_from_path(workspace_path) - # Fallback to directory name if repo detection fails - if not repo_name: - repo_name = Path(workspace_path).name - collection_name = get_collection_name(repo_name) - return { "upload_endpoint": os.environ.get("REMOTE_UPLOAD_ENDPOINT", "http://localhost:8080"), "workspace_path": workspace_path, - "collection_name": collection_name, + "collection_name": None, "logical_repo_id": logical_repo_id, # Use higher, more robust defaults but still allow env overrides "max_retries": int(os.environ.get("REMOTE_UPLOAD_MAX_RETRIES", "5")), @@ -1816,7 +2606,7 @@ def main(): config["timeout"] = args.timeout logger.info(f"Workspace path: {config['workspace_path']}") - logger.info(f"Collection name: {config['collection_name']}") + logger.info(f"Collection name: {config['collection_name'] or ''}") logger.info(f"Upload endpoint: {config['upload_endpoint']}") if args.show_mapping: @@ -1850,15 +2640,8 @@ def main(): # Test server connection first logger.info("Checking server status...") status = client.get_server_status() - is_success = ( - isinstance(status, dict) and - 'workspace_path' in status and - 'collection_name' in status and - status.get('status') == 'ready' - ) - if not is_success: - error = status.get("error", {}) - logger.error(f"Cannot connect to server: {error.get('message', 'Unknown error')}") + if not _is_usable_delta_status(status): + logger.error("Cannot connect to server: %s", _server_status_error_message(status)) return 1 logger.info("Server connection successful") @@ -1894,16 +2677,8 @@ def main(): # Test server connection logger.info("Checking server status...") status = client.get_server_status() - # For delta endpoint, success is indicated by having expected fields (not a "success" boolean) - is_success = ( - isinstance(status, dict) and - 'workspace_path' in status and - 'collection_name' in status and - status.get('status') == 'ready' - ) - if not is_success: - error = status.get("error", {}) - logger.error(f"Cannot connect to server: {error.get('message', 'Unknown error')}") + if not _is_usable_delta_status(status): + logger.error("Cannot connect to server: %s", _server_status_error_message(status)) return 1 logger.info("Server connection successful") @@ -1912,14 +2687,7 @@ def main(): logger.info("Scanning repository for files...") workspace_path = Path(config['workspace_path']) - # Find all files in the repository - all_files = [] - for file_path in workspace_path.rglob('*'): - if file_path.is_file() and not file_path.name.startswith('.'): - rel_path = file_path.relative_to(workspace_path) - # Skip .codebase directory and other metadata - if not str(rel_path).startswith('.codebase'): - all_files.append(file_path) + all_files = client.get_all_code_files() logger.info(f"Found {len(all_files)} files to upload") @@ -1929,8 +2697,7 @@ def main(): # Detect changes (treat all files as changes for initial upload) if args.force: - # Force mode: treat all files as created - changes = {"created": all_files, "updated": [], "deleted": [], "moved": [], "unchanged": []} + changes = client.build_force_changes(all_files) else: changes = client.detect_file_changes(all_files) @@ -1945,8 +2712,19 @@ def main(): success = client.process_changes_and_upload(changes) if success: - logger.info("Repository upload completed successfully!") - logger.info(f"Collection name: {config['collection_name']}") + outcome = str((client.last_upload_result or {}).get("outcome") or "") + if outcome == "skipped_by_plan": + logger.info("No upload needed after plan") + elif outcome == "queued": + logger.info("Repository upload request accepted; server processing asynchronously") + elif outcome == "uploaded_async": + logger.info( + "Repository upload processed asynchronously: %s", + (client.last_upload_result or {}).get("processed_operations") or {}, + ) + else: + logger.info("Repository upload completed successfully!") + logger.info(f"Collection name: {config['collection_name'] or ''}") logger.info(f"Files uploaded: {len(all_files)}") else: logger.error("Repository upload failed!") diff --git a/scripts/rerank_ab_test.py b/scripts/rerank_ab_test.py deleted file mode 100644 index bf592065..00000000 --- a/scripts/rerank_ab_test.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/ab_test.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.ab_test import * - -if __name__ == "__main__": - simulate_ab_test(n_sessions=100, n_queries_per_session=5) diff --git a/scripts/rerank_eval.py b/scripts/rerank_eval.py deleted file mode 100644 index ef810c82..00000000 --- a/scripts/rerank_eval.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/eval.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.eval import * - -if __name__ == "__main__": - main() diff --git a/scripts/rerank_events.py b/scripts/rerank_events.py deleted file mode 100644 index f98916af..00000000 --- a/scripts/rerank_events.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/events.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.events import * diff --git a/scripts/rerank_local.py b/scripts/rerank_local.py deleted file mode 100644 index 8a56e7d9..00000000 --- a/scripts/rerank_local.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/local.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.local import * - -if __name__ == "__main__": - main() diff --git a/scripts/rerank_query.py b/scripts/rerank_query.py deleted file mode 100644 index f20936bf..00000000 --- a/scripts/rerank_query.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/query.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.query import * - -if __name__ == "__main__": - main() diff --git a/scripts/rerank_real_benchmark.py b/scripts/rerank_real_benchmark.py deleted file mode 100644 index fd380df4..00000000 --- a/scripts/rerank_real_benchmark.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/benchmark.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.benchmark import * - -if __name__ == "__main__": - run_real_benchmark() diff --git a/scripts/rerank_recursive.py b/scripts/rerank_recursive.py deleted file mode 100644 index 879cb53b..00000000 --- a/scripts/rerank_recursive.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python3 -""" -Backwards-compatibility shim for rerank_recursive. - -This file re-exports all symbols from the scripts.rerank_recursive package -to maintain backwards compatibility for existing imports. - -All new code should import from scripts.rerank_recursive (the package) directly. - -Usage (both work): - # Old style (still works) - from scripts.rerank_recursive import RecursiveReranker, TinyScorer - - # New style (preferred) - from scripts.rerank_recursive import RecursiveReranker, TinyScorer -""" -from __future__ import annotations - -# Re-export everything from the package -from scripts.rerank_recursive import ( - # State - RefinementState, - # Core classes - TinyScorer, - LatentRefiner, - VICReg, - LearnedProjection, - LearnedHybridWeights, - QueryExpander, - ConfidenceEstimator, - RecursiveReranker, - ONNXRecursiveReranker, - SessionAwareReranker, - # Utilities - _COMMON_TOKENS, - _split_identifier, - _normalize_token, - _tokenize_for_fname_boost, - _candidate_path_for_fname_boost, - _compute_fname_boost, - _cache_key, - _get_cached_embedding, - _cache_embedding, - # Functions - rerank_recursive, - rerank_recursive_inprocess, - rerank_with_learning, - rerank_with_session, - get_recursive_reranker, - _get_learning_reranker, - # Constants - HAS_ONNX, -) - -__all__ = [ - # State - "RefinementState", - # Core classes - "TinyScorer", - "LatentRefiner", - "VICReg", - "LearnedProjection", - "LearnedHybridWeights", - "QueryExpander", - "ConfidenceEstimator", - "RecursiveReranker", - "ONNXRecursiveReranker", - "SessionAwareReranker", - # Utilities - "_COMMON_TOKENS", - "_split_identifier", - "_normalize_token", - "_tokenize_for_fname_boost", - "_candidate_path_for_fname_boost", - "_compute_fname_boost", - "_cache_key", - "_get_cached_embedding", - "_cache_embedding", - # Functions - "rerank_recursive", - "rerank_recursive_inprocess", - "rerank_with_learning", - "rerank_with_session", - "get_recursive_reranker", - "_get_learning_reranker", - # Constants - "HAS_ONNX", -] diff --git a/scripts/rerank_recursive/__init__.py b/scripts/rerank_recursive/__init__.py index 23d96ac9..877e2ba6 100644 --- a/scripts/rerank_recursive/__init__.py +++ b/scripts/rerank_recursive/__init__.py @@ -1,115 +1,9 @@ -""" -Recursive Reranker Package - TRM-inspired iterative refinement for code search. - -This package provides modular components for recursive reranking: - -Core Components: -- TinyScorer: 2-layer MLP for scoring query-document pairs -- LatentRefiner: Refines latent state based on current results -- RecursiveReranker: Main reranking pipeline +"""Recursive reranking components. -Regularization: -- VICReg: Variance-Invariance-Covariance regularization +Import concrete modules directly, for example: -Learnable Components: -- LearnedProjection: Learnable embedding projection -- LearnedHybridWeights: Learns dense vs. lexical balance -- QueryExpander: Learns query expansions from usage - -Utilities: -- RefinementState: Dataclass for latent state -- ConfidenceEstimator: Early stopping logic + from scripts.rerank_recursive.recursive import RecursiveReranker + from scripts.rerank_recursive.utils import _compute_fname_boost """ -from __future__ import annotations - -# State dataclass -from scripts.rerank_recursive.state import RefinementState - -# Utilities -from scripts.rerank_recursive.utils import ( - _COMMON_TOKENS, - _split_identifier, - _normalize_token, - _tokenize_for_fname_boost, - _candidate_path_for_fname_boost, - _compute_fname_boost, - _cache_key, - _get_cached_embedding, - _cache_embedding, -) - -# Core scorer and refiner -from scripts.rerank_recursive.scorer import TinyScorer -from scripts.rerank_recursive.refiner import LatentRefiner - -# Regularization -from scripts.rerank_recursive.vicreg import VICReg - -# Learnable components -from scripts.rerank_recursive.projection import LearnedProjection -from scripts.rerank_recursive.hybrid_weights import LearnedHybridWeights -from scripts.rerank_recursive.expander import QueryExpander - -# Early stopping -from scripts.rerank_recursive.confidence import ConfidenceEstimator - -# Alpha scheduling -from scripts.rerank_recursive.alpha_scheduler import ( - CosineAlphaScheduler, - LearnedAlphaWeights, -) - -# Main rerankers and functions -from scripts.rerank_recursive.recursive import ( - RecursiveReranker, - ONNXRecursiveReranker, - FastEmbedRecursiveReranker, - SessionAwareReranker, - rerank_recursive, - rerank_recursive_inprocess, - rerank_with_learning, - rerank_with_session, - get_recursive_reranker, - _get_learning_reranker, - HAS_ONNX, - HAS_RERANKER_FACTORY, -) -__all__ = [ - # State - "RefinementState", - # Core classes - "TinyScorer", - "LatentRefiner", - "VICReg", - "LearnedProjection", - "LearnedHybridWeights", - "QueryExpander", - "ConfidenceEstimator", - "CosineAlphaScheduler", - "LearnedAlphaWeights", - "RecursiveReranker", - "ONNXRecursiveReranker", - "FastEmbedRecursiveReranker", - "SessionAwareReranker", - # Utilities - "_COMMON_TOKENS", - "_split_identifier", - "_normalize_token", - "_tokenize_for_fname_boost", - "_candidate_path_for_fname_boost", - "_compute_fname_boost", - "_cache_key", - "_get_cached_embedding", - "_cache_embedding", - # Functions - "rerank_recursive", - "rerank_recursive_inprocess", - "rerank_with_learning", - "rerank_with_session", - "get_recursive_reranker", - "_get_learning_reranker", - # Constants - "HAS_ONNX", - "HAS_RERANKER_FACTORY", -] +__all__ = [] diff --git a/scripts/rerank_recursive/recursive.py b/scripts/rerank_recursive/recursive.py index 71bbee7c..c17f41dc 100644 --- a/scripts/rerank_recursive/recursive.py +++ b/scripts/rerank_recursive/recursive.py @@ -9,43 +9,28 @@ import os import threading import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy as np -# Safe ONNX imports -try: - import onnxruntime as ort - from tokenizers import Tokenizer +HAS_ONNX = False +HAS_RERANKER_FACTORY = True +_ONNX_RUNTIME: Optional[Tuple[Any, Any]] = None + + +def _get_onnx_runtime() -> Optional[Tuple[Any, Any]]: + global HAS_ONNX, _ONNX_RUNTIME + if _ONNX_RUNTIME is not None: + return _ONNX_RUNTIME + try: + import onnxruntime as ort + from tokenizers import Tokenizer + except ImportError: + HAS_ONNX = False + return None HAS_ONNX = True -except ImportError: - ort = None - Tokenizer = None - HAS_ONNX = False - -# Use centralized reranker factory (supports FastEmbed + ONNX backends) -try: - from scripts.reranker import ( - get_reranker_model as _get_reranker_model, - rerank_pairs as _rerank_pairs, - is_reranker_available as _is_reranker_available, - RERANKER_MODEL, - ) - HAS_RERANKER_FACTORY = True -except ImportError: - HAS_RERANKER_FACTORY = False - _get_reranker_model = None - _rerank_pairs = None - _is_reranker_available = None - RERANKER_MODEL = None - -# Legacy: direct FastEmbed imports (fallback when factory unavailable) -try: - from fastembed.rerank.cross_encoder import TextCrossEncoder - HAS_FASTEMBED_RERANK = True -except ImportError: - TextCrossEncoder = None - HAS_FASTEMBED_RERANK = False + _ONNX_RUNTIME = (ort, Tokenizer) + return _ONNX_RUNTIME from scripts.rerank_recursive.state import RefinementState from scripts.rerank_recursive.scorer import TinyScorer @@ -361,8 +346,10 @@ def __init__( def _get_onnx_session(self): if self._session is not None: return self._session, self._tokenizer - if not HAS_ONNX or not self.onnx_path or not self.tokenizer_path: + runtime = _get_onnx_runtime() + if runtime is None or not self.onnx_path or not self.tokenizer_path: return None, None + ort, Tokenizer = runtime with self._onnx_lock: if self._session is not None: return self._session, self._tokenizer @@ -532,22 +519,22 @@ def _get_model(self): """Get cached reranker model from factory.""" if self._reranker_model is not None: return self._reranker_model - if not HAS_RERANKER_FACTORY or _get_reranker_model is None: - return None with self._model_lock: if self._reranker_model is not None: return self._reranker_model - self._reranker_model = _get_reranker_model() + from scripts.reranker import get_reranker_model + self._reranker_model = get_reranker_model() return self._reranker_model def _factory_score(self, query: str, docs: List[str]) -> Optional[np.ndarray]: """Score documents using reranker factory.""" model = self._get_model() - if model is None or _rerank_pairs is None: + if model is None: return None try: + from scripts.reranker import rerank_pairs pairs = [(query, doc) for doc in docs] - scores = _rerank_pairs(pairs, model=model) + scores = rerank_pairs(pairs, model=model) return np.array(scores, dtype=np.float32) except Exception: return None @@ -845,34 +832,25 @@ def rerank_with_learning( if learn_from_onnx and candidates: teacher_scores = None if str(os.environ.get("RERANK_TEACHER_INLINE", "")).strip().lower() in {"1", "true", "yes", "on"}: + from scripts.rerank_tools.local import rerank_local + try: - from scripts.rerank_local import rerank_local - except ImportError: - try: - from rerank_local import rerank_local - except ImportError: - rerank_local = None - if rerank_local is not None: - try: - pairs = [] - for c in candidates: - doc = c.get("code") or c.get("snippet") or "" - if not doc: - parts = [] - if c.get("symbol"): - parts.append(str(c["symbol"])) - if c.get("path"): - parts.append(str(c["path"])) - doc = " ".join(parts) if parts else "empty" - pairs.append((query, doc[:1000])) - teacher_scores = rerank_local(pairs) - except Exception: - teacher_scores = None + pairs = [] + for c in candidates: + doc = c.get("code") or c.get("snippet") or "" + if not doc: + parts = [] + if c.get("symbol"): + parts.append(str(c["symbol"])) + if c.get("path"): + parts.append(str(c["path"])) + doc = " ".join(parts) if parts else "empty" + pairs.append((query, doc[:1000])) + teacher_scores = rerank_local(pairs) + except Exception: + teacher_scores = None try: - try: - from rerank_events import log_training_event - except ImportError: - from scripts.rerank_events import log_training_event + from scripts.rerank_tools.events import log_training_event log_training_event( query=query, candidates=candidates, @@ -899,14 +877,15 @@ def get_recursive_reranker(n_iterations: int = 3, **kwargs) -> RecursiveReranker Backwards compatible: existing ONNX configs continue to work. """ # Priority 1: Use factory if RERANKER_MODEL is set - if HAS_RERANKER_FACTORY and _is_reranker_available is not None: - if _is_reranker_available(): + if HAS_RERANKER_FACTORY: + from scripts.reranker import is_reranker_available + if is_reranker_available(): return FastEmbedRecursiveReranker(n_iterations=n_iterations, **kwargs) # Priority 2: Legacy ONNX path (backwards compatibility) onnx_path = os.environ.get("RERANKER_ONNX_PATH", "") tokenizer_path = os.environ.get("RERANKER_TOKENIZER_PATH", "") - if HAS_ONNX and onnx_path and tokenizer_path: + if onnx_path and tokenizer_path and _get_onnx_runtime() is not None: return ONNXRecursiveReranker(n_iterations=n_iterations, **kwargs) # Priority 3: Base reranker (no neural scoring) diff --git a/scripts/rerank_tools/ab_test.py b/scripts/rerank_tools/ab_test.py index b4c328e2..955aaf3b 100644 --- a/scripts/rerank_tools/ab_test.py +++ b/scripts/rerank_tools/ab_test.py @@ -10,7 +10,7 @@ Usage: # In your search pipeline: - from scripts.rerank_ab_test import ABTestManager, RerankerVariant + from scripts.rerank_tools.ab_test import ABTestManager, RerankerVariant ab = ABTestManager() variant = ab.get_variant(session_id="user_123") @@ -179,33 +179,17 @@ def baseline_rerank(query, candidates, **kwargs): VariantType.BASELINE, baseline_rerank ) - # Recursive reranker - try: - try: - from scripts.rerank_recursive import rerank_recursive - except ImportError: - from rerank_recursive import rerank_recursive + from scripts.rerank_recursive.recursive import rerank_recursive + from scripts.rerank_tools.local import rerank_in_process - self._variant_impls[VariantType.RECURSIVE] = RerankerVariant( - VariantType.RECURSIVE, - lambda q, c, **kw: rerank_recursive(q, c, n_iterations=3) - ) - except ImportError: - pass - - # ONNX reranker - try: - try: - from scripts.rerank_local import rerank_in_process - except ImportError: - from rerank_local import rerank_in_process - - self._variant_impls[VariantType.ONNX] = RerankerVariant( - VariantType.ONNX, - lambda q, c, **kw: rerank_in_process(q, c, limit=len(c)) - ) - except ImportError: - pass + self._variant_impls[VariantType.RECURSIVE] = RerankerVariant( + VariantType.RECURSIVE, + lambda q, c, **kw: rerank_recursive(q, c, n_iterations=3) + ) + self._variant_impls[VariantType.ONNX] = RerankerVariant( + VariantType.ONNX, + lambda q, c, **kw: rerank_in_process(q, c, limit=len(c)) + ) def _hash_to_bucket(self, session_id: str) -> float: """Hash session ID to a value in [0, 1) for consistent bucketing.""" diff --git a/scripts/rerank_tools/benchmark.py b/scripts/rerank_tools/benchmark.py index 40bbe973..be4bd527 100644 --- a/scripts/rerank_tools/benchmark.py +++ b/scripts/rerank_tools/benchmark.py @@ -9,21 +9,16 @@ 4. Uses ground truth from ONNX reranker as reference Usage: - python scripts/rerank_real_benchmark.py + python -m scripts.rerank_tools.benchmark """ import os -import sys import time import json import numpy as np -from pathlib import Path from typing import List, Dict, Any, Optional from dataclasses import dataclass, field -# Add project root to path -sys.path.insert(0, str(Path(__file__).parent.parent)) - # Real queries based on actual Context Engine functionality REAL_QUERIES = [ "hybrid search RRF fusion implementation", @@ -118,10 +113,7 @@ def benchmark_baseline(query: str, candidates: List[Dict[str, Any]]) -> RealBenc def benchmark_recursive(query: str, candidates: List[Dict[str, Any]], n_iters: int = 3) -> RealBenchmarkResult: """Benchmark recursive reranker.""" - try: - from scripts.rerank_recursive import RecursiveReranker - except ImportError: - from rerank_recursive import RecursiveReranker + from scripts.rerank_recursive.recursive import RecursiveReranker reranker = RecursiveReranker(n_iterations=n_iters, dim=256) initial_scores = [c.get("score", 0) for c in candidates] @@ -143,10 +135,7 @@ def benchmark_recursive(query: str, candidates: List[Dict[str, Any]], n_iters: i def benchmark_onnx(query: str, candidates: List[Dict[str, Any]]) -> Optional[RealBenchmarkResult]: """Benchmark ONNX cross-encoder reranker on pre-fetched candidates.""" try: - try: - from scripts.rerank_local import rerank_local - except ImportError: - from rerank_local import rerank_local + from scripts.rerank_tools.local import rerank_local # Prepare pairs for ONNX reranker pairs = [] @@ -178,10 +167,7 @@ def benchmark_onnx(query: str, candidates: List[Dict[str, Any]]) -> Optional[Rea def benchmark_session_aware(query: str, candidates: List[Dict[str, Any]], session_id: str) -> RealBenchmarkResult: """Benchmark session-aware recursive reranker.""" - try: - from scripts.rerank_recursive import SessionAwareReranker - except ImportError: - from rerank_recursive import SessionAwareReranker + from scripts.rerank_recursive.recursive import SessionAwareReranker reranker = SessionAwareReranker(n_iterations=3, dim=256) initial_scores = [c.get("score", 0) for c in candidates] @@ -208,10 +194,7 @@ def get_learning_reranker(): """Get or create the learning-enabled reranker.""" global _LEARNING_RERANKER if _LEARNING_RERANKER is None: - try: - from scripts.rerank_recursive import RecursiveReranker - except ImportError: - from rerank_recursive import RecursiveReranker + from scripts.rerank_recursive.recursive import RecursiveReranker _LEARNING_RERANKER = RecursiveReranker(n_iterations=3, dim=256) return _LEARNING_RERANKER @@ -361,10 +344,7 @@ def run_real_benchmark(): print(f" ONNX: {onnx_result.latency_ms:.2f}ms") teacher_scores = onnx_result.top_5_scores # Use full scores # Get full ONNX scores for learning - try: - from scripts.rerank_local import rerank_local - except ImportError: - from rerank_local import rerank_local + from scripts.rerank_tools.local import rerank_local pairs = [(query, c.get("code", "") or c.get("snippet", "")) for c in candidates] teacher_scores = rerank_local(pairs) @@ -502,4 +482,3 @@ def run_real_benchmark(): if __name__ == "__main__": run_real_benchmark() - diff --git a/scripts/rerank_tools/eval.py b/scripts/rerank_tools/eval.py index 1676b967..f64fe1bf 100644 --- a/scripts/rerank_tools/eval.py +++ b/scripts/rerank_tools/eval.py @@ -6,8 +6,8 @@ Designed for CI/regression testing - deterministic, no sampling. Usage: - python scripts/rerank_eval.py [--queries QUERIES_FILE] [--output OUTPUT_FILE] - python scripts/rerank_eval.py --ablations # Run all ablation modes + python scripts/rerank_tools/eval.py [--queries QUERIES_FILE] [--output OUTPUT_FILE] + python scripts/rerank_tools/eval.py --ablations # Run all ablation modes Metrics reported: - MRR@k (Mean Reciprocal Rank) @@ -19,7 +19,6 @@ import copy import json import os -import sys import time from dataclasses import dataclass, field, asdict from pathlib import Path @@ -27,9 +26,6 @@ import numpy as np -# Add project root to path -sys.path.insert(0, str(Path(__file__).parent.parent)) - # Fixed evaluation queries (deterministic, no sampling) DEFAULT_EVAL_QUERIES = [ "hybrid search RRF fusion implementation", @@ -108,7 +104,7 @@ def get_candidates(query: str, limit: int = 30) -> List[Dict[str, Any]]: def get_onnx_scores(query: str, candidates: List[Dict[str, Any]]) -> Optional[List[float]]: """Get ONNX reranker scores (ground truth).""" try: - from scripts.rerank_local import rerank_local + from scripts.rerank_tools.local import rerank_local pairs = [] for c in candidates: doc_parts = [] @@ -138,7 +134,7 @@ def rerank_recursive( ) -> List[Dict[str, Any]]: """Recursive reranker (no learning).""" try: - from scripts.rerank_recursive import RecursiveReranker + from scripts.rerank_recursive.recursive import RecursiveReranker reranker = RecursiveReranker(n_iterations=n_iterations, dim=256) initial_scores = [c.get("score", 0) for c in candidates] return reranker.rerank(query, candidates, initial_scores) @@ -154,7 +150,7 @@ def rerank_learning( ) -> List[Dict[str, Any]]: """Learning reranker (uses trained weights).""" try: - from scripts.rerank_recursive import rerank_with_learning + from scripts.rerank_recursive.recursive import rerank_with_learning return rerank_with_learning( query=query, candidates=candidates, @@ -406,4 +402,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/rerank_tools/local.py b/scripts/rerank_tools/local.py index e2151791..40cb54ac 100644 --- a/scripts/rerank_tools/local.py +++ b/scripts/rerank_tools/local.py @@ -1,43 +1,27 @@ #!/usr/bin/env python3 import os import argparse -import sys import threading -from pathlib import Path as _P from typing import List, Dict, Any, TYPE_CHECKING -# Ensure project root is on sys.path when run as a script (so 'scripts' package imports work) -_ROOT = _P(__file__).resolve().parent.parent.parent -if str(_ROOT) not in sys.path: - sys.path.insert(0, str(_ROOT)) - from qdrant_client import QdrantClient, models # Import TextEmbedding for type hints (may not be available at runtime with embedder factory) if TYPE_CHECKING: from fastembed import TextEmbedding -# Use embedder factory for Qwen3 support; fallback to direct fastembed -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - from fastembed import TextEmbedding +from scripts.embedder import get_embedding_model as _get_embedding_model + +_EMBEDDER_FACTORY = True # Use centralized reranker factory (supports FastEmbed + ONNX backends) -try: - from scripts.reranker import ( - get_reranker_model as _get_reranker_model, - rerank_pairs as _rerank_pairs, - is_reranker_available as _is_reranker_available, - ) - _RERANKER_FACTORY = True -except ImportError: - _RERANKER_FACTORY = False - _get_reranker_model = None - _rerank_pairs = None - _is_reranker_available = None +from scripts.reranker import ( + get_reranker_model as _get_reranker_model, + rerank_pairs as _rerank_pairs, + is_reranker_available as _is_reranker_available, +) + +_RERANKER_FACTORY = True # Legacy ONNX imports (fallback when factory unavailable) try: @@ -134,6 +118,10 @@ def _get_rerank_session(): from scripts.utils import sanitize_vector_name as _sanitize_vector_name +from scripts.path_scope import ( + normalize_under as _normalize_under_scope, + metadata_matches_under as _metadata_matches_under, +) def warmup_reranker(): @@ -163,18 +151,14 @@ def _start_background_warmup(): _start_background_warmup() -def _norm_under(u: str | None) -> str | None: - if not u: - return None - u = str(u).strip().replace("\\", "/") - u = "/".join([p for p in u.split("/") if p]) - if not u: - return None - if not u.startswith("/"): - return "/work/" + u - if not u.startswith("/work/"): - return "/work/" + u.lstrip("/") - return u +def _point_matches_under(pt: Any, under: str | None) -> bool: + if not under: + return True + payload = getattr(pt, "payload", None) or {} + md = payload.get("metadata") or {} + if not isinstance(md, dict): + md = {} + return _metadata_matches_under(md, under) def _select_dense_vector_name( @@ -366,18 +350,21 @@ def rerank_in_process( key="metadata.language", match=models.MatchValue(value=language) ) ) - eff_under = _norm_under(under) - if eff_under: - must.append( - models.FieldCondition( - key="metadata.path_prefix", match=models.MatchValue(value=eff_under) - ) - ) + eff_under = _normalize_under_scope(under) flt = models.Filter(must=must) if must else None - pts = dense_results(client, _model, vec_name, query, flt, topk, eff_collection) - if not pts and flt is not None: - pts = dense_results(client, _model, vec_name, query, None, topk, eff_collection) + fetch_topk = max(1, int(topk)) + if eff_under: + try: + under_mult = int(os.environ.get("RERANK_UNDER_FETCH_MULT", "4") or 4) + except Exception: + under_mult = 4 + fetch_topk = max(fetch_topk, int(limit) * max(under_mult, 2), fetch_topk * max(under_mult, 2)) + fetch_topk = min(fetch_topk, 2000) + + pts = dense_results(client, _model, vec_name, query, flt, fetch_topk, eff_collection) + if eff_under and pts: + pts = [pt for pt in pts if _point_matches_under(pt, eff_under)] if not pts: return [] @@ -447,19 +434,21 @@ def main(): key="metadata.language", match=models.MatchValue(value=args.language) ) ) - eff_under = _norm_under(args.under) - if eff_under: - must.append( - models.FieldCondition( - key="metadata.path_prefix", match=models.MatchValue(value=eff_under) - ) - ) + eff_under = _normalize_under_scope(args.under) flt = models.Filter(must=must) if must else None - pts = dense_results(client, model, vec_name, args.query, flt, args.topk, eff_collection) - # Fallback: if filtered search yields nothing, retry without filters to avoid empty rerank - if not pts and flt is not None: - pts = dense_results(client, model, vec_name, args.query, None, args.topk, eff_collection) + fetch_topk = max(1, int(args.topk)) + if eff_under: + try: + under_mult = int(os.environ.get("RERANK_UNDER_FETCH_MULT", "4") or 4) + except Exception: + under_mult = 4 + fetch_topk = max(fetch_topk, int(args.limit) * max(under_mult, 2), fetch_topk * max(under_mult, 2)) + fetch_topk = min(fetch_topk, 2000) + + pts = dense_results(client, model, vec_name, args.query, flt, fetch_topk, eff_collection) + if eff_under and pts: + pts = [pt for pt in pts if _point_matches_under(pt, eff_under)] if not pts: return pairs = prepare_pairs(args.query, pts) diff --git a/scripts/rerank_tools/query.py b/scripts/rerank_tools/query.py index 9c3eeee9..9957c694 100644 --- a/scripts/rerank_tools/query.py +++ b/scripts/rerank_tools/query.py @@ -1,27 +1,17 @@ #!/usr/bin/env python3 import os import argparse -import sys from collections import defaultdict from typing import List, Dict, Any -from pathlib import Path from qdrant_client import QdrantClient, models import re -ROOT_DIR = Path(__file__).resolve().parents[2] -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - from scripts.utils import sanitize_vector_name -# Use embedder factory for Qwen3 support; fallback to direct fastembed -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - from fastembed import TextEmbedding +from scripts.embedder import get_embedding_model as _get_embedding_model + +_EMBEDDER_FACTORY = True # Env configuration diff --git a/scripts/rerank_tools/train.py b/scripts/rerank_tools/train.py index f3bafaf3..71769636 100644 --- a/scripts/rerank_tools/train.py +++ b/scripts/rerank_tools/train.py @@ -10,13 +10,13 @@ Usage: # Generate synthetic training data - python scripts/rerank_train.py --generate-data --output data/rerank_train.jsonl + python scripts/rerank_tools/train.py --generate-data --output data/rerank_train.jsonl # Train the model - python scripts/rerank_train.py --train --data data/rerank_train.jsonl --epochs 100 + python scripts/rerank_tools/train.py --train --data data/rerank_train.jsonl --epochs 100 # Evaluate - python scripts/rerank_train.py --evaluate --data data/rerank_test.jsonl + python scripts/rerank_tools/train.py --evaluate --data data/rerank_test.jsonl """ import os diff --git a/scripts/rerank_train.py b/scripts/rerank_train.py deleted file mode 100644 index a9ef2079..00000000 --- a/scripts/rerank_train.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python3 -"""Backward-compatibility shim. See scripts/rerank_tools/train.py""" -import sys -from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from scripts.rerank_tools.train import * - -if __name__ == "__main__": - main() diff --git a/scripts/router_eval.py b/scripts/router_eval.py deleted file mode 100644 index d429a08c..00000000 --- a/scripts/router_eval.py +++ /dev/null @@ -1,385 +0,0 @@ -import argparse -import json, os, threading, time, sys, re, copy -from http.server import HTTPServer, BaseHTTPRequestHandler -from typing import Dict, Any, List, Tuple - -# Simple Mock MCP server for evals -class MockMCPHandler(BaseHTTPRequestHandler): - server_version = "MockMCP/0.1" - - def _send_json(self, obj: Dict[str, Any], session: str | None = None, code: int = 200): - body = json.dumps(obj).encode("utf-8") - self.send_response(code) - self.send_header("Content-Type", "application/json") - if session: - self.send_header("Mcp-Session-Id", session) - self.send_header("Content-Length", str(len(body))) - self.end_headers() - self.wfile.write(body) - - def do_POST(self): # noqa: N802 - raw = self.rfile.read(int(self.headers.get("Content-Length", "0") or 0)) - try: - j = json.loads(raw.decode("utf-8", errors="ignore")) - except Exception: - return self._send_json({"jsonrpc": "2.0", "error": {"message": "bad json"}}, code=400) - method = j.get("method") - if method == "initialize": - # Return session via header; some clients also parse body - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"ok": True, "server": self.server.server_name}}, session="mock-session") - if method == "notifications/initialized": - return self._send_json({"jsonrpc": "2.0", "result": {"ok": True}}) - if method == "tools/list": - # Simulate flakiness once if flagged - if getattr(self.server, "fail_list_once", False) and not getattr(self.server, "_fail_list_consumed", False): - setattr(self.server, "_fail_list_consumed", True) - return self._send_json({"jsonrpc": "2.0", "error": {"message": "flaky list"}}, code=500) - tools = getattr(self.server, "tools", []) - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"tools": tools}}) - if method == "tools/call": - params = j.get("params") or {} - name = (params.get("name") or "").strip() - args = params.get("arguments") or {} - # Indexer tools - if name in {"repo_search", "search_config_for", "search_tests_for", "search_callers_for", "search_importers_for"}: - total = int(getattr(self.server, "search_total", 5)) - # Cap returned items to avoid huge payloads; still report full total - shown = max(0, min(total, 3)) - results = [ - {"score": 0.9 - (i * 0.1), "path": f"/work/README_{i}.md", "start_line": 1, "end_line": 2, "snippet": "demo"} - for i in range(shown) - ] - res = { - "result": { - "args": { - "queries": [str(args.get("query") or "")], - "limit": int(args.get("limit") or 8), - "include_snippet": bool(args.get("include_snippet") or False), - "language": str(args.get("language") or ""), - "under": str(args.get("under") or ""), - "symbol": str(args.get("symbol") or ""), - "ext": str(args.get("ext") or ""), - "compact": False, - }, - "total": total, - "results": results, - "ok": True, - "code": 0, - "stdout": "", - "stderr": "", - } - } - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - if name == "context_answer_compat": - # Simulate failure if flagged so router should fall back to context_answer - if getattr(self.server, "fail_context_compat", False): - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps({"error": "compat failed"})}], "structuredContent": {"error": "compat failed"}, "isError": True}}) - # Require nested arguments wrapper - if not isinstance(args, dict) or "arguments" not in args: - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps({"error": "compat requires nested arguments"})}], "structuredContent": {"error": "compat requires nested arguments"}, "isError": True}}) - inner = args.get("arguments") or {} - q = str(inner.get("query") or "") - ans = { - "answer": "Short ok." if len(q) < 80 else "Longer answer", - "citations": [{"id": 1, "path": "/work/file.py", "start_line": 1, "end_line": 2}], - "query": [q], - "used": {"gate_first": True, "refrag": True}, - } - res = {"result": ans} - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - if name == "context_answer": - q = str(args.get("query") or "") - ans = {"answer": "Ok.", "citations": []} - res = {"result": ans} - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - if name in {"qdrant_status", "qdrant_list"}: - res = {"result": {"ok": True}} - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - # Memory tools - if name == "store": - res = {"result": {"ok": True}} - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - if name == "find": - q = str(args.get("query") or "") - res = {"result": {"ok": True, "results": [{"information": "The MCP indexer uses hybrid search combining dense embeddings and lexical matching with optional reranking", "metadata": {"category": "architecture"}}], "count": 1}} - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps(res)}], "structuredContent": res, "isError": False}}) - return self._send_json({"jsonrpc": "2.0", "id": j.get("id"), "result": {"content": [{"type": "text", "text": json.dumps({"error": f"unknown tool {name}"})}], "structuredContent": {"error": f"unknown tool {name}"}, "isError": True}}) - return self._send_json({"jsonrpc": "2.0", "error": {"message": f"unknown method {method}"}}, code=400) - - -def start_mock_server(port: int, tools: List[Dict[str, Any]]) -> Tuple[HTTPServer, threading.Thread]: - httpd = HTTPServer(("localhost", port), MockMCPHandler) - httpd.tools = tools # type: ignore - t = threading.Thread(target=httpd.serve_forever, daemon=True) - t.start() - # Warmup - time.sleep(0.05) - return httpd, t - - -def tool(name: str, description: str, params: List[str] = None) -> Dict[str, Any]: - schema = {"type": "object", "properties": {p: {"type": "string"} for p in (params or [])}} - return {"name": name, "description": description, "inputSchema": schema} - - -def run_eval_suite(verbose: bool = False) -> int: - # Two mock servers: indexer and memory - indexer_tools = [ - tool("repo_search", "General code search", ["query", "limit", "include_snippet", "language", "under", "symbol", "ext"]), - tool("search_config_for", "Intent-specific search for configuration files", ["query", "limit", "include_snippet"]), - tool("search_importers_for", "Find files importing a module or symbol", ["query", "limit", "language", "under"]), - tool("context_answer_compat", "Answer a question using code context (compat)", ["arguments"]), - tool("context_answer", "Answer a question using code context", ["query", "limit"]), - tool("qdrant_status", "Qdrant status"), - tool("qdrant_list", "Qdrant list"), - ] - memory_tools = [tool("store", "Store memory", ["information"]), tool("find", "Find memory", ["query", "limit"])] - idx, _ = start_mock_server(18031, indexer_tools) - mem, _ = start_mock_server(18032, memory_tools) - - try: - os.environ["MCP_INDEXER_HTTP_URL"] = "http://localhost:18031/mcp" - os.environ["MCP_MEMORY_HTTP_URL"] = "http://localhost:18032/mcp" - os.environ["ROUTER_SEARCH_LIMIT"] = "8" - os.environ["ROUTER_INCLUDE_SNIPPET"] = "1" - - # Import router after env set so its defaults bind to mock URLs - import importlib.util as _ilu - _p = os.path.join(os.path.dirname(__file__), "mcp_router.py") - _spec = _ilu.spec_from_file_location("mcp_router", _p) - router = _ilu.module_from_spec(_spec) - assert _spec and _spec.loader - _spec.loader.exec_module(router) # type: ignore - - failures = [] - intent_logs: List[Dict[str, Any]] = [] - - def run_plan(q: str) -> List[Tuple[str, Dict[str, Any]]]: - plan = router.build_plan(q) - debug = getattr(router, "_LAST_INTENT_DEBUG", {}) - if isinstance(debug, dict): - log_entry = copy.deepcopy(debug) - else: - log_entry = {"debug": debug} - log_entry["query"] = q - log_entry["plan_first_tool"] = plan[0][0] if plan else None - intent_logs.append(log_entry) - return plan - - # 1) Signature selection: prefer search_config_for for config changes - p1 = run_plan("compare callers to config changes") - if not p1 or p1[0][0] != "search_config_for": - failures.append("signature selection: expected search_config_for") - - # 2) Repo hints: language+under parsed - p2 = run_plan("who imports foo in python under src/lib") - if not p2 or p2[0][0] != "search_importers_for": - failures.append("repo hints: expected search_importers_for") - else: - args2 = p2[0][1] - if args2.get("language") != "python": - failures.append("repo hints: language not parsed") - if args2.get("under") != "src/lib": - failures.append("repo hints: under not parsed") - - # 3) Design recap: memory find precedes answer - p3 = run_plan("recap our architecture decisions for the indexer") - expect_order = ["find", "context_answer_compat"] - if not p3 or [p3[0][0], p3[1][0]] != expect_order: - failures.append("design recap plan: expected find -> context_answer_compat") - - # 4) Multi-intent: store + reindex - p4 = run_plan("remember this: prefer concise answers; then reindex fresh") - if not p4 or [p4[0][0], p4[1][0]] != ["store", "qdrant_index_root"]: - failures.append("multi-intent: expected store then index") - else: - store_args = p4[0][1] or {} - info4 = (store_args.get("information") or "").lower() - if "remember" in info4: - failures.append("multi-intent: trigger phrase leaked into stored information") - if "reindex" in info4: - failures.append("multi-intent: reindex fragment leaked into stored information") - if not p4[1][1].get("recreate"): - failures.append("multi-intent: expected recreate true") - - # 5) Memory metadata extraction - p_meta = run_plan("remember this [priority=high tags=ux,frontend]: update the signup banner copy") - if not p_meta or p_meta[0][0] != "store": - failures.append("memory metadata: expected store intent") - else: - store_args = p_meta[0][1] or {} - if store_args.get("information") != "update the signup banner copy": - failures.append("memory metadata: information not cleaned") - md = store_args.get("metadata") or {} - if md.get("priority") != "high": - failures.append("memory metadata: priority missing") - if md.get("tags") != ["ux", "frontend"]: - failures.append("memory metadata: tags mismatch") - - # 6) Glob/exclude filters - p5 = run_plan("search only *.py files exclude vendor") - if not p5: - failures.append("glob: plan empty") - else: - args5 = p5[0][1] - gl = (args5 or {}).get("path_glob") or [] - ng = (args5 or {}).get("not_glob") or [] - if "**/*.py" not in gl: - failures.append("glob: missing **/*.py") - if "**/vendor/**" not in ng: - failures.append("glob: missing exclude vendor") - - # 7) Run end-to-end for recap and ensure compat accepted and short answers not rejected - # Capture stdout of router.main - def run_router(args: List[str]) -> str: - from io import StringIO - old = sys.stdout - try: - buf = StringIO() - sys.stdout = buf - router.main(args) - return buf.getvalue() - finally: - sys.stdout = old - out = run_router(["--run", "recap our architecture decisions for the indexer"]) - def run_router_code(args: List[str]) -> int: - from io import StringIO - old = sys.stdout - try: - sys.stdout = StringIO() # suppress stdout capture to avoid noise - return int(router.main(args)) - finally: - sys.stdout = old - - if "compat requires nested arguments" in out: - failures.append("compat: still sending flattened args") - if "Memory context:" not in out: - failures.append("memory→answer: query was not augmented with memory context") - print("--- router stdout ---\n" + out + "\n--- end stdout ---") - - - # 6b) Repeat immediately after recap should skip fresh memory.find step - out_repeat = run_router(["--run", "repeat that"]) - if '"skipped": "scratchpad_fresh"' not in out_repeat: - failures.append("repeat: find step not skipped on fresh cache") - - # 7) Repeat last: persist then repeat - _ = run_router(["--run", "who imports foo in python under src/lib"]) - p7a = run_plan("who imports foo in python under src/lib") - p7b = run_plan("repeat that") - if p7a != p7b: - failures.append("repeat: last plan not reused") - - # 7b) "same filters" carry-over in planning - p7c = run_plan("search with same filters for bar baz") - if not p7c: - failures.append("same filters: plan empty") - else: - args7c = p7c[0][1] - if (args7c or {}).get("language") != "python" or (args7c or {}).get("under") != "src/lib": - failures.append("same filters: did not reuse prior language/under") - - - - # 8) Fallback on compat failure - setattr(idx, "fail_context_compat", True) - out2 = run_router(["--run", "recap our architecture decisions for the indexer"]) - if '"tool": "context_answer"' not in out2: - failures.append("fallback: did not call context_answer after compat failure") - setattr(idx, "fail_context_compat", False) - - # 9) tools/list flakiness toleration - setattr(idx, "fail_list_once", True) - p9 = run_plan("find config changes") - if not p9: - failures.append("discovery flakiness: plan empty after retry") - setattr(idx, "fail_list_once", False) - - - # 10) Expand on last summary uses prior summary and citations (fresh) - out3 = run_router(["--run", "expand on that summary"]) - if "Prior summary:" not in out3 or "/work/file.py" not in out3: - failures.append("expand: prior summary/citations not injected when fresh") - - # 11) TTL expiry should suppress prior summary injection - os.environ["ROUTER_SCRATCHPAD_TTL_SEC"] = "0" - out4 = run_router(["--run", "expand on that summary"]) - if "Prior summary:" in out4: - failures.append("ttl: prior summary injected despite stale cache") - os.environ.pop("ROUTER_SCRATCHPAD_TTL_SEC", None) - - # 13) Divergence fatal per-tool: repo_search set to fatal should cause nonzero exit - os.environ["ROUTER_DIVERGENCE_FATAL_TOOLS"] = "repo_search" - setattr(idx, "search_total", 6) - _ = run_router(["--run", "search for demo"]) - setattr(idx, "search_total", 2) - code_div = run_router_code(["--run", "search for demo"]) - if code_div == 0: - failures.append("divergence fatal: router returned success despite fatal policy") - os.environ.pop("ROUTER_DIVERGENCE_FATAL_TOOLS", None) - setattr(idx, "search_total", 5) - - # 12) Divergence detection: baseline high → lower later should print a divergence notice - setattr(idx, "search_total", 6) - _ = run_router(["--run", "search for demo"]) - setattr(idx, "search_total", 2) - out_div = run_router(["--run", "search for demo"]) - if '"divergence"' not in out_div: - failures.append("divergence: no divergence flagged on material drop") - - fallback_logs = [] - for log in intent_logs: - if log.get("strategy") == "ml": - if "confidence" not in log: - failures.append(f"intent log missing confidence for query: {log.get('query')}") - if ( - log.get("intent") == router.INTENT_SEARCH - and log.get("top_candidate") - and log.get("top_candidate") != router.INTENT_SEARCH - ): - if log.get("confidence", 0.0) >= log.get("threshold", 0.25): - failures.append(f"intent fallback without low confidence for query: {log.get('query')}") - else: - fallback_logs.append(log) - if fallback_logs: - print("Intent fallback diagnostics:") - for item in fallback_logs: - try: - score = float(item.get("confidence") or 0.0) - except Exception: - score = 0.0 - print( - f" query={item.get('query')!r} top={item.get('top_candidate')} " - f"score={score:.3f} -> intent={item.get('intent')} first_tool={item.get('plan_first_tool')}" - ) - if verbose: - print("Intent diagnostics (all):") - for item in intent_logs: - try: - score = float(item.get("confidence") or 0.0) - except Exception: - score = 0.0 - print( - f" query={item.get('query')!r} strategy={item.get('strategy')} " - f"intent={item.get('intent')} score={score:.3f} " - f"top={item.get('top_candidate')} first_tool={item.get('plan_first_tool')}" - ) - - if failures: - print("Router eval: FAIL\n- " + "\n- ".join(failures)) - return 1 - print("Router eval: PASS (all checks)") - return 0 - finally: - idx.shutdown(); mem.shutdown() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run router evaluation suite.") - parser.add_argument( - "--verbose", - action="store_true", - help="Print intent confidence diagnostics after the suite completes.", - ) - args = parser.parse_args() - raise SystemExit(run_eval_suite(verbose=args.verbose)) diff --git a/scripts/run_init_maintenance.py b/scripts/run_init_maintenance.py new file mode 100644 index 00000000..b629b267 --- /dev/null +++ b/scripts/run_init_maintenance.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +"""Run the init maintenance script sequence under the shared watcher lock.""" + +from __future__ import annotations + +from scripts.watch_index_core.init_maintenance import run_init_maintenance_once + + +def main() -> int: + return 0 if run_init_maintenance_once() else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/semantic_expansion.py b/scripts/semantic_expansion.py index 76e472c1..2dd999fd 100644 --- a/scripts/semantic_expansion.py +++ b/scripts/semantic_expansion.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + """ Semantic similarity-based query expansion for Context-Engine. @@ -9,44 +11,34 @@ import os import math import re -from typing import List, Dict, Any, Tuple, Optional, Set +from typing import List, Dict, Any, Tuple, Optional, Set, TYPE_CHECKING from collections import defaultdict import logging logger = logging.getLogger("semantic_expansion") -# Import embedding functionality (prefer embedder factory for Qwen3 support) -try: - from scripts.embedder import get_embedding_model as _get_embedding_model - _EMBEDDER_FACTORY = True - FASTEMBED_AVAILABLE = True -except ImportError: - _EMBEDDER_FACTORY = False - try: - from fastembed import TextEmbedding - FASTEMBED_AVAILABLE = True - except ImportError: - FASTEMBED_AVAILABLE = False - TextEmbedding = None - -# Import Qdrant client for vector operations -try: - from qdrant_client import QdrantClient, models - QDRANT_AVAILABLE = True -except ImportError: - QDRANT_AVAILABLE = False - QdrantClient = None - models = None - -# Import local utilities -try: - from scripts.utils import ( - lex_hash_vector_queries as _lex_hash_vector_queries, - sanitize_vector_name as _sanitize_vector_name, - ) -except ImportError: - _lex_hash_vector_queries = None - _sanitize_vector_name = None +if TYPE_CHECKING: + from qdrant_client import QdrantClient, models as models +else: + QdrantClient = Any + + class _LazyQdrantModels: + def __getattr__(self, name: str) -> Any: + from qdrant_client import models as _models + + return getattr(_models, name) + + models = _LazyQdrantModels() + +from scripts.embedder import get_embedding_model as _get_embedding_model +from scripts.utils import ( + lex_hash_vector_queries as _lex_hash_vector_queries, + sanitize_vector_name as _sanitize_vector_name, +) + +_EMBEDDER_FACTORY = True +FASTEMBED_AVAILABLE = True +QDRANT_AVAILABLE = True # Configuration defaults # NOTE: SEMANTIC_EXPANSION_ENABLED is intentionally *not* a module-level constant. @@ -62,18 +54,15 @@ def _semantic_expansion_enabled() -> bool: SEMANTIC_EXPANSION_CACHE_TTL = float(os.environ.get("SEMANTIC_EXPANSION_CACHE_TTL", "3600") or "3600") # Use UnifiedCache for proper LRU eviction instead of simple FIFO -try: - from scripts.cache_manager import UnifiedCache, EvictionPolicy - _expansion_cache = UnifiedCache( - name="semantic_expansion", - max_size=SEMANTIC_EXPANSION_CACHE_SIZE, - eviction_policy=EvictionPolicy.LRU, - default_ttl=SEMANTIC_EXPANSION_CACHE_TTL, - ) - _UNIFIED_CACHE = True -except ImportError: - _expansion_cache: Dict[str, List[str]] = {} # type: ignore - _UNIFIED_CACHE = False +from scripts.cache_manager import UnifiedCache, EvictionPolicy + +_expansion_cache = UnifiedCache( + name="semantic_expansion", + max_size=SEMANTIC_EXPANSION_CACHE_SIZE, + eviction_policy=EvictionPolicy.LRU, + default_ttl=SEMANTIC_EXPANSION_CACHE_TTL, +) +_UNIFIED_CACHE = True _cache_hits = 0 _cache_misses = 0 diff --git a/scripts/smoke_test.py b/scripts/smoke_test.py index 2a699278..827dcda5 100644 --- a/scripts/smoke_test.py +++ b/scripts/smoke_test.py @@ -1,21 +1,9 @@ #!/usr/bin/env python3 import os import sys -from pathlib import Path from qdrant_client import QdrantClient -# Ensure scripts is importable -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -# Use embedder factory for Qwen3 support -try: - from scripts.embedder import get_embedding_model - _EMBEDDER_FACTORY = True -except ImportError: - _EMBEDDER_FACTORY = False - from fastembed import TextEmbedding +from scripts.embedder import get_embedding_model from scripts.utils import sanitize_vector_name QDRANT_URL = os.environ.get("QDRANT_URL", "http://qdrant:6333") @@ -37,10 +25,7 @@ count = None # Prepare query embedding -if _EMBEDDER_FACTORY: - model = get_embedding_model(MODEL) -else: - model = TextEmbedding(model_name=MODEL) +model = get_embedding_model(MODEL) query = "python code indexer for qdrant" vec = next(model.embed([query])) diff --git a/scripts/standalone_upload_client.py b/scripts/standalone_upload_client.py index 7cbd9dd1..7053dad1 100644 --- a/scripts/standalone_upload_client.py +++ b/scripts/standalone_upload_client.py @@ -39,15 +39,60 @@ except ImportError: WATCHDOG_AVAILABLE = False -try: - from upload_auth_utils import get_auth_session # type: ignore[import] -except ImportError: - def get_auth_session(upload_endpoint: str) -> str: - return "" +from scripts.upload_auth_utils import get_auth_session # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +_git_history_skip_log_key: Optional[str] = None + + +def _is_usable_delta_status(status: Any) -> bool: + if not isinstance(status, dict): + return False + state = str(status.get("status") or "").strip().lower() + return ( + bool(status.get("success")) and + "workspace_path" in status and + "collection_name" in status and + state in {"ready", "processing", "completed"} + ) + + +def _server_status_error_message(status: Any) -> str: + if isinstance(status, dict): + error = status.get("error") + if isinstance(error, dict): + msg = str(error.get("message") or "").strip() + if msg: + return msg + state = str(status.get("status") or "").strip() + if state: + return f"Server status is {state}" + return "Invalid server status response" + + +def _env_flag(name: str, default: bool) -> bool: + raw = os.environ.get(name) + if raw is None: + return default + return str(raw).strip().lower() in {"1", "true", "yes", "on"} + + +def _format_cached_sha1(value: Optional[str]) -> Optional[str]: + raw = str(value or "").strip() + if not raw: + return None + return raw if raw.lower().startswith("sha1:") else f"sha1:{raw}" + + +def _log_git_history_skip_once(reason: str, key: str) -> None: + global _git_history_skip_log_key + marker = f"{reason}:{key}" + if _git_history_skip_log_key == marker: + return + _git_history_skip_log_key = marker + logger.info("[git_history] skip (%s): %s", reason, key) DEFAULT_MAX_TEMP_CLEAN_ATTEMPTS = 3 DEFAULT_TEMP_CLEAN_SLEEP = 1.0 @@ -176,17 +221,6 @@ def hash_id(text: str, path: str, start: int, end: int) -> str: ).hexdigest() return h[:16] -def get_collection_name(repo_name: Optional[str] = None) -> str: - """Generate collection name with 8-char hash for local workspaces. - - Simplified version from workspace_state.py. - """ - if not repo_name: - return "default-collection" - hash_obj = hashlib.sha256(repo_name.encode()) - short_hash = hash_obj.hexdigest()[:8] - return f"{repo_name}-{short_hash}" - def _extract_repo_name_from_path(workspace_path: str) -> str: """Extract repository name from workspace path. @@ -290,6 +324,10 @@ def remove_hash(self, file_path: str) -> None: self._cache = file_hashes self._cache_loaded = True + def flush(self) -> None: + """Persist the current in-memory cache state to disk.""" + self._save_cache(dict(self._load_cache())) + def _cache_seems_stale(self, file_hashes: Dict[str, str]) -> bool: """Return True if a large portion of cached paths no longer exist on disk.""" total = len(file_hashes) @@ -341,6 +379,13 @@ def remove_cached_file(file_path: str, repo_name: Optional[str] = None) -> None: _hash_cache.remove_hash(file_path) +def flush_cached_file_hashes() -> None: + """Persist the current workspace hash cache to disk.""" + global _hash_cache + if _hash_cache: + _hash_cache.flush() + + def _find_git_root(start: Path) -> Optional[Path]: """Best-effort detection of the git repository root for a workspace. @@ -426,10 +471,12 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str } if max_commits <= 0: + _log_git_history_skip_once("disabled", f"max_commits={max_commits}") return None root = _find_git_root(Path(workspace_path)) if not root: + _log_git_history_skip_once("no_repo", workspace_path) return None # Git history cache: avoid emitting identical manifests when HEAD/settings are unchanged @@ -463,6 +510,7 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str cache = {} if current_head and cache.get("last_head") == current_head and cache.get("max_commits") == max_commits and str(cache.get("since") or "") == since: + _log_git_history_skip_once("cache_hit", f"head={current_head[:10]} since={since or '-'} max={max_commits}") return None base_head = "" @@ -513,12 +561,20 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str errors="replace", ) if proc.returncode != 0 or not proc.stdout.strip(): + _log_git_history_skip_once( + "rev_list_empty", + f"head={current_head[:10] if current_head else '-'} rc={proc.returncode}", + ) return None commits = [l.strip() for l in proc.stdout.splitlines() if l.strip()] except Exception: return None if not commits: + _log_git_history_skip_once( + "no_commits", + f"head={current_head[:10] if current_head else '-'}", + ) return None if len(commits) > max_commits: commits = commits[:max_commits] @@ -592,6 +648,10 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str continue if not records: + _log_git_history_skip_once( + "no_records", + f"commits={len(commits)} head={current_head[:10] if current_head else '-'}", + ) return None try: @@ -611,6 +671,14 @@ def _collect_git_history_for_workspace(workspace_path: str) -> Optional[Dict[str "since": since, "commits": records, } + logger.info( + "[git_history] prepared manifest mode=%s commits=%d head=%s prev=%s base=%s", + manifest["mode"], + len(records), + (current_head[:10] if current_head else "-"), + (prev_head[:10] if prev_head else "-"), + (base_head[:10] if base_head else "-"), + ) # Update git history cache with the HEAD and settings used for this manifest try: @@ -663,7 +731,7 @@ def _translate_to_container_path(self, host_path: str) -> str: return host_path.replace('\\', '/').replace(':', '') - def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: str, + def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: Optional[str] = None, max_retries: int = 3, timeout: int = 30, metadata_path: Optional[str] = None, logical_repo_id: Optional[str] = None): """Initialize remote upload client.""" @@ -675,9 +743,6 @@ def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: s self.temp_dir = None self.logical_repo_id = logical_repo_id - # Set environment variables for cache functions - os.environ["WORKSPACE_PATH"] = workspace_path - # Store repo name and initialize hash cache self.repo_name = _extract_repo_name_from_path(workspace_path) # Fallback to directory name if repo detection fails (for non-git repos) @@ -695,6 +760,129 @@ def __init__(self, upload_endpoint: str, workspace_path: str, collection_name: s adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) self.session.mount("https://", adapter) + self.last_upload_result: Dict[str, Any] = {"outcome": "idle"} + self._last_plan_payload: Optional[Dict[str, Any]] = None + + def _set_last_upload_result(self, outcome: str, **details: Any) -> Dict[str, Any]: + result: Dict[str, Any] = {"outcome": outcome} + result.update(details) + self.last_upload_result = result + return result + + def log_watch_upload_result(self) -> None: + outcome = str((self.last_upload_result or {}).get("outcome") or "") + if outcome == "skipped_by_plan": + logger.info("[watch] No upload needed after plan") + elif outcome == "queued": + logger.info("[watch] Upload request accepted; server processing asynchronously") + elif outcome == "uploaded_async": + processed = (self.last_upload_result or {}).get("processed_operations") + logger.info("[watch] Upload processed asynchronously: %s", processed or {}) + elif outcome == "uploaded": + logger.info("[watch] Successfully uploaded changes") + elif outcome == "no_changes": + logger.info("[watch] No meaningful changes to upload") + else: + logger.info("[watch] Upload handling completed") + + def _finalize_successful_changes(self, changes: Dict[str, List]) -> None: + for path in changes.get("created", []): + try: + abs_path = str(path.resolve()) + current_hash = hashlib.sha1(path.read_bytes()).hexdigest() + set_cached_file_hash(abs_path, current_hash, self.repo_name) + stat = path.stat() + self._stat_cache[abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + for path in changes.get("updated", []): + try: + abs_path = str(path.resolve()) + current_hash = hashlib.sha1(path.read_bytes()).hexdigest() + set_cached_file_hash(abs_path, current_hash, self.repo_name) + stat = path.stat() + self._stat_cache[abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + for path in changes.get("deleted", []): + try: + abs_path = str(path.resolve()) + remove_cached_file(abs_path, self.repo_name) + self._stat_cache.pop(abs_path, None) + except Exception: + continue + for source_path, dest_path in changes.get("moved", []): + try: + source_abs_path = str(source_path.resolve()) + remove_cached_file(source_abs_path, self.repo_name) + self._stat_cache.pop(source_abs_path, None) + except Exception: + pass + try: + dest_abs_path = str(dest_path.resolve()) + current_hash = hashlib.sha1(dest_path.read_bytes()).hexdigest() + set_cached_file_hash(dest_abs_path, current_hash, self.repo_name) + stat = dest_path.stat() + self._stat_cache[dest_abs_path] = ( + getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), + stat.st_size, + ) + except Exception: + continue + + def _await_async_upload_result( + self, + bundle_id: Optional[str], + sequence_number: Optional[int], + ) -> Optional[Dict[str, Any]]: + try: + max_wait = float(os.environ.get("CTXCE_REMOTE_UPLOAD_STATUS_WAIT_SECS", "5")) + except Exception: + max_wait = 5.0 + if max_wait <= 0: + return None + + try: + poll_interval = float(os.environ.get("CTXCE_REMOTE_UPLOAD_STATUS_POLL_INTERVAL_SECS", "1")) + except Exception: + poll_interval = 1.0 + poll_interval = max(0.1, poll_interval) + + deadline = time.time() + max_wait + while time.time() < deadline: + status = self.get_server_status() + if not status.get("success"): + return None + server_info = status.get("server_info", {}) if isinstance(status, dict) else {} + last_bundle_id = server_info.get("last_bundle_id") + last_upload_status = server_info.get("last_upload_status") + last_sequence = status.get("last_sequence") + bundle_matches = bool(bundle_id) and last_bundle_id == bundle_id + sequence_matches = sequence_number is not None and last_sequence == sequence_number + if bundle_matches or sequence_matches: + if last_upload_status == "completed": + return { + "outcome": "uploaded_async", + "bundle_id": last_bundle_id or bundle_id, + "sequence_number": last_sequence if last_sequence is not None else sequence_number, + "processed_operations": server_info.get("last_processed_operations"), + "processing_time_ms": server_info.get("last_processing_time_ms"), + } + if last_upload_status in ("failed", "error"): + return { + "outcome": "failed", + "bundle_id": last_bundle_id or bundle_id, + "sequence_number": last_sequence if last_sequence is not None else sequence_number, + "error": server_info.get("last_error"), + } + time.sleep(poll_interval) + return None def __enter__(self): """Context manager entry.""" @@ -715,7 +903,7 @@ def get_mapping_summary(self) -> Dict[str, Any]: container_path = self._translate_to_container_path(self.workspace_path) return { "repo_name": self.repo_name, - "collection_name": self.collection_name, + "collection_name": self.collection_name or "", "source_path": self.workspace_path, "container_path": container_path, "upload_endpoint": self.upload_endpoint, @@ -730,6 +918,51 @@ def log_mapping_summary(self) -> None: logger.info(f" source_path: {info['source_path']}") logger.info(f" container_path: {info['container_path']}") + def _excluded_dirnames(self) -> frozenset: + # Keep in sync with get_all_code_files exclusions. + # NOTE: This caches the exclusion set per client instance. + # Runtime changes to DEV_REMOTE_MODE/REMOTE_UPLOAD_MODE won't be reflected + # until a new client is created (typically via process restart), which is + # acceptable for the standalone upload client use case. + cached = getattr(self, "_excluded_dirnames_cache", None) + if cached is not None: + return cached + excluded = { + "node_modules", "vendor", "dist", "build", "target", "out", + ".git", ".hg", ".svn", ".vscode", ".idea", ".venv", "venv", + "__pycache__", ".pytest_cache", ".mypy_cache", ".cache", + ".context-engine", ".context-engine-uploader", ".codebase", + } + dev_remote = os.environ.get("DEV_REMOTE_MODE") == "1" or os.environ.get("REMOTE_UPLOAD_MODE") == "development" + if dev_remote: + excluded.add("dev-workspace") + cached = frozenset(excluded) + self._excluded_dirnames_cache = cached + return cached + + def _is_ignored_path(self, path: Path) -> bool: + """Return True when path is outside workspace or under excluded dirs.""" + try: + workspace_root = Path(self.workspace_path).resolve() + rel = path.resolve().relative_to(workspace_root) + except Exception: + return True + + dir_parts = set(rel.parts[:-1]) if len(rel.parts) > 1 else set() + if dir_parts & self._excluded_dirnames(): + return True + # Ignore hidden directories anywhere under the workspace, but allow + # extensionless dotfiles like `.gitignore` that we explicitly support. + if any(p.startswith(".") for p in rel.parts[:-1]): + return True + if rel.name.startswith(".") and rel.name.lower() not in EXTENSIONLESS_FILES: + return True + return False + + def _is_watchable_path(self, path: Path) -> bool: + """Return True when a filesystem event path is eligible for upload processing.""" + return not self._is_ignored_path(path) and detect_language(path) != "unknown" + def _get_temp_bundle_dir(self) -> Path: """Get or create temporary directory for bundle creation.""" if not self.temp_dir: @@ -757,6 +990,19 @@ def detect_file_changes(self, changed_paths: List[Path]) -> Dict[str, List]: } for path in changed_paths: + if self._is_ignored_path(path): + try: + abs_path = str(path.resolve()) + except Exception: + continue + cached_hash = get_cached_file_hash(abs_path, self.repo_name) + if cached_hash: + changes["deleted"].append(path) + try: + self._stat_cache.pop(abs_path, None) + except Exception: + pass + continue try: abs_path = str(path.resolve()) except Exception: @@ -819,8 +1065,6 @@ def detect_file_changes(self, changed_paths: List[Path]) -> Dict[str, List]: self._stat_cache[abs_path] = (getattr(stat, "st_mtime_ns", int(stat.st_mtime * 1e9)), stat.st_size) except Exception: pass - set_cached_file_hash(abs_path, current_hash, self.repo_name) - # Detect moves by looking for files with same content hash # but different paths (requires additional tracking) changes["moved"] = self._detect_moves(changes["created"], changes["deleted"]) @@ -945,8 +1189,6 @@ def create_delta_bundle( operations.append(operation) file_hashes[rel_path] = f"sha1:{file_hash}" total_size += stat.st_size - set_cached_file_hash(str(path.resolve()), file_hash, self.repo_name) - except Exception as e: print(f"[bundle_create] Error processing created file {path}: {e}") continue @@ -985,8 +1227,6 @@ def create_delta_bundle( operations.append(operation) file_hashes[rel_path] = f"sha1:{file_hash}" total_size += stat.st_size - set_cached_file_hash(str(path.resolve()), file_hash, self.repo_name) - except Exception as e: print(f"[bundle_create] Error processing updated file {path}: {e}") continue @@ -1027,8 +1267,6 @@ def create_delta_bundle( operations.append(operation) file_hashes[dest_rel_path] = f"sha1:{file_hash}" total_size += stat.st_size - set_cached_file_hash(str(dest_path.resolve()), file_hash, self.repo_name) - except Exception as e: print(f"[bundle_create] Error processing moved file {source_path} -> {dest_path}: {e}") continue @@ -1063,7 +1301,6 @@ def create_delta_bundle( "version": "1.0", "bundle_id": bundle_id, "workspace_path": self.workspace_path, - "collection_name": self.collection_name, "created_at": created_at, # CLI is stateless - server handles sequence numbers "sequence_number": None, # Server will assign @@ -1115,6 +1352,305 @@ def create_delta_bundle( return str(bundle_path), manifest + def _build_plan_payload(self, changes: Dict[str, List]) -> Dict[str, Any]: + created_at = datetime.now().isoformat() + bundle_id = str(uuid.uuid4()) + operations: List[Dict[str, Any]] = [] + file_hashes: Dict[str, str] = {} + total_size = 0 + + for path in changes["created"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = path.stat() + operations.append( + { + "operation": "created", + "path": rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "language": detect_language(path), + } + ) + file_hashes[rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning("[remote_upload] Failed to prepare created plan entry for %s: %s", path, e) + + for path in changes["updated"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = path.stat() + previous_hash = _format_cached_sha1( + get_cached_file_hash(str(path.resolve()), self.repo_name) + ) + operations.append( + { + "operation": "updated", + "path": rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "previous_hash": previous_hash, + "language": detect_language(path), + } + ) + file_hashes[rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning("[remote_upload] Failed to prepare updated plan entry for %s: %s", path, e) + + for source_path, dest_path in changes["moved"]: + dest_rel_path = dest_path.relative_to(Path(self.workspace_path)).as_posix() + source_rel_path = source_path.relative_to(Path(self.workspace_path)).as_posix() + try: + content = dest_path.read_bytes() + file_hash = hashlib.sha1(content).hexdigest() + stat = dest_path.stat() + operations.append( + { + "operation": "moved", + "path": dest_rel_path, + "source_path": source_rel_path, + "size_bytes": stat.st_size, + "content_hash": f"sha1:{file_hash}", + "language": detect_language(dest_path), + } + ) + file_hashes[dest_rel_path] = f"sha1:{file_hash}" + total_size += stat.st_size + except Exception as e: + logger.warning( + "[remote_upload] Failed to prepare moved plan entry for %s -> %s: %s", + source_path, + dest_path, + e, + ) + + for path in changes["deleted"]: + rel_path = path.relative_to(Path(self.workspace_path)).as_posix() + try: + previous_hash = _format_cached_sha1( + get_cached_file_hash(str(path.resolve()), self.repo_name) + ) + operations.append( + { + "operation": "deleted", + "path": rel_path, + "previous_hash": previous_hash, + "language": detect_language(path), + } + ) + except Exception as e: + logger.warning("[remote_upload] Failed to prepare deleted plan entry for %s: %s", path, e) + + manifest = { + "version": "1.0", + "bundle_id": bundle_id, + "workspace_path": self.workspace_path, + "created_at": created_at, + "sequence_number": None, + "parent_sequence": None, + "operations": { + "created": len(changes["created"]), + "updated": len(changes["updated"]), + "deleted": len(changes["deleted"]), + "moved": len(changes["moved"]), + }, + "total_files": len(operations), + "total_size_bytes": total_size, + "compression": "gzip", + "encoding": "utf-8", + } + return { + "manifest": manifest, + "operations": operations, + "file_hashes": file_hashes, + } + + def _plan_delta_upload(self, changes: Dict[str, List]) -> Optional[Dict[str, Any]]: + if not _env_flag("CTXCE_REMOTE_UPLOAD_PLAN_ENABLED", True): + return None + try: + payload = self._build_plan_payload(changes) + self._last_plan_payload = payload + data = { + "workspace_path": self._translate_to_container_path(self.workspace_path), + "source_path": self.workspace_path, + "logical_repo_id": _compute_logical_repo_id(self.workspace_path), + "manifest": payload["manifest"], + "operations": payload["operations"], + "file_hashes": payload["file_hashes"], + } + sess = get_auth_session(self.upload_endpoint) + if sess: + data["session"] = sess + if getattr(self, "logical_repo_id", None): + data["logical_repo_id"] = self.logical_repo_id + + response = self.session.post( + f"{self.upload_endpoint}/api/v1/delta/plan", + json=data, + timeout=min(self.timeout, 60), + ) + if response.status_code in {404, 405}: + logger.info("[remote_upload] Plan endpoint unavailable; falling back to full bundle upload") + return None + response.raise_for_status() + body = response.json() + if not body.get("success", False): + logger.warning("[remote_upload] Plan request failed; falling back: %s", body.get("error")) + return None + return body + except Exception as e: + logger.warning("[remote_upload] Plan request failed; falling back to full bundle upload: %s", e) + return None + + def _build_apply_only_payload(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Dict[str, Any]: + payload = self._last_plan_payload or self._build_plan_payload(changes) + needed = plan.get("needed_files", {}) if isinstance(plan, dict) else {} + created_needed = set(needed.get("created", []) or []) + updated_needed = set(needed.get("updated", []) or []) + moved_needed = set(needed.get("moved", []) or []) + + # Check if ALL operations are hash-matched (nothing needs content at all) + # This happens when all needed_files lists are empty and there are no actual changes requiring content + has_changes_needing_content = bool(created_needed or updated_needed or moved_needed) + has_deletes = bool(changes.get("deleted", [])) + + # Only skip apply-only if there are NO operations needing content AND NO deletes + if not has_changes_needing_content and not has_deletes: + return { + "manifest": payload.get("manifest", {}), + "operations": [], + "file_hashes": {}, + } + + filtered_ops: List[Dict[str, Any]] = [] + filtered_hashes: Dict[str, str] = {} + for operation in payload.get("operations", []): + op_type = str(operation.get("operation") or "") + rel_path = str(operation.get("path") or "") + # Determine if this operation needs content (only those skip filtered_hashes) + needs_content = ( + (op_type == "created" and rel_path in created_needed) + or (op_type == "updated" and rel_path in updated_needed) + or (op_type == "moved" and rel_path in moved_needed) + ) + if needs_content: + # Skip operations that need content - they'll be uploaded separately + continue + # IMPORTANT: server-side apply_delta_operations() only accepts "deleted" and "moved" + # operations. Hash-matched "created" and "updated" operations must NOT be routed + # through apply_ops since the server will reject them. + if op_type not in {"deleted", "moved"}: + continue + # Preserve all other operations so server advances state + filtered_ops.append(operation) + # Include hash for non-deleted operations + if op_type != "deleted": + hash_value = payload.get("file_hashes", {}).get(rel_path) + if hash_value: + filtered_hashes[rel_path] = hash_value + return { + "manifest": payload.get("manifest", {}), + "operations": filtered_ops, + "file_hashes": filtered_hashes, + } + + def _apply_operations_without_content(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Optional[bool]: + payload = self._build_apply_only_payload(changes, plan) + operations = payload.get("operations", []) + if not operations: + return None + try: + data = { + "workspace_path": self._translate_to_container_path(self.workspace_path), + "source_path": self.workspace_path, + "logical_repo_id": _compute_logical_repo_id(self.workspace_path), + "manifest": payload["manifest"], + "operations": operations, + "file_hashes": payload["file_hashes"], + } + sess = get_auth_session(self.upload_endpoint) + if sess: + data["session"] = sess + if getattr(self, "logical_repo_id", None): + data["logical_repo_id"] = self.logical_repo_id + + logger.info( + "[remote_upload] Applying metadata-only operations without bundle: deleted=%s moved=%s", + sum(1 for op in operations if op.get("operation") == "deleted"), + sum(1 for op in operations if op.get("operation") == "moved"), + ) + response = self.session.post( + f"{self.upload_endpoint}/api/v1/delta/apply_ops", + json=data, + timeout=min(self.timeout, 60), + ) + if response.status_code in {404, 405}: + logger.info("[remote_upload] apply_ops endpoint unavailable; falling back to bundle upload") + return None + response.raise_for_status() + body = response.json() + if not body.get("success", False): + logger.warning("[remote_upload] apply_ops failed; falling back to bundle upload: %s", body.get("error")) + return None + # Only finalize changes that were actually processed by the server + # apply_delta_operations only handles deleted/moved operations + processed_ops = body.get("processed_operations") or {} + applied_changes = { + "deleted": changes.get("deleted", []), + "moved": changes.get("moved", []), + "created": [], + "updated": [], + } + self._finalize_successful_changes(applied_changes) + self._set_last_upload_result( + "uploaded", + bundle_id=body.get("bundle_id"), + sequence_number=body.get("sequence_number"), + processed_operations=processed_ops, + ) + logger.info( + "[remote_upload] Metadata-only operations applied: %s", + processed_ops, + ) + return True + except Exception as e: + logger.warning("[remote_upload] apply_ops failed; falling back to bundle upload: %s", e) + return None + + def _filter_changes_by_plan(self, changes: Dict[str, List], plan: Dict[str, Any]) -> Dict[str, List]: + needed = plan.get("needed_files", {}) if isinstance(plan, dict) else {} + created_needed = set(needed.get("created", []) or []) + updated_needed = set(needed.get("updated", []) or []) + moved_needed = set(needed.get("moved", []) or []) + + filtered_created = [ + path for path in changes["created"] + if path.relative_to(Path(self.workspace_path)).as_posix() in created_needed + ] + filtered_updated = [ + path for path in changes["updated"] + if path.relative_to(Path(self.workspace_path)).as_posix() in updated_needed + ] + filtered_moved = [ + (source_path, dest_path) + for source_path, dest_path in changes["moved"] + if dest_path.relative_to(Path(self.workspace_path)).as_posix() in moved_needed + ] + return { + "created": filtered_created, + "updated": filtered_updated, + "deleted": list(changes["deleted"]), + "moved": filtered_moved, + "unchanged": [], + } + def upload_bundle(self, bundle_path: str, manifest: Dict[str, Any]) -> Dict[str, Any]: """Upload delta bundle to remote server with exponential backoff retry. @@ -1147,13 +1683,11 @@ def upload_bundle(self, bundle_path: str, manifest: Dict[str, Any]) -> Dict[str, } data = { "workspace_path": self._translate_to_container_path(self.workspace_path), - "collection_name": self.collection_name, "sequence_number": manifest.get("sequence_number"), "force": False, "source_path": self.workspace_path, "logical_repo_id": _compute_logical_repo_id(self.workspace_path), } - sess = get_auth_session(self.upload_endpoint) if sess: data["session"] = sess @@ -1361,7 +1895,16 @@ def get_server_status(self) -> Dict[str, Any]: ) if response.status_code == 200: - return response.json() + payload = response.json() + if not isinstance(payload, dict): + return { + "success": False, + "error": { + "code": "STATUS_INVALID", + "message": "Invalid status response payload", + }, + } + return {"success": True, **payload} # Handle error response error_msg = f"Status check failed with HTTP {response.status_code}" @@ -1385,6 +1928,93 @@ def has_meaningful_changes(self, changes: Dict[str, List]) -> bool: total_changes = sum(len(files) for op, files in changes.items() if op != "unchanged") return total_changes > 0 + def _collect_force_cleanup_paths(self) -> List[Path]: + """ + Return ignored paths that force mode should actively delete remotely. + + In dev-remote mode, dev-workspace is intentionally ignored during upload + scans to avoid recursive dogfooding. If that tree already exists on the + remote side from an older buggy upload, force mode should remove it even + when the standalone client's cache does not know about those paths. + """ + cleanup_paths: List[Path] = [] + if "dev-workspace" not in self._excluded_dirnames(): + return cleanup_paths + + dev_root = Path(self.workspace_path) / "dev-workspace" + if not dev_root.exists(): + return cleanup_paths + + for root, dirnames, filenames in os.walk(dev_root): + dirnames[:] = [d for d in dirnames if not d.startswith(".")] + for filename in filenames: + path = Path(root) / filename + try: + if path.is_file(): + cleanup_paths.append(path) + except Exception: + continue + return cleanup_paths + + def build_force_changes(self, all_files: List[Path]) -> Dict[str, List]: + """ + Build force-upload changes while still cleaning stale cached paths. + + Force mode should re-upload every currently managed file, but it must also + emit deletes for files that only exist in the local cache now, including + paths that are ignored under the current client policy such as + dev-workspace in dev-remote mode. + """ + created_files: List[Path] = [] + path_map: Dict[Path, Path] = {} + for path in all_files: + if self._is_ignored_path(path): + continue + try: + resolved = path.resolve() + except Exception: + continue + created_files.append(path) + path_map[resolved] = path + + for cached_abs in get_all_cached_paths(self.repo_name): + try: + cached_path = Path(cached_abs) + resolved = cached_path.resolve() + except Exception: + continue + if resolved not in path_map: + path_map[resolved] = cached_path + + force_cleanup_paths = self._collect_force_cleanup_paths() + for cleanup_path in force_cleanup_paths: + try: + resolved = cleanup_path.resolve() + except Exception: + continue + if resolved not in path_map: + path_map[resolved] = cleanup_path + + probed = self.detect_file_changes(list(path_map.values())) + deleted_by_resolved: Dict[Path, Path] = {} + for deleted_path in probed.get("deleted", []): + try: + deleted_by_resolved[deleted_path.resolve()] = deleted_path + except Exception: + continue + for cleanup_path in force_cleanup_paths: + try: + deleted_by_resolved.setdefault(cleanup_path.resolve(), cleanup_path) + except Exception: + continue + return { + "created": created_files, + "updated": [], + "deleted": list(deleted_by_resolved.values()), + "moved": [], + "unchanged": [], + } + def upload_git_history_only(self, git_history: Dict[str, Any]) -> bool: try: empty_changes = { @@ -1429,10 +2059,12 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: # Validate input if not changes: logger.info("[remote_upload] No changes provided") + self._set_last_upload_result("no_changes") return True if not self.has_meaningful_changes(changes): logger.info("[remote_upload] No meaningful changes detected, skipping upload") + self._set_last_upload_result("no_changes") return True # Log change summary @@ -1441,10 +2073,46 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: f"{len(changes['created'])} created, {len(changes['updated'])} updated, " f"{len(changes['deleted'])} deleted, {len(changes['moved'])} moved") + planned_changes = changes + plan = self._plan_delta_upload(changes) + if plan: + preview = plan.get("operation_counts_preview", {}) + logger.info( + "[remote_upload] Plan preview: needed created=%s updated=%s deleted=%s moved=%s " + "skipped_hash_match=%s needed_bytes=%s", + preview.get("created", 0), + preview.get("updated", 0), + preview.get("deleted", 0), + preview.get("moved", 0), + preview.get("skipped_hash_match", 0), + plan.get("needed_size_bytes", 0), + ) + planned_changes = self._filter_changes_by_plan(changes, plan) + has_content_work = bool( + planned_changes.get("created") + or planned_changes.get("updated") + or planned_changes.get("moved") + ) + if not has_content_work: + apply_only_result = self._apply_operations_without_content(changes, plan) + if apply_only_result is True: + flush_cached_file_hashes() + return True + if not self.has_meaningful_changes(planned_changes): + logger.info("[remote_upload] Plan found no upload work; skipping bundle upload") + self._finalize_successful_changes(changes) + self._set_last_upload_result( + "skipped_by_plan", + plan_preview=preview, + needed_size_bytes=plan.get("needed_size_bytes", 0), + ) + flush_cached_file_hashes() + return True + # Create delta bundle bundle_path = None try: - bundle_path, manifest = self.create_delta_bundle(changes) + bundle_path, manifest = self.create_delta_bundle(planned_changes) logger.info(f"[remote_upload] Created delta bundle: {manifest['bundle_id']} " f"(size: {manifest['total_size_bytes']} bytes)") @@ -1456,6 +2124,7 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: logger.error(f"[remote_upload] Error creating delta bundle: {e}") # Clean up any temporary files on failure self.cleanup() + self._set_last_upload_result("failed", stage="bundle_creation", error=str(e)) return False # Upload bundle with retry logic @@ -1463,9 +2132,86 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: response = self.upload_bundle(bundle_path, manifest) if response.get("success", False): - processed_ops = response.get('processed_operations', {}) - logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") - logger.info(f"[remote_upload] Processed operations: {processed_ops}") + async_failed = False + async_pending = False + processed_ops = response.get("processed_operations") + if processed_ops is None: + logger.info( + "[remote_upload] Bundle %s accepted by server; processing asynchronously (sequence=%s)", + manifest["bundle_id"], + response.get("sequence_number"), + ) + self._set_last_upload_result( + "queued", + bundle_id=manifest["bundle_id"], + sequence_number=response.get("sequence_number"), + ) + async_result = self._await_async_upload_result( + manifest["bundle_id"], + response.get("sequence_number"), + ) + if async_result is None: + # Server accepted the bundle but status is still pending. + async_pending = True + logger.warning( + "[remote_upload] Async upload timed out awaiting server response for bundle %s", + manifest["bundle_id"], + ) + else: + self.last_upload_result = async_result + outcome = str(async_result.get("outcome") or "") + if outcome == "uploaded_async": + self._finalize_successful_changes(planned_changes) + logger.info( + "[remote_upload] Async processing completed for bundle %s: %s", + manifest["bundle_id"], + async_result.get("processed_operations") or {}, + ) + elif outcome == "failed": + async_failed = True + logger.error( + "[remote_upload] Async processing failed for bundle %s: %s", + manifest["bundle_id"], + async_result.get("error"), + ) + self._set_last_upload_result( + "failed", + stage="async_processing", + bundle_id=async_result.get("bundle_id") or manifest["bundle_id"], + sequence_number=async_result.get("sequence_number") or response.get("sequence_number"), + error=async_result.get("error"), + ) + else: + async_pending = True + # Keep queued state for non-terminal async outcomes. + self._set_last_upload_result( + "queued", + bundle_id=async_result.get("bundle_id") or manifest["bundle_id"], + sequence_number=async_result.get("sequence_number") or response.get("sequence_number"), + ) + logger.warning( + "[remote_upload] Async upload still pending for bundle %s (sequence=%s, outcome=%s)", + manifest["bundle_id"], + response.get("sequence_number"), + outcome or "", + ) + else: + logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") + logger.info(f"[remote_upload] Processed operations: {processed_ops}") + self._finalize_successful_changes(planned_changes) + self._set_last_upload_result( + "uploaded", + bundle_id=manifest["bundle_id"], + sequence_number=response.get("sequence_number"), + processed_operations=processed_ops, + ) + if async_pending: + logger.info( + "[remote_upload] Bundle %s accepted and queued; deferring local finalization", + manifest["bundle_id"], + ) + if not async_failed and not async_pending: + flush_cached_file_hashes() # Clean up temporary bundle after successful upload try: @@ -1477,18 +2223,21 @@ def process_changes_and_upload(self, changes: Dict[str, List]) -> bool: except Exception as cleanup_error: logger.warning(f"[remote_upload] Failed to cleanup bundle {bundle_path}: {cleanup_error}") - return True + return not async_failed else: error_msg = response.get('error', {}).get('message', 'Unknown upload error') logger.error(f"[remote_upload] Upload failed: {error_msg}") + self._set_last_upload_result("failed", stage="upload", error=error_msg) return False except Exception as e: logger.error(f"[remote_upload] Error uploading bundle: {e}") + self._set_last_upload_result("failed", stage="upload", error=str(e)) return False except Exception as e: logger.error(f"[remote_upload] Unexpected error in process_changes_and_upload: {e}") + self._set_last_upload_result("failed", stage="unexpected", error=str(e)) return False def watch_loop(self, interval: int = 5): @@ -1515,6 +2264,7 @@ def __init__(self, client, debounce_seconds=2.0): self._pending_paths = set() self._check_for_deletions = False self._lock = threading.Lock() + self._processing = False def on_any_event(self, event): """Handle any file system event.""" @@ -1532,13 +2282,13 @@ def on_any_event(self, event): # Always check src_path src_path = Path(event.src_path) - if detect_language(src_path) != "unknown": + if self.client._is_watchable_path(src_path): paths_to_process.append(src_path) # For FileMovedEvent, also process the destination path if hasattr(event, 'dest_path') and event.dest_path: dest_path = Path(event.dest_path) - if detect_language(dest_path) != "unknown": + if self.client._is_watchable_path(dest_path): paths_to_process.append(dest_path) if not paths_to_process: @@ -1559,22 +2309,30 @@ def on_any_event(self, event): def _process_pending_changes(self): """Process accumulated changes after debounce period.""" with self._lock: + # Timer fired; allow a new debounce to be armed while we process. + self._debounce_timer = None + if self._processing: + return if not self._pending_paths: return + self._processing = True pending = list(self._pending_paths) self._pending_paths.clear() check_deletions = self._check_for_deletions self._check_for_deletions = False + upload_succeeded = False try: # Only include cached paths when deletion-related events occurred if check_deletions: - all_paths = list(set(pending + [ + cached_paths = [ Path(p) for p in get_all_cached_paths(self.client.repo_name) - ])) + ] + all_paths = list(set(pending + cached_paths)) else: all_paths = pending - + + changes = self.client.detect_file_changes(all_paths) meaningful_changes = ( len(changes.get("created", [])) + @@ -1582,12 +2340,13 @@ def _process_pending_changes(self): len(changes.get("deleted", [])) + len(changes.get("moved", [])) ) - + if meaningful_changes > 0: logger.info(f"[watch] Detected {meaningful_changes} changes: { {k: len(v) for k, v in changes.items() if k != 'unchanged'} }") success = self.client.process_changes_and_upload(changes) if success: - logger.info("[watch] Successfully uploaded changes") + self.client.log_watch_upload_result() + upload_succeeded = True else: logger.error("[watch] Failed to upload changes") else: @@ -1597,16 +2356,34 @@ def _process_pending_changes(self): git_history = _collect_git_history_for_workspace(self.client.workspace_path) except Exception: git_history = None - + if git_history: logger.info("[watch] Detected git history update; uploading git history metadata") success = self.client.upload_git_history_only(git_history) if success: logger.info("[watch] Successfully uploaded git history metadata") + upload_succeeded = True else: logger.error("[watch] Failed to upload git history metadata") + else: + upload_succeeded = True # No changes to process except Exception as e: logger.error(f"[watch] Error processing changes: {e}") + finally: + with self._lock: + self._processing = False + # Re-queue pending paths if upload failed + if not upload_succeeded and pending: + # Merge pending paths back into _pending_paths + for p in pending: + self._pending_paths.add(p) + # Arm next pass if there are pending paths + if self._pending_paths and self._debounce_timer is None: + self._debounce_timer = threading.Timer( + self.debounce_seconds, + self._process_pending_changes, + ) + self._debounce_timer.start() observer = Observer() handler = CodeFileEventHandler(self, debounce_seconds=2.0) @@ -1676,7 +2453,7 @@ def _watch_loop_polling(self, interval: int = 5): success = self.process_changes_and_upload(changes) if success: - logger.info(f"[watch] Successfully uploaded changes") + self.log_watch_upload_result() else: logger.error(f"[watch] Failed to upload changes") else: @@ -1719,16 +2496,10 @@ def get_all_code_files(self) -> List[Path]: # Single walk with early pruning and set-based matching to reduce IO ext_suffixes = {str(ext).lower() for ext in CODE_EXTS if str(ext).startswith('.')} - extensionless_names = set(EXTENSIONLESS_FILES.keys()) + extensionless_names = {k.lower() for k in EXTENSIONLESS_FILES.keys()} # Always exclude dev-workspace to prevent recursive upload loops # (upload service creates dev-workspace// which would otherwise get re-uploaded) - excluded = { - "node_modules", "vendor", "dist", "build", "target", "out", - ".git", ".hg", ".svn", ".vscode", ".idea", ".venv", "venv", - "__pycache__", ".pytest_cache", ".mypy_cache", ".cache", - ".context-engine", ".context-engine-uploader", ".codebase", - "dev-workspace" - } + excluded = self._excluded_dirnames() seen = set() for root, dirnames, filenames in os.walk(workspace_path): @@ -1741,6 +2512,8 @@ def get_all_code_files(self) -> List[Path]: if filename.startswith('.') and fname_lower not in extensionless_names: continue candidate = Path(root) / filename + if self._is_ignored_path(candidate): + continue suffix = candidate.suffix.lower() # Match by extension, extensionless name, or Dockerfile.* prefix if (suffix in ext_suffixes or @@ -1780,87 +2553,14 @@ def process_and_upload_changes(self, changed_paths: List[Path]) -> bool: except Exception as e: logger.error(f"[remote_upload] Error detecting file changes: {e}") return False - - if not self.has_meaningful_changes(changes): - logger.info("[remote_upload] No meaningful changes detected, skipping upload") - return True - - # Log change summary - total_changes = sum(len(files) for op, files in changes.items() if op != "unchanged") - logger.info(f"[remote_upload] Detected {total_changes} meaningful changes: " - f"{len(changes['created'])} created, {len(changes['updated'])} updated, " - f"{len(changes['deleted'])} deleted, {len(changes['moved'])} moved") - - # Create delta bundle - bundle_path = None - try: - bundle_path, manifest = self.create_delta_bundle(changes) - logger.info(f"[remote_upload] Created delta bundle: {manifest['bundle_id']} " - f"(size: {manifest['total_size_bytes']} bytes)") - - # Validate bundle was created successfully - if not bundle_path or not os.path.exists(bundle_path): - raise RuntimeError(f"Failed to create bundle at {bundle_path}") - - except Exception as e: - logger.error(f"[remote_upload] Error creating delta bundle: {e}") - # Clean up any temporary files on failure - self.cleanup() - return False - - # Upload bundle with retry logic - try: - response = self.upload_bundle(bundle_path, manifest) - - if response.get("success", False): - processed_ops = response.get('processed_operations', {}) - logger.info(f"[remote_upload] Successfully uploaded bundle {manifest['bundle_id']}") - logger.info(f"[remote_upload] Processed operations: {processed_ops}") - - # Clean up temporary bundle after successful upload - try: - if os.path.exists(bundle_path): - os.remove(bundle_path) - logger.debug(f"[remote_upload] Cleaned up temporary bundle: {bundle_path}") - # Also clean up the entire temp directory if this is the last bundle - self.cleanup() - except Exception as cleanup_error: - logger.warning(f"[remote_upload] Failed to cleanup bundle {bundle_path}: {cleanup_error}") - - return True - else: - error = response.get("error", {}) - error_code = error.get("code", "UNKNOWN") - error_msg = error.get("message", "Unknown error") - - logger.error(f"[remote_upload] Upload failed: {error_msg}") - - # Handle specific error types - # CLI is stateless - server handles sequence management - if error_code in ["BUNDLE_TOO_LARGE", "BUNDLE_NOT_FOUND"]: - # These are unrecoverable errors - logger.error(f"[remote_upload] Unrecoverable error ({error_code}): {error_msg}") - return False - elif error_code in ["TIMEOUT_ERROR", "CONNECTION_ERROR", "NETWORK_ERROR"]: - # These might be temporary, suggest fallback - logger.warning(f"[remote_upload] Network-related error ({error_code}): {error_msg}") - logger.warning("[remote_upload] Consider falling back to local mode if this persists") - return False - else: - # Other errors - logger.error(f"[remote_upload] Upload error ({error_code}): {error_msg}") - return False - - except Exception as e: - logger.error(f"[remote_upload] Unexpected error during upload: {e}") - return False + return self.process_changes_and_upload(changes) except Exception as e: logger.error(f"[remote_upload] Critical error in process_and_upload_changes: {e}") logger.exception("[remote_upload] Full traceback:") return False -def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, str]: +def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, Any]: """Get remote upload configuration from environment variables and command-line arguments.""" # Use command-line path if provided, otherwise fall back to environment variables if cli_path: @@ -1870,17 +2570,10 @@ def get_remote_config(cli_path: Optional[str] = None) -> Dict[str, str]: logical_repo_id = _compute_logical_repo_id(workspace_path) - # Use auto-generated collection name based on repo name - repo_name = _extract_repo_name_from_path(workspace_path) - # Fallback to directory name if repo detection fails - if not repo_name: - repo_name = Path(workspace_path).name - collection_name = get_collection_name(repo_name) - return { "upload_endpoint": os.environ.get("REMOTE_UPLOAD_ENDPOINT", "http://localhost:8080"), "workspace_path": workspace_path, - "collection_name": collection_name, + "collection_name": None, "logical_repo_id": logical_repo_id, # Use higher, more robust defaults but still allow env overrides "max_retries": int(os.environ.get("REMOTE_UPLOAD_MAX_RETRIES", "5")), @@ -2006,7 +2699,7 @@ def main(): config["timeout"] = args.timeout logger.info(f"Workspace path: {config['workspace_path']}") - logger.info(f"Collection name: {config['collection_name']}") + logger.info(f"Collection name: {config['collection_name'] or ''}") logger.info(f"Upload endpoint: {config['upload_endpoint']}") if args.show_mapping: @@ -2040,15 +2733,8 @@ def main(): # Test server connection first logger.info("Checking server status...") status = client.get_server_status() - is_success = ( - isinstance(status, dict) and - 'workspace_path' in status and - 'collection_name' in status and - status.get('status') == 'ready' - ) - if not is_success: - error = status.get("error", {}) - logger.error(f"Cannot connect to server: {error.get('message', 'Unknown error')}") + if not _is_usable_delta_status(status): + logger.error("Cannot connect to server: %s", _server_status_error_message(status)) return 1 logger.info("Server connection successful") @@ -2085,16 +2771,8 @@ def main(): # Test server connection logger.info("Checking server status...") status = client.get_server_status() - # For delta endpoint, success is indicated by having expected fields (not a "success" boolean) - is_success = ( - isinstance(status, dict) and - 'workspace_path' in status and - 'collection_name' in status and - status.get('status') == 'ready' - ) - if not is_success: - error = status.get("error", {}) - logger.error(f"Cannot connect to server: {error.get('message', 'Unknown error')}") + if not _is_usable_delta_status(status): + logger.error("Cannot connect to server: %s", _server_status_error_message(status)) return 1 logger.info("Server connection successful") @@ -2113,8 +2791,7 @@ def main(): # Detect changes (treat all files as changes for initial upload) if args.force: - # Force mode: treat all files as created - changes = {"created": all_files, "updated": [], "deleted": [], "moved": [], "unchanged": []} + changes = client.build_force_changes(all_files) else: changes = client.detect_file_changes(all_files) @@ -2129,8 +2806,19 @@ def main(): success = client.process_changes_and_upload(changes) if success: - logger.info("Repository upload completed successfully!") - logger.info(f"Collection name: {config['collection_name']}") + outcome = str((client.last_upload_result or {}).get("outcome") or "") + if outcome == "skipped_by_plan": + logger.info("No upload needed after plan") + elif outcome == "queued": + logger.info("Repository upload request accepted; server processing asynchronously") + elif outcome == "uploaded_async": + logger.info( + "Repository upload processed asynchronously: %s", + (client.last_upload_result or {}).get("processed_operations") or {}, + ) + else: + logger.info("Repository upload completed successfully!") + logger.info(f"Collection name: {config['collection_name'] or ''}") logger.info(f"Files uploaded: {len(all_files)}") else: logger.error("Repository upload failed!") diff --git a/scripts/upload_delta_bundle.py b/scripts/upload_delta_bundle.py index 973be132..0ccd2906 100644 --- a/scripts/upload_delta_bundle.py +++ b/scripts/upload_delta_bundle.py @@ -1,24 +1,21 @@ import os import json +import shutil import tarfile import hashlib import re import logging from pathlib import Path -from typing import Dict, Any, Optional +from typing import Any, Dict, Optional - -try: - from scripts.workspace_state import ( - _extract_repo_name_from_path, - get_staging_targets, - get_collection_state_snapshot, - is_staging_enabled, - ) -except ImportError as exc: - raise ImportError( - "upload_delta_bundle requires scripts.workspace_state; ensure the module is available" - ) from exc +from scripts.workspace_state import ( + _normalize_cache_key_path, + _extract_repo_name_from_path, + get_staging_targets, + get_collection_state_snapshot, + is_staging_enabled, + upsert_index_journal_entries, +) logger = logging.getLogger(__name__) @@ -27,6 +24,102 @@ _SLUGGED_REPO_RE = re.compile(r"^.+-[0-9a-f]{16}(?:_old)?$") +def _normalize_hash_value(value: Any) -> str: + raw = str(value or "").strip() + if not raw: + return "" + if ":" in raw: + _, _, digest = raw.partition(":") + if digest.strip(): + return digest.strip().lower() + return raw.lower() + + +def _build_upsert_journal_entry(path: Path | str, content_hash: Optional[str]) -> Dict[str, Any]: + entry: Dict[str, Any] = { + "path": str(path), + "op_type": "upsert", + } + if content_hash: + entry["content_hash"] = content_hash + return entry + + +def _build_delete_journal_entry(path: Path | str, content_hash: Optional[str] = None) -> Dict[str, Any]: + entry: Dict[str, Any] = { + "path": str(path), + "op_type": "delete", + } + if content_hash: + entry["content_hash"] = content_hash + return entry + + +def _load_cache_hashes(cache_path: Path) -> Dict[str, str]: + try: + with cache_path.open("r", encoding="utf-8-sig") as f: + data = json.load(f) + except (OSError, ValueError, json.JSONDecodeError): + return {} + + file_hashes = data.get("file_hashes", {}) + if not isinstance(file_hashes, dict): + return {} + + normalized: Dict[str, str] = {} + for path_key, value in file_hashes.items(): + if isinstance(value, dict): + hash_value = value.get("hash") + else: + hash_value = value + digest = _normalize_hash_value(hash_value) + if digest: + normalized[_normalize_cache_key_path(str(path_key))] = digest + return normalized + + +def _load_replica_cache_hashes(workspace_root: Path, slug: str) -> Dict[str, str]: + merged: Dict[str, str] = {} + cache_paths = ( + Path(WORK_DIR) / ".codebase" / "repos" / slug / "cache.json", + workspace_root / ".codebase" / "cache.json", + ) + for cache_path in cache_paths: + if not cache_path.exists(): + continue + merged.update(_load_cache_hashes(cache_path)) + return merged + + +def _flush_replica_cache_hashes(workspace_root: Path, slug: str, hashes: Dict[str, str]) -> None: + """Flush replica hashes to workspace cache.json.""" + try: + cache_path = workspace_root / ".codebase" / "cache.json" + cache_path.parent.mkdir(parents=True, exist_ok=True) + + # Read existing cache to preserve other entries + existing_data = {} + if cache_path.exists(): + try: + with cache_path.open("r", encoding="utf-8-sig") as f: + existing_data = json.load(f) + except (OSError, ValueError, json.JSONDecodeError): + existing_data = {} + + # Update file_hashes section + if not isinstance(existing_data, dict): + existing_data = {} + existing_data["file_hashes"] = hashes + + # Write back atomically + temp_path = cache_path.with_suffix(".tmp") + with temp_path.open("w", encoding="utf-8") as f: + json.dump(existing_data, f, indent=2) + temp_path.replace(cache_path) + except Exception as e: + logger.debug(f"[upload_service] Failed to flush cache for {slug}: {e}") + + def get_workspace_key(workspace_path: str) -> str: """Generate 16-char hash for collision avoidance in remote uploads. @@ -61,139 +154,126 @@ def _cleanup_empty_dirs(path: Path, stop_at: Path) -> None: break -def process_delta_bundle(workspace_path: str, bundle_path: Path, manifest: Dict[str, Any]) -> Dict[str, int]: - """Process delta bundle and return operation counts.""" - operations_count = { - "created": 0, - "updated": 0, - "deleted": 0, - "moved": 0, - "skipped": 0, - "failed": 0, - } +def _resolve_replica_roots(workspace_path: str, *, create_missing: bool = True) -> Dict[str, Path]: + workspace_leaf = Path(workspace_path).name + repo_name_for_state: Optional[str] = None + serving_slug: Optional[str] = None + active_slug: Optional[str] = None try: - # CRITICAL: Always materialize writes under WORK_DIR using a slugged repo directory. - # Do NOT write directly into the client-supplied workspace_path, since that may be a host - # path (e.g. /home/user/repo) that is not mounted/visible to the watcher/indexer. - workspace_leaf = Path(workspace_path).name - - repo_name_for_state: Optional[str] = None + repo_name_for_state = _extract_repo_name_from_path(workspace_path) + if repo_name_for_state: + snapshot = get_collection_state_snapshot( + workspace_path=None, + repo_name=repo_name_for_state, + ) # type: ignore[arg-type] + serving_slug = snapshot.get("serving_repo_slug") + active_slug = snapshot.get("active_repo_slug") + except Exception: + serving_slug = None + active_slug = None + + slug_order: list[str] = [] + serving_candidate: Optional[str] = None + if serving_slug and _SLUGGED_REPO_RE.match(serving_slug): + serving_candidate = serving_slug + if active_slug and _SLUGGED_REPO_RE.match(active_slug) and active_slug not in slug_order: + slug_order.append(active_slug) + + staging_active = False + staging_gate = bool(is_staging_enabled()) + try: + if serving_slug and str(serving_slug).endswith("_old"): + staging_active = True + except Exception: + staging_active = False - serving_slug: Optional[str] = None - active_slug: Optional[str] = None - if _extract_repo_name_from_path and get_collection_state_snapshot: - try: - repo_name_for_state = _extract_repo_name_from_path(workspace_path) - if repo_name_for_state: - snapshot = get_collection_state_snapshot(workspace_path=None, repo_name=repo_name_for_state) # type: ignore[arg-type] - serving_slug = snapshot.get("serving_repo_slug") - active_slug = snapshot.get("active_repo_slug") - except Exception: - serving_slug = None - active_slug = None - - slug_order: list[str] = [] - serving_candidate: Optional[str] = None - if serving_slug and _SLUGGED_REPO_RE.match(serving_slug): - serving_candidate = serving_slug - if active_slug and _SLUGGED_REPO_RE.match(active_slug) and active_slug not in slug_order: - slug_order.append(active_slug) - - # If staging is active, we must mirror uploads into BOTH the canonical slug and - # the "*_old" slug. Relying purely on snapshot detection is brittle (e.g. when - # the client workspace_path is a host path). When we can infer a canonical slug, - # force both targets. + if not staging_gate: staging_active = False - staging_gate = bool(is_staging_enabled() if callable(is_staging_enabled) else False) - try: - if serving_slug and str(serving_slug).endswith("_old"): - staging_active = True - except Exception: - staging_active = False - if not staging_gate: - staging_active = False + def _append_slug(slug: Optional[str]) -> None: + if slug and _SLUGGED_REPO_RE.match(slug) and slug not in slug_order: + slug_order.append(slug) + + if repo_name_for_state and _SLUGGED_REPO_RE.match(repo_name_for_state): + canonical_slug = ( + repo_name_for_state[:-4] + if repo_name_for_state.endswith("_old") + else repo_name_for_state + ) + old_slug_candidate = ( + repo_name_for_state + if repo_name_for_state.endswith("_old") + else f"{canonical_slug}_old" + ) + if staging_active: + slug_order = [] + _append_slug(canonical_slug) + _append_slug(old_slug_candidate) + elif not slug_order: + _append_slug(canonical_slug) + old_slug_path = Path(WORK_DIR) / old_slug_candidate + if old_slug_path.exists(): + _append_slug(old_slug_candidate) - def _append_slug(slug: Optional[str]) -> None: - if slug and _SLUGGED_REPO_RE.match(slug) and slug not in slug_order: - slug_order.append(slug) + if not slug_order: + if _SLUGGED_REPO_RE.match(workspace_leaf): + slug_order.append(workspace_leaf) + else: + repo_name = _extract_repo_name_from_path(workspace_path) or workspace_leaf + workspace_key = get_workspace_key(workspace_path) + slug_order.append(f"{repo_name}-{workspace_key}") - if repo_name_for_state and _SLUGGED_REPO_RE.match(repo_name_for_state): - canonical_slug = repo_name_for_state[:-4] if repo_name_for_state.endswith("_old") else repo_name_for_state - old_slug_candidate = ( - repo_name_for_state if repo_name_for_state.endswith("_old") else f"{canonical_slug}_old" + if staging_gate and not staging_active: + try: + repo_name_for_staging = _extract_repo_name_from_path(workspace_path) or slug_order[0] + targets = get_staging_targets( + workspace_path=workspace_path, + repo_name=repo_name_for_staging, ) - if staging_active: - slug_order = [] - _append_slug(canonical_slug) - _append_slug(old_slug_candidate) - elif not slug_order: - _append_slug(canonical_slug) - old_slug_path = Path(WORK_DIR) / old_slug_candidate - if old_slug_path.exists(): - _append_slug(old_slug_candidate) - - if not slug_order: - if _SLUGGED_REPO_RE.match(workspace_leaf): - slug_order.append(workspace_leaf) - else: - if _extract_repo_name_from_path: - repo_name = _extract_repo_name_from_path(workspace_path) or workspace_leaf - else: - repo_name = workspace_leaf - workspace_key = get_workspace_key(workspace_path) - slug_order.append(f"{repo_name}-{workspace_key}") + if isinstance(targets, dict) and targets.get("staging"): + staging_active = True + except Exception as staging_err: + logger.debug("[upload_service] Failed to detect staging: %s", staging_err) - # Best-effort: if staging is active according to workspace_state, ensure we mirror to - # both the canonical slug and its *_old slug. - if staging_gate and (not staging_active) and get_staging_targets and _extract_repo_name_from_path: - try: - repo_name_for_staging = _extract_repo_name_from_path(workspace_path) or slug_order[0] - targets = get_staging_targets(workspace_path=workspace_path, repo_name=repo_name_for_staging) - if isinstance(targets, dict) and targets.get("staging"): - staging_active = True - except Exception as staging_err: - logger.debug(f"[upload_service] Failed to detect staging: {staging_err}") - - def _slug_exists(slug: str) -> bool: - try: - return ( - (Path(WORK_DIR) / slug).exists() - or (Path(WORK_DIR) / ".codebase" / "repos" / slug).exists() - ) - except Exception: - return False - - if staging_gate and (not staging_active) and slug_order: - primary = slug_order[0] - if _SLUGGED_REPO_RE.match(primary): - canonical = primary[:-4] if primary.endswith("_old") else primary - inferred_old = primary if primary.endswith("_old") else f"{canonical}_old" - if _slug_exists(inferred_old): - staging_active = True - - if staging_gate and staging_active and slug_order: - primary = slug_order[0] - if _SLUGGED_REPO_RE.match(primary): - canonical = primary[:-4] if primary.endswith("_old") else primary - old_slug = primary if primary.endswith("_old") else f"{canonical}_old" - desired = [canonical, old_slug] - slug_order = [s for s in desired if _SLUGGED_REPO_RE.match(s)] - elif staging_gate and not staging_active and serving_candidate: - # Ignore serving slugs when staging is disabled; keep deterministic non-staging writes. - if serving_candidate in slug_order: - slug_order = [s for s in slug_order if s != serving_candidate] - - if staging_gate: - try: - logger.info(f"[upload_service] Delta bundle targets (staging={staging_active}): {slug_order}") - except Exception: - pass + def _slug_exists(slug: str) -> bool: + try: + return ( + (Path(WORK_DIR) / slug).exists() + or (Path(WORK_DIR) / ".codebase" / "repos" / slug).exists() + ) + except Exception: + return False + + if staging_gate and (not staging_active) and slug_order: + primary = slug_order[0] + if _SLUGGED_REPO_RE.match(primary): + canonical = primary[:-4] if primary.endswith("_old") else primary + inferred_old = primary if primary.endswith("_old") else f"{canonical}_old" + if _slug_exists(inferred_old): + staging_active = True + + if staging_gate and staging_active and slug_order: + primary = slug_order[0] + if _SLUGGED_REPO_RE.match(primary): + canonical = primary[:-4] if primary.endswith("_old") else primary + old_slug = primary if primary.endswith("_old") else f"{canonical}_old" + desired = [canonical, old_slug] + slug_order = [s for s in desired if _SLUGGED_REPO_RE.match(s)] + elif staging_gate and not staging_active and serving_candidate: + if serving_candidate in slug_order: + slug_order = [s for s in slug_order if s != serving_candidate] + + if staging_gate: + try: + logger.info("[upload_service] Delta bundle targets (staging=%s): %s", staging_active, slug_order) + except Exception: + pass - replica_roots: Dict[str, Path] = {} - for slug in slug_order: - path = Path(WORK_DIR) / slug + replica_roots: Dict[str, Path] = {} + for slug in slug_order: + path = Path(WORK_DIR) / slug + if create_missing: path.mkdir(parents=True, exist_ok=True) try: marker_dir = Path(WORK_DIR) / ".codebase" / "repos" / slug @@ -201,35 +281,393 @@ def _slug_exists(slug: str) -> bool: (marker_dir / ".ctxce_managed_upload").write_text("1\n") except Exception: pass - replica_roots[slug] = path.resolve() + replica_roots[slug] = path.resolve() + return replica_roots + + +def _enqueue_replica_journal_entries( + *, + workspace_root: Path, + slug: str, + entries: list[Dict[str, Any]], +) -> None: + if not entries: + return + try: + upsert_index_journal_entries( + entries, + workspace_path=str(workspace_root), + repo_name=slug, + ) + except Exception as exc: + logger.debug( + "[upload_service] Failed to enqueue index journal entries for %s: %s", + workspace_root, + exc, + ) + + +def _safe_join(base: Path, rel: str) -> Path: + rp = Path(str(rel)) + if str(rp) in {".", ""}: + raise ValueError("Invalid operation path") + if rp.is_absolute(): + raise ValueError(f"Absolute paths are not allowed: {rel}") + base_resolved = base.resolve() + candidate = (base_resolved / rp).resolve() + try: + ok = candidate.is_relative_to(base_resolved) + except Exception: + ok = os.path.commonpath([str(base_resolved), str(candidate)]) == str(base_resolved) + if not ok: + raise ValueError(f"Path escapes workspace: {rel}") + return candidate + + +def _sanitize_operation_path(rel_path: str, replica_roots: Dict[str, Path]) -> Optional[str]: + sanitized_path = rel_path + skipped_due_to_exact_slug = False + for slug in replica_roots.keys(): + if sanitized_path == slug: + skipped_due_to_exact_slug = True + break + prefix = f"{slug}/" + if sanitized_path.startswith(prefix): + sanitized_path = sanitized_path[len(prefix):] + break + if skipped_due_to_exact_slug or not sanitized_path: + return None + return sanitized_path + + +def plan_delta_upload( + workspace_path: str, + operations: list[Dict[str, Any]], + file_hashes: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: + needed_files = { + "created": [], + "updated": [], + "moved": [], + } + operations_count = { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + } + needed_size_bytes = 0 + replica_roots = _resolve_replica_roots(workspace_path, create_missing=False) + replica_cache_hashes = { + slug: _load_replica_cache_hashes(root, slug) + for slug, root in replica_roots.items() + } + normalized_hashes = { + str(rel_path): _normalize_hash_value(hash_value) + for rel_path, hash_value in (file_hashes or {}).items() + if _normalize_hash_value(hash_value) + } - primary_slug = slug_order[0] + for operation in operations: + op_type = str(operation.get("operation") or "") + rel_path = operation.get("path") + if not rel_path: + operations_count["skipped"] += 1 + continue + + sanitized = _sanitize_operation_path(str(rel_path), replica_roots) + if not sanitized: + operations_count["skipped"] += 1 + continue + + if op_type == "deleted": + operations_count["deleted"] += 1 + continue + if op_type == "moved": + operations_count["moved"] += 1 + source_rel_path = operation.get("source_path") or operation.get("source_relative_path") + if not source_rel_path: + needed_files["moved"].append(sanitized) + needed_size_bytes += int(operation.get("size_bytes") or 0) + continue + + move_needs_content = False + for _slug, root in replica_roots.items(): + try: + safe_source_path = _safe_join(root, str(source_rel_path)) + except ValueError: + logger.warning( + "[upload_service] Invalid move source path during plan: %s (root=%s)", + source_rel_path, + root, + ) + move_needs_content = True + break + if not safe_source_path.exists(): + move_needs_content = True + break + if move_needs_content: + needed_files["moved"].append(sanitized) + needed_size_bytes += int(operation.get("size_bytes") or 0) + continue + if op_type not in {"created", "updated"}: + operations_count["failed"] += 1 + continue + + op_content_hash = _normalize_hash_value( + operation.get("content_hash") or normalized_hashes.get(sanitized) + ) + if not op_content_hash: + needed_files[op_type].append(sanitized) + operations_count[op_type] += 1 + needed_size_bytes += int(operation.get("size_bytes") or 0) + continue + + needs_content = False + for slug, root in replica_roots.items(): + try: + target_path = _safe_join(root, sanitized) + except ValueError: + logger.warning( + "[upload_service] Invalid %s path during plan: %s (root=%s)", + op_type, + sanitized, + root, + ) + continue + target_key = _normalize_cache_key_path(str(target_path)) + cached_hash = replica_cache_hashes.get(slug, {}).get(target_key) + if cached_hash != op_content_hash: + needs_content = True + break + + if needs_content: + needed_files[op_type].append(sanitized) + operations_count[op_type] += 1 + needed_size_bytes += int(operation.get("size_bytes") or 0) + else: + operations_count["skipped"] += 1 + operations_count["skipped_hash_match"] += 1 + + return { + "needed_files": needed_files, + "operation_counts_preview": operations_count, + "needed_size_bytes": needed_size_bytes, + "replica_targets": list(replica_roots.keys()), + } + + +def apply_delta_operations( + workspace_path: str, + operations: list[Dict[str, Any]], + file_hashes: Optional[Dict[str, str]] = None, +) -> Dict[str, int]: + """Apply metadata-only delta operations without requiring a tar bundle.""" + operations_count = { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + } + + try: + replica_roots = _resolve_replica_roots(workspace_path) + if not replica_roots: + raise ValueError(f"No replica roots available for workspace: {workspace_path}") + replica_cache_hashes = { + slug: _load_replica_cache_hashes(root, slug) + for slug, root in replica_roots.items() + } + journal_entries_by_slug: Dict[str, list[Dict[str, Any]]] = { + slug: [] for slug in replica_roots.keys() + } + normalized_hashes = { + str(rel_path): _normalize_hash_value(hash_value) + for rel_path, hash_value in (file_hashes or {}).items() + if _normalize_hash_value(hash_value) + } + + for operation in operations: + op_type = str(operation.get("operation") or "") + rel_path = operation.get("path") + + if not rel_path: + operations_count["skipped"] += 1 + continue + + sanitized_path = _sanitize_operation_path(str(rel_path), replica_roots) + if not sanitized_path: + operations_count["skipped"] += 1 + continue + + rel_path = sanitized_path + + if op_type not in {"deleted", "moved"}: + operations_count["failed"] += 1 + continue + + source_rel_path = None + if op_type == "moved": + raw_source = operation.get("source_path") or operation.get("source_relative_path") + if not raw_source: + operations_count["failed"] += 1 + continue + source_rel_path = _sanitize_operation_path(str(raw_source), replica_roots) + if not source_rel_path: + operations_count["failed"] += 1 + continue + + replica_results: Dict[str, str] = {} + for slug, root in replica_roots.items(): + target_path = _safe_join(root, rel_path) + target_key = _normalize_cache_key_path(str(target_path)) + replica_hashes = replica_cache_hashes.setdefault(slug, {}) + op_content_hash = _normalize_hash_value( + operation.get("content_hash") or normalized_hashes.get(rel_path) + ) + + try: + if op_type == "deleted": + if target_path.exists(): + target_path.unlink(missing_ok=True) + _cleanup_empty_dirs(target_path.parent, root) + replica_hashes.pop(target_key, None) + journal_entries_by_slug.setdefault(slug, []).append( + _build_delete_journal_entry(target_path) + ) + replica_results[slug] = "applied" + continue + + safe_source_path = _safe_join(root, source_rel_path or "") + if not safe_source_path.exists(): + replica_results[slug] = "failed" + continue + + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists(): + if target_path.is_dir(): + raise IsADirectoryError( + f"[upload_delta_bundle] move target is a directory: {target_path}" + ) + else: + target_path.unlink() + shutil.move(str(safe_source_path), str(target_path)) + _cleanup_empty_dirs(safe_source_path.parent, root) + source_key = _normalize_cache_key_path(str(safe_source_path)) + moved_hash = replica_hashes.pop(source_key, None) + if op_content_hash: + replica_hashes[target_key] = op_content_hash + elif moved_hash: + replica_hashes[target_key] = moved_hash + move_entry_hash = op_content_hash or moved_hash + journal_entries_by_slug.setdefault(slug, []).extend( + [ + _build_delete_journal_entry(safe_source_path, move_entry_hash), + _build_upsert_journal_entry(target_path, move_entry_hash), + ] + ) + replica_results[slug] = "applied" + except Exception as exc: + logger.debug( + "[upload_service] Failed to apply metadata-only %s to %s in %s: %s", + op_type, + rel_path, + root, + exc, + ) + replica_results[slug] = "failed" + + applied_any = any(result == "applied" for result in replica_results.values()) + success_all = all(result == "applied" for result in replica_results.values()) + if applied_any: + operations_count[op_type] += 1 + if not success_all: + logger.debug( + "[upload_service] Partial metadata-only success for %s %s: %s", + op_type, + rel_path, + replica_results, + ) + else: + operations_count["failed"] += 1 + + for slug, root in replica_roots.items(): + _enqueue_replica_journal_entries( + workspace_root=root, + slug=slug, + entries=journal_entries_by_slug.get(slug, []), + ) + # Flush updated replica hashes to disk (including empty caches) + replica_hashes = replica_cache_hashes.get(slug, {}) + _flush_replica_cache_hashes(root, slug, replica_hashes) + + return operations_count + except Exception as e: + logger.error(f"Error applying metadata-only delta operations: {e}") + raise + + +def process_delta_bundle(workspace_path: str, bundle_path: Path, manifest: Dict[str, Any]) -> Dict[str, int]: + """Process delta bundle and return operation counts.""" + operations_count = { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + } + + try: + replica_roots = _resolve_replica_roots(workspace_path) + if not replica_roots: + raise ValueError(f"No replica roots available for workspace: {workspace_path}") + primary_slug = next(iter(replica_roots)) workspace_root = replica_roots[primary_slug] - def _safe_join(base: Path, rel: str) -> Path: - # SECURITY: Prevent path traversal / absolute-path writes by ensuring the resolved - # candidate path stays within the intended workspace root. - rp = Path(str(rel)) - if str(rp) in {".", ""}: - raise ValueError("Invalid operation path") - if rp.is_absolute(): - raise ValueError(f"Absolute paths are not allowed: {rel}") - base_resolved = base.resolve() - candidate = (base_resolved / rp).resolve() - try: - ok = candidate.is_relative_to(base_resolved) - except Exception: - ok = os.path.commonpath([str(base_resolved), str(candidate)]) == str(base_resolved) - if not ok: - raise ValueError(f"Path escapes workspace: {rel}") - return candidate + def _member_suffix(name: str, marker: str) -> Optional[str]: + idx = name.find(marker) + if idx < 0: + return None + suffix = name[idx + len(marker):] + return suffix or None with tarfile.open(bundle_path, "r:gz") as tar: ops_member = None - for member in tar.getnames(): - if member.endswith("metadata/operations.json"): + hashes_member = None + git_member = None + created_members: Dict[str, tarfile.TarInfo] = {} + updated_members: Dict[str, tarfile.TarInfo] = {} + moved_members: Dict[str, tarfile.TarInfo] = {} + for member in tar.getmembers(): + name = member.name + if name.endswith("metadata/operations.json"): ops_member = member - break + continue + if name.endswith("metadata/hashes.json"): + hashes_member = member + continue + if name.endswith("metadata/git_history.json"): + git_member = member + continue + created_rel = _member_suffix(name, "files/created/") + if created_rel: + created_members[created_rel] = member + continue + updated_rel = _member_suffix(name, "files/updated/") + if updated_rel: + updated_members[updated_rel] = member + continue + moved_rel = _member_suffix(name, "files/moved/") + if moved_rel: + moved_members[moved_rel] = member if not ops_member: raise ValueError("operations.json not found in bundle") @@ -240,14 +678,28 @@ def _safe_join(base: Path, rel: str) -> Path: operations_data = json.loads(ops_file.read().decode("utf-8")) operations = operations_data.get("operations", []) + bundle_hashes: Dict[str, str] = {} + if hashes_member: + hashes_file = tar.extractfile(hashes_member) + if hashes_file: + hashes_data = json.loads(hashes_file.read().decode("utf-8")) + raw_hashes = hashes_data.get("file_hashes", {}) + if isinstance(raw_hashes, dict): + for rel_path, hash_value in raw_hashes.items(): + digest = _normalize_hash_value(hash_value) + if digest: + bundle_hashes[str(rel_path)] = digest + + replica_cache_hashes = { + slug: _load_replica_cache_hashes(root, slug) + for slug, root in replica_roots.items() + } + journal_entries_by_slug: Dict[str, list[Dict[str, Any]]] = { + slug: [] for slug in replica_roots.keys() + } # Best-effort: extract git history metadata for watcher to ingest try: - git_member = None - for member in tar.getnames(): - if member.endswith("metadata/git_history.json"): - git_member = member - break if git_member: git_file = tar.extractfile(git_member) if git_file: @@ -266,11 +718,20 @@ def _safe_join(base: Path, rel: str) -> Path: except Exception as git_err: logger.debug(f"[upload_service] Error extracting git history metadata: {git_err}") - def _apply_operation_to_workspace(workspace_root: Path) -> bool: - """Apply a single file operation to a workspace. Returns True on success.""" - nonlocal operations_count, op_type, rel_path, tar - + def _apply_operation_to_workspace( + slug: str, + workspace_root: Path, + op_type: str, + rel_path: str, + operation: Dict[str, Any], + ) -> str: + """Apply a single file operation to a workspace.""" target_path = _safe_join(workspace_root, rel_path) + target_key = _normalize_cache_key_path(str(target_path)) + replica_hashes = replica_cache_hashes.setdefault(slug, {}) + op_content_hash = _normalize_hash_value( + operation.get("content_hash") or bundle_hashes.get(rel_path) + ) safe_source_path = None source_rel_path = None @@ -281,76 +742,113 @@ def _apply_operation_to_workspace(workspace_root: Path) -> bool: try: if op_type == "created": - file_member = None - for member in tar.getnames(): - if member.endswith(f"files/created/{rel_path}"): - file_member = member - break - + if op_content_hash and target_path.exists(): + cached_hash = replica_hashes.get(target_key) + if cached_hash and cached_hash == op_content_hash: + return "skipped_hash_match" + file_member = created_members.get(rel_path) if file_member: file_content = tar.extractfile(file_member) if file_content: target_path.parent.mkdir(parents=True, exist_ok=True) target_path.write_bytes(file_content.read()) - return True + if op_content_hash: + replica_hashes[target_key] = op_content_hash + journal_entries_by_slug.setdefault(slug, []).append( + _build_upsert_journal_entry(target_path, op_content_hash) + ) + return "applied" else: - return False + return "failed" else: - return False + return "failed" elif op_type == "updated": - file_member = None - for member in tar.getnames(): - if member.endswith(f"files/updated/{rel_path}"): - file_member = member - break - + if op_content_hash and target_path.exists(): + cached_hash = replica_hashes.get(target_key) + if cached_hash and cached_hash == op_content_hash: + return "skipped_hash_match" + file_member = updated_members.get(rel_path) if file_member: file_content = tar.extractfile(file_member) if file_content: target_path.parent.mkdir(parents=True, exist_ok=True) target_path.write_bytes(file_content.read()) - return True + if op_content_hash: + replica_hashes[target_key] = op_content_hash + journal_entries_by_slug.setdefault(slug, []).append( + _build_upsert_journal_entry(target_path, op_content_hash) + ) + return "applied" else: - return False + return "failed" else: - return False + return "failed" elif op_type == "deleted": if target_path.exists(): target_path.unlink(missing_ok=True) - return True - else: - return True # Already deleted + _cleanup_empty_dirs(target_path.parent, workspace_root) + replica_hashes.pop(target_key, None) + journal_entries_by_slug.setdefault(slug, []).append( + _build_delete_journal_entry(target_path) + ) + return "applied" elif op_type == "moved": if safe_source_path and safe_source_path.exists(): target_path.parent.mkdir(parents=True, exist_ok=True) - safe_source_path.rename(target_path) - return True + if target_path.exists(): + if target_path.is_dir(): + raise IsADirectoryError( + f"[upload_service] move target is a directory: {target_path}" + ) + else: + target_path.unlink() + shutil.move(str(safe_source_path), str(target_path)) + _cleanup_empty_dirs(safe_source_path.parent, workspace_root) + source_key = _normalize_cache_key_path(str(safe_source_path)) + moved_hash = replica_hashes.pop(source_key, None) + if op_content_hash: + replica_hashes[target_key] = op_content_hash + elif moved_hash: + replica_hashes[target_key] = moved_hash + move_entry_hash = op_content_hash or moved_hash + journal_entries_by_slug.setdefault(slug, []).extend( + [ + _build_delete_journal_entry(safe_source_path, move_entry_hash), + _build_upsert_journal_entry(target_path, move_entry_hash), + ] + ) + return "applied" # Remote uploads may not have the source file on the server (e.g. staging # mirrors). In that case, clients can embed the destination content under # files/moved/. - file_member = None - for member in tar.getnames(): - if member.endswith(f"files/moved/{rel_path}"): - file_member = member - break + file_member = moved_members.get(rel_path) if file_member: file_content = tar.extractfile(file_member) if file_content: target_path.parent.mkdir(parents=True, exist_ok=True) target_path.write_bytes(file_content.read()) - return True - return False - return False + if op_content_hash: + replica_hashes[target_key] = op_content_hash + if safe_source_path: + journal_entries_by_slug.setdefault(slug, []).append( + _build_delete_journal_entry(safe_source_path, op_content_hash) + ) + journal_entries_by_slug.setdefault(slug, []).append( + _build_upsert_journal_entry(target_path, op_content_hash) + ) + return "applied" + return "failed" + return "failed" else: logger.warning(f"[upload_service] Unknown operation type: {op_type}") - return False + return "failed" except Exception as e: logger.debug(f"[upload_service] Failed to apply {op_type} to {rel_path} in {workspace_root}: {e}") - return False + return "failed" for operation in operations: op_type = operation.get("operation") @@ -360,18 +858,8 @@ def _apply_operation_to_workspace(workspace_root: Path) -> bool: operations_count["skipped"] += 1 continue - sanitized_path = rel_path - skipped_due_to_exact_slug = False - for slug in replica_roots.keys(): - if sanitized_path == slug: - skipped_due_to_exact_slug = True - break - prefix = f"{slug}/" - if sanitized_path.startswith(prefix): - sanitized_path = sanitized_path[len(prefix):] - break - - if skipped_due_to_exact_slug or not sanitized_path: + sanitized_path = _sanitize_operation_path(str(rel_path), replica_roots) + if not sanitized_path: logger.debug( f"[upload_service] Skipping operation {op_type} for path {rel_path}: " "appears to reference slug root directly.", @@ -381,22 +869,44 @@ def _apply_operation_to_workspace(workspace_root: Path) -> bool: rel_path = sanitized_path - replica_results: Dict[str, bool] = {} + replica_results: Dict[str, str] = {} for slug, root in replica_roots.items(): - replica_results[slug] = _apply_operation_to_workspace(root) + replica_results[slug] = _apply_operation_to_workspace( + slug, + root, + op_type, + rel_path, + operation, + ) - success_any = any(replica_results.values()) - success_all = all(replica_results.values()) - if success_any: + applied_any = any(result == "applied" for result in replica_results.values()) + skipped_hash_match = bool(replica_results) and all( + result == "skipped_hash_match" for result in replica_results.values() + ) + success_all = all(result in {"applied", "skipped_hash_match"} for result in replica_results.values()) + if applied_any: operations_count.setdefault(op_type, 0) - operations_count[op_type] = operations_count.get(op_type, 0) + 1 + operations_count[op_type] += 1 if not success_all: logger.debug( f"[upload_service] Partial success for {op_type} {rel_path}: {replica_results}" ) + elif skipped_hash_match: + operations_count["skipped"] += 1 + operations_count["skipped_hash_match"] += 1 else: operations_count["failed"] += 1 + for slug, root in replica_roots.items(): + _enqueue_replica_journal_entries( + workspace_root=root, + slug=slug, + entries=journal_entries_by_slug.get(slug, []), + ) + # Flush updated replica hashes to disk (including empty caches) + replica_hashes = replica_cache_hashes.get(slug, {}) + _flush_replica_cache_hashes(root, slug, replica_hashes) + return operations_count except Exception as e: diff --git a/scripts/upload_service.py b/scripts/upload_service.py index 6771d652..e01a75fc 100644 --- a/scripts/upload_service.py +++ b/scripts/upload_service.py @@ -45,21 +45,13 @@ from fastapi.responses import JSONResponse, RedirectResponse from fastapi.middleware.cors import CORSMiddleware -from scripts.upload_delta_bundle import get_workspace_key, process_delta_bundle - -from scripts.indexing_admin import ( - build_admin_collections_view, - resolve_collection_root, - spawn_ingest_code, - recreate_collection_qdrant, +from scripts.upload_delta_bundle import ( + apply_delta_operations, + get_workspace_key, + plan_delta_upload, + process_delta_bundle, ) -try: - from scripts.workspace_state import is_staging_enabled -except Exception: - is_staging_enabled = None # type: ignore - - from pydantic import BaseModel, Field from scripts.auth_backend import ( AuthDisabledError, @@ -81,74 +73,31 @@ revoke_collection_access, ) -try: - from scripts.collection_admin import delete_collection_everywhere, copy_collection_qdrant -except Exception: - delete_collection_everywhere = None - copy_collection_qdrant = None -try: - from scripts.admin_ui import ( - render_admin_acl, - render_admin_bootstrap, - render_admin_error, - render_admin_login, - ) -except Exception: - - def _admin_ui_unavailable(*args, **kwargs): - raise HTTPException(status_code=500, detail="Admin UI unavailable") - - render_admin_acl = _admin_ui_unavailable - render_admin_bootstrap = _admin_ui_unavailable - render_admin_error = _admin_ui_unavailable - render_admin_login = _admin_ui_unavailable - -# Import staging/indexing admin helpers -try: - from scripts.indexing_admin import ( - start_staging_rebuild, - activate_staging_rebuild, - abort_staging_rebuild, - ) -except ImportError: - start_staging_rebuild = None # type: ignore - activate_staging_rebuild = None # type: ignore - abort_staging_rebuild = None # type: ignore +from scripts.admin_ui import ( + render_admin_acl, + render_admin_bootstrap, + render_admin_error, + render_admin_login, +) -# Import existing workspace state and indexing functions -try: - from scripts.workspace_state import ( - log_activity, - get_collection_name, - get_cached_file_hash, - set_cached_file_hash, - _extract_repo_name_from_path, - update_repo_origin, - get_collection_mappings, - find_collection_for_logical_repo, - update_workspace_state, - set_staging_state, - update_staging_status, - clear_staging_collection, - logical_repo_reuse_enabled, - get_collection_state_snapshot, - ) -except ImportError: - # Fallback for testing without full environment - log_activity = None - get_collection_name = None - get_cached_file_hash = None - set_cached_file_hash = None - _extract_repo_name_from_path = None - update_repo_origin = None - get_collection_mappings = None - find_collection_for_logical_repo = None - update_workspace_state = None - set_staging_state = None - update_staging_status = None - clear_staging_collection = None - def logical_repo_reuse_enabled() -> bool: # type: ignore[no-redef] - return False +from scripts.workspace_state import ( + is_staging_enabled, + log_activity, + get_collection_name, + get_cached_file_hash, + set_cached_file_hash, + _extract_repo_name_from_path, + update_repo_origin, + get_collection_mappings, + find_collection_for_logical_repo, + update_workspace_state, + set_staging_state, + update_staging_status, + clear_staging_collection, + clear_index_journal_entries, + logical_repo_reuse_enabled, + get_collection_state_snapshot, +) # Configure logging @@ -182,6 +131,15 @@ def logical_repo_reuse_enabled() -> bool: # type: ignore[no-redef] ) BRIDGE_STATE_TOKEN = (os.environ.get("CTXCE_BRIDGE_STATE_TOKEN") or "").strip() + +# Admin collection operations import Qdrant clients, subprocess indexing helpers, +# and collection-copy/delete wiring. Keep those off the upload/status import path. +def _indexing_admin(): + from scripts import indexing_admin + + return indexing_admin + + # FastAPI app app = FastAPI( title="Context-Engine Delta Upload Service", @@ -200,6 +158,7 @@ def logical_repo_reuse_enabled() -> bool: # type: ignore[no-redef] # In-memory sequence tracking (in production, use persistent storage) _sequence_tracker: Dict[str, int] = {} +_upload_result_tracker: Dict[str, Dict[str, Any]] = {} def _int_env(name: str, default: int) -> int: @@ -224,6 +183,39 @@ class UploadResponse(BaseModel): next_sequence: Optional[int] = None error: Optional[Dict[str, Any]] = None + +class PlanRequest(BaseModel): + workspace_path: str + collection_name: Optional[str] = None + source_path: Optional[str] = None + logical_repo_id: Optional[str] = None + session: Optional[str] = None + manifest: Dict[str, Any] = Field(default_factory=dict) + operations: List[Dict[str, Any]] = Field(default_factory=list) + file_hashes: Dict[str, str] = Field(default_factory=dict) + + +class PlanResponse(BaseModel): + success: bool + workspace_path: str + needed_files: Dict[str, List[str]] + operation_counts_preview: Dict[str, int] + needed_size_bytes: int + replica_targets: List[str] + fallback_used: bool = False + error: Optional[Dict[str, Any]] = None + + +class ApplyOperationsRequest(BaseModel): + workspace_path: str + collection_name: Optional[str] = None + source_path: Optional[str] = None + logical_repo_id: Optional[str] = None + session: Optional[str] = None + manifest: Dict[str, Any] = Field(default_factory=dict) + operations: List[Dict[str, Any]] = Field(default_factory=list) + file_hashes: Dict[str, str] = Field(default_factory=dict) + class StatusResponse(BaseModel): workspace_path: str collection_name: str @@ -407,9 +399,8 @@ def _resolve_bridge_state_target( repo = (repo_name or "").strip() or None if collection: - if resolve_collection_root is None: - raise HTTPException(status_code=400, detail="collection mapping unavailable") - root, resolved_repo = resolve_collection_root(collection=collection, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + root, resolved_repo = indexing_admin.resolve_collection_root(collection=collection, work_dir=WORK_DIR) if not root: raise HTTPException(status_code=404, detail="collection mapping not found") workspace_path = root @@ -480,34 +471,83 @@ async def _process_bundle_background( sequence_number: Optional[int], bundle_id: Optional[str], ) -> None: + key = get_workspace_key(workspace_path) try: start_time = datetime.now() + _upload_result_tracker[key] = { + "workspace_path": workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": None, + "processing_time_ms": None, + "status": "processing", + "completed_at": None, + } operations_count = await asyncio.to_thread( process_delta_bundle, workspace_path, bundle_path, manifest ) + processing_time = int((datetime.now() - start_time).total_seconds() * 1000) + failed_count = int((operations_count or {}).get("failed") or 0) + applied_count = int( + (operations_count or {}).get("created", 0) + + (operations_count or {}).get("updated", 0) + + (operations_count or {}).get("deleted", 0) + + (operations_count or {}).get("moved", 0) + ) + status_value = "completed" if failed_count == 0 else "failed" if sequence_number is not None: - key = get_workspace_key(workspace_path) _sequence_tracker[key] = sequence_number - if log_activity: - try: - repo = _extract_repo_name_from_path(workspace_path) if _extract_repo_name_from_path else None - log_activity( - repo_name=repo, - action="uploaded", - file_path=bundle_id, - details={ - "bundle_id": bundle_id, - "operations": operations_count, - "source": "delta_upload", - }, - ) - except Exception as activity_err: - logger.debug(f"[upload_service] Failed to log activity for bundle {bundle_id}: {activity_err}") - processing_time = (datetime.now() - start_time).total_seconds() * 1000 - logger.info( - f"[upload_service] Finished processing bundle {bundle_id} seq {sequence_number} in {int(processing_time)}ms" - ) + _upload_result_tracker[key] = { + "workspace_path": workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": operations_count, + "processing_time_ms": processing_time, + "status": status_value, + "failed_count": failed_count, + "partial": bool(failed_count > 0 and applied_count > 0), + "completed_at": datetime.now().isoformat(), + } + try: + repo = _extract_repo_name_from_path(workspace_path) + log_activity( + repo_name=repo, + action="uploaded", + file_path=bundle_id, + details={ + "bundle_id": bundle_id, + "operations": operations_count, + "source": "delta_upload", + }, + ) + except Exception as activity_err: + logger.debug(f"[upload_service] Failed to log activity for bundle {bundle_id}: {activity_err}") + if failed_count > 0: + logger.warning( + "[upload_service] Finished processing bundle %s seq %s with failures in %sms " + "failed=%d ops=%s", + bundle_id, + sequence_number, + processing_time, + failed_count, + operations_count, + ) + else: + logger.info( + f"[upload_service] Finished processing bundle {bundle_id} seq {sequence_number} " + f"in {processing_time}ms ops={operations_count}" + ) except Exception as e: + _upload_result_tracker[key] = { + "workspace_path": workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": None, + "processing_time_ms": None, + "status": "error", + "completed_at": datetime.now().isoformat(), + "error": str(e), + } logger.error(f"[upload_service] Error in background processing for bundle {bundle_id}: {e}") finally: try: @@ -707,8 +747,9 @@ async def admin_acl_page(request: Request): logger.error(f"[upload_service] Failed to load admin UI data: {e}") raise HTTPException(status_code=500, detail="Failed to load admin data") + indexing_admin = _indexing_admin() enriched = await asyncio.to_thread( - build_admin_collections_view, collections=collections, work_dir=WORK_DIR + indexing_admin.build_admin_collections_view, collections=collections, work_dir=WORK_DIR ) resp = render_admin_acl( @@ -753,8 +794,9 @@ async def admin_collections_status(request: Request): except Exception: raise HTTPException(status_code=500, detail="Failed to load collections") + indexing_admin = _indexing_admin() enriched = await asyncio.to_thread( - lambda: build_admin_collections_view(collections=collections, work_dir=WORK_DIR) + lambda: indexing_admin.build_admin_collections_view(collections=collections, work_dir=WORK_DIR) ) return JSONResponse({"collections": enriched}) @@ -774,7 +816,8 @@ async def admin_reindex_collection( back_href="/admin/acl", ) - root, repo_name = resolve_collection_root(collection=name, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + root, repo_name = indexing_admin.resolve_collection_root(collection=name, work_dir=WORK_DIR) if not root: return render_admin_error( request, @@ -784,7 +827,7 @@ async def admin_reindex_collection( ) try: - spawn_ingest_code( + indexing_admin.spawn_ingest_code( root=root, work_dir=WORK_DIR, collection=name, @@ -810,9 +853,6 @@ async def bridge_collection_state( workspace: Optional[str] = None, repo_name: Optional[str] = None, ): - if get_collection_state_snapshot is None: - raise HTTPException(status_code=503, detail="workspace_state helper unavailable") - _bridge_state_authorized(request) workspace_path, repo = _resolve_bridge_state_target(collection=collection, workspace=workspace, repo_name=repo_name) @@ -821,7 +861,7 @@ async def bridge_collection_state( if not snapshot: raise HTTPException(status_code=404, detail="Workspace state not found") - if not (is_staging_enabled() if callable(is_staging_enabled) else False): + if not is_staging_enabled(): # Classic mode: ignore any serving_* overrides from staging/migration. snapshot = dict(snapshot) snapshot.pop("serving_collection", None) @@ -856,7 +896,8 @@ async def admin_recreate_collection( back_href="/admin/acl", ) - root, repo_name = resolve_collection_root(collection=name, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + root, repo_name = indexing_admin.resolve_collection_root(collection=name, work_dir=WORK_DIR) if not root: return render_admin_error( request, @@ -866,12 +907,12 @@ async def admin_recreate_collection( ) try: - recreate_collection_qdrant( + indexing_admin.recreate_collection_qdrant( qdrant_url=QDRANT_URL, api_key=os.environ.get("QDRANT_API_KEY") or None, collection=name, ) - spawn_ingest_code( + indexing_admin.spawn_ingest_code( root=root, work_dir=WORK_DIR, collection=name, @@ -890,6 +931,50 @@ async def admin_recreate_collection( return RedirectResponse(url="/admin/acl", status_code=302) +@app.post("/admin/collections/clear-journal") +async def admin_clear_collection_journal( + request: Request, + collection: str = Form(...), +): + _require_admin_session(request) + name = (collection or "").strip() + if not name: + return render_admin_error( + request, + title="Clear Journal Failed", + message="collection is required", + back_href="/admin/acl", + ) + + indexing_admin = _indexing_admin() + root, repo_name = indexing_admin.resolve_collection_root(collection=name, work_dir=WORK_DIR) + if not root: + return render_admin_error( + request, + title="Clear Journal Failed", + message="No workspace mapping found for collection", + back_href="/admin/acl", + ) + + try: + removed = clear_index_journal_entries(workspace_path=root, repo_name=repo_name) + except Exception as e: + return render_admin_error( + request, + title="Clear Journal Failed", + message=str(e), + back_href="/admin/acl", + ) + + try: + from urllib.parse import urlencode + + url = "/admin/acl?" + urlencode({"journal_cleared": name, "journal_removed": str(removed)}) + except Exception: + url = "/admin/acl" + return RedirectResponse(url=url, status_code=302) + + @app.post("/admin/collections/delete") async def admin_delete_collection( request: Request, @@ -920,14 +1005,6 @@ async def admin_delete_collection( back_href="/admin/acl", ) - if delete_collection_everywhere is None: - return render_admin_error( - request, - title="Delete Collection Failed", - message="Collection delete helper unavailable", - back_href="/admin/acl", - ) - # Default is Qdrant-only (no filesystem cleanup). Users must explicitly opt in. try: cleanup_fs = (delete_fs or "").strip().lower() in {"1", "true", "yes", "on"} @@ -935,7 +1012,10 @@ async def admin_delete_collection( cleanup_fs = False try: - delete_collection_everywhere( + # Collection deletion imports Qdrant admin helpers only on the admin route. + from scripts.collection_admin import delete_collection_everywhere + + out = delete_collection_everywhere( collection=name, work_dir=WORK_DIR, qdrant_url=QDRANT_URL, @@ -949,7 +1029,23 @@ async def admin_delete_collection( back_href="/admin/acl", ) - return RedirectResponse(url="/admin/acl", status_code=302) + graph_deleted: Optional[str] = None + try: + if isinstance(out, dict) and not name.endswith("_graph"): + graph_deleted = "1" if bool(out.get("qdrant_graph_deleted")) else "0" + except Exception: + graph_deleted = None + + try: + from urllib.parse import urlencode + + params = {"deleted": name} + if graph_deleted is not None: + params["graph_deleted"] = graph_deleted + url = "/admin/acl?" + urlencode(params) + except Exception: + url = "/admin/acl" + return RedirectResponse(url=url, status_code=302) @app.post("/admin/staging/start") @@ -958,7 +1054,7 @@ async def admin_start_staging( collection: str = Form(...), ): _require_admin_session(request) - if not (is_staging_enabled() if callable(is_staging_enabled) else False): + if not is_staging_enabled(): return render_admin_error( request, title="Start Staging Failed", @@ -975,18 +1071,11 @@ async def admin_start_staging( back_href="/admin/acl", ) - if start_staging_rebuild is None: - return render_admin_error( - request, - title="Start Staging Failed", - message="Staging helper unavailable", - back_href="/admin/acl", - ) - root: Optional[str] = None repo_name: Optional[str] = None try: - root, repo_name = resolve_collection_root(collection=name, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + root, repo_name = indexing_admin.resolve_collection_root(collection=name, work_dir=WORK_DIR) except Exception: root, repo_name = None, None if not root: @@ -1014,12 +1103,10 @@ async def admin_start_staging( }, } - if set_staging_state: - try: - set_staging_state(workspace_path=root, repo_name=repo_name, staging=staging_payload) - except Exception as set_err: - logger.warning(f"[admin] Failed to persist queued staging state for {name}: {set_err}") - elif update_workspace_state: + try: + set_staging_state(workspace_path=root, repo_name=repo_name, staging=staging_payload) + except Exception as set_err: + logger.warning(f"[admin] Failed to persist queued staging state for {name}: {set_err}") try: update_workspace_state( workspace_path=root, @@ -1029,30 +1116,29 @@ async def admin_start_staging( except Exception as set_err: logger.warning(f"[admin] Failed to update workspace state for queued staging {name}: {set_err}") - if update_staging_status: - try: - update_staging_status( - workspace_path=root, - repo_name=repo_name, - status={"state": "queued", "queued_at": now, "request_id": request_id}, - ) - except Exception as status_err: - logger.debug(f"[admin] Failed to mark staging status queued for {name}: {status_err}") + try: + update_staging_status( + workspace_path=root, + repo_name=repo_name, + status={"state": "queued", "queued_at": now, "request_id": request_id}, + ) + except Exception as status_err: + logger.debug(f"[admin] Failed to mark staging status queued for {name}: {status_err}") try: async def _bg_start() -> None: try: staging_collection = await asyncio.to_thread( - start_staging_rebuild, collection=name, work_dir=WORK_DIR + indexing_admin.start_staging_rebuild, collection=name, work_dir=WORK_DIR ) logger.info(f"[admin] Started staging rebuild for {name} -> {staging_collection}") except Exception as e: logger.error(f"[admin] Background staging start failed for {name}: {e}") # Ensure we don't leave the workspace stuck in a queued staging state. try: - if clear_staging_collection: + try: clear_staging_collection(workspace_path=root, repo_name=repo_name) - elif update_workspace_state: + except Exception: update_workspace_state( workspace_path=root, repo_name=repo_name, @@ -1081,7 +1167,7 @@ async def admin_activate_staging( collection: str = Form(...), ): _require_admin_session(request) - if not (is_staging_enabled() if callable(is_staging_enabled) else False): + if not is_staging_enabled(): return render_admin_error( request, title="Activate Staging Failed", @@ -1097,18 +1183,11 @@ async def admin_activate_staging( back_href="/admin/acl", ) - if activate_staging_rebuild is None: - return render_admin_error( - request, - title="Activate Staging Failed", - message="Staging helper unavailable", - back_href="/admin/acl", - ) - try: async def _bg_activate() -> None: try: - await asyncio.to_thread(activate_staging_rebuild, collection=name, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + await asyncio.to_thread(indexing_admin.activate_staging_rebuild, collection=name, work_dir=WORK_DIR) logger.info(f"[admin] Activated staging for {name}") except Exception as e: logger.error(f"[admin] Background staging activate failed for {name}: {e}") @@ -1140,7 +1219,8 @@ async def admin_abort_staging( back_href="/admin/acl", ) - root, repo_name = resolve_collection_root(collection=name, work_dir=WORK_DIR) + indexing_admin = _indexing_admin() + root, repo_name = indexing_admin.resolve_collection_root(collection=name, work_dir=WORK_DIR) if not root: return render_admin_error( request, @@ -1150,21 +1230,13 @@ async def admin_abort_staging( ) try: - if abort_staging_rebuild is not None: - # Run abort to completion so we always clear staging metadata before returning. - await asyncio.to_thread( - abort_staging_rebuild, - collection=name, - work_dir=WORK_DIR, - delete_collection=True, - ) - logger.info(f"[admin] Aborted staging rebuild for {name}") - elif clear_staging_collection: - # Fallback for older deployments: clear staging metadata only. - clear_staging_collection(workspace_path=root, repo_name=repo_name) - logger.info(f"[admin] Aborted staging for {name} (metadata only)") - else: - raise RuntimeError("staging abort helpers unavailable") + await asyncio.to_thread( + indexing_admin.abort_staging_rebuild, + collection=name, + work_dir=WORK_DIR, + delete_collection=True, + ) + logger.info(f"[admin] Aborted staging rebuild for {name}") except Exception as e: return render_admin_error( request, @@ -1193,20 +1265,15 @@ async def admin_copy_collection( back_href="/admin/acl", ) - if copy_collection_qdrant is None: - return render_admin_error( - request, - title="Copy Collection Failed", - message="copy helper unavailable", - back_href="/admin/acl", - ) - try: allow_overwrite = str(overwrite or "").strip().lower() in {"1", "true", "yes", "on"} except Exception: allow_overwrite = False try: + # Collection copy imports Qdrant admin helpers only on the admin route. + from scripts.collection_admin import copy_collection_qdrant + new_name = copy_collection_qdrant( source=name, target=(target or None), @@ -1222,7 +1289,60 @@ async def admin_copy_collection( back_href="/admin/acl", ) - return RedirectResponse(url="/admin/acl", status_code=302) + graph_copied: Optional[str] = None + try: + if not name.endswith("_graph") and not str(new_name).endswith("_graph"): + used_pooled = True + try: + # Qdrant client pool is only needed to verify the copied graph collection. + from scripts.qdrant_client_manager import pooled_qdrant_client + + with pooled_qdrant_client( + url=QDRANT_URL, + api_key=os.environ.get("QDRANT_API_KEY"), + ) as cli: + try: + cli.get_collection(collection_name=f"{new_name}_graph") + graph_copied = "1" + except Exception: + graph_copied = "0" + except Exception: + # Failed to acquire pooled client; fall back to non-pooled + used_pooled = False + if not used_pooled: + try: + from qdrant_client import QdrantClient # type: ignore + + cli = QdrantClient( + url=QDRANT_URL, + api_key=os.environ.get("QDRANT_API_KEY"), + timeout=float(os.environ.get("QDRANT_TIMEOUT", "5") or 5), + ) + try: + cli.get_collection(collection_name=f"{new_name}_graph") + graph_copied = "1" + except Exception: + graph_copied = "0" + finally: + try: + cli.close() + except Exception: + pass + except Exception: + graph_copied = "0" + except Exception: + graph_copied = None + + try: + from urllib.parse import urlencode + + params = {"copied": name, "new": new_name} + if graph_copied is not None: + params["graph_copied"] = graph_copied + url = "/admin/acl?" + urlencode(params) + except Exception: + url = "/admin/acl" + return RedirectResponse(url=url, status_code=302) @app.post("/admin/users") @@ -1347,16 +1467,17 @@ async def get_status(workspace_path: str): """Get upload status for workspace.""" try: # Get collection name - if get_collection_name: - repo_name = _extract_repo_name_from_path(workspace_path) if _extract_repo_name_from_path else None - collection_name = get_collection_name(repo_name) - else: - collection_name = DEFAULT_COLLECTION + repo_name = _extract_repo_name_from_path(workspace_path) + collection_name = get_collection_name(repo_name) # Get last sequence last_sequence = get_last_sequence(workspace_path) + key = get_workspace_key(workspace_path) + upload_result = _upload_result_tracker.get(key, {}) - last_upload = None + last_upload = upload_result.get("completed_at") + upload_status = str(upload_result.get("status") or "") + workspace_status = "processing" if upload_status == "processing" else "ready" return StatusResponse( workspace_path=workspace_path, @@ -1364,11 +1485,16 @@ async def get_status(workspace_path: str): last_sequence=last_sequence, last_upload=last_upload, pending_operations=0, - status="ready", + status=workspace_status, server_info={ "version": "1.0.0", "max_bundle_size_mb": MAX_BUNDLE_SIZE_MB, - "supported_formats": ["tar.gz"] + "supported_formats": ["tar.gz"], + "last_bundle_id": upload_result.get("bundle_id"), + "last_processing_time_ms": upload_result.get("processing_time_ms"), + "last_processed_operations": upload_result.get("processed_operations"), + "last_upload_status": upload_status or None, + "last_error": upload_result.get("error"), } ) @@ -1376,6 +1502,357 @@ async def get_status(workspace_path: str): logger.error(f"Error getting status: {e}") raise HTTPException(status_code=500, detail=str(e)) + +def _resolve_collection_for_request( + workspace_path: str, + client_collection_name: Optional[str], + logical_repo_id: Optional[str], + source_path: Optional[str] = None, +) -> Tuple[str, Optional[str]]: + """ + Resolve collection name and repo_name for upload/plan/apply requests. + + Returns: + Tuple of (collection_name, repo_name) + """ + # Resolve collection name for ACL enforcement + collection_name: Optional[str] = None + repo_name: Optional[str] = None + + repo_source = (source_path or "").strip() or workspace_path + repo_name = _extract_repo_name_from_path(repo_source) + if not repo_name: + repo_name = Path(repo_source).name + + resolved_collection: Optional[str] = None + + # Resolve collection name, preferring server-side mapping for logical_repo_id when enabled + if logical_repo_reuse_enabled() and logical_repo_id: + try: + existing = find_collection_for_logical_repo(logical_repo_id, search_root=WORK_DIR) + except Exception: + existing = None + if existing: + resolved_collection = existing + + # Latent migration: when no explicit mapping exists yet for this logical_repo_id, but there is a + # single existing collection mapping, prefer reusing it rather than creating a fresh collection. + if logical_repo_reuse_enabled() and logical_repo_id and resolved_collection is None: + try: + mappings = get_collection_mappings(search_root=WORK_DIR) or [] + except Exception: + mappings = [] + + if len(mappings) == 1: + canonical = mappings[0] + canonical_coll = canonical.get("collection_name") + if canonical_coll: + resolved_collection = canonical_coll + try: + update_workspace_state( + workspace_path=canonical.get("container_path") or canonical.get("state_file"), + updates={"logical_repo_id": logical_repo_id}, + repo_name=canonical.get("repo_name"), + ) + except Exception as migrate_err: + logger.debug( + f"[upload_service] Failed to migrate logical_repo_id for existing mapping: {migrate_err}" + ) + + # Upload-managed requests are server-owned; ignore client-supplied collection routing. + if resolved_collection is not None: + collection_name = resolved_collection + else: + collection_name = get_collection_name(repo_name) if repo_name else DEFAULT_COLLECTION + + return collection_name, repo_name + + +@app.post("/api/v1/delta/plan", response_model=PlanResponse) +async def plan_delta(request: PlanRequest): + """Plan which file bodies are needed before uploading content.""" + try: + workspace = Path(request.workspace_path) + if not workspace.is_absolute(): + workspace = Path(WORK_DIR) / workspace + workspace_path = str(workspace.resolve()) + + if AUTH_ENABLED: + session_value = str(request.session or "").strip() + try: + record = validate_session(session_value) + except AuthDisabledError: + record = None + except Exception as e: + logger.error(f"[upload_service] Failed to validate auth session for plan: {e}") + raise HTTPException(status_code=500, detail="Failed to validate auth session") + if record is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired session", + ) + + # Resolve collection name for ACL enforcement + collection_name, repo_name = _resolve_collection_for_request( + workspace_path=workspace_path, + client_collection_name=request.collection_name, + logical_repo_id=request.logical_repo_id, + source_path=request.source_path, + ) + + # Enforce collection write access for plan/apply when auth is enabled + if AUTH_ENABLED and CTXCE_MCP_ACL_ENFORCE: + if not collection_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Collection resolution failed for ACL enforcement", + ) + uid = str((record or {}).get("user_id") or "").strip() + if not uid: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired session", + ) + try: + allowed = has_collection_access(uid, str(collection_name), "write") + except AuthDisabledError: + allowed = True + if not allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have write access to collection '{collection_name}'", + ) + + plan = plan_delta_upload( + workspace_path=workspace_path, + operations=request.operations, + file_hashes=request.file_hashes, + ) + return PlanResponse( + success=True, + workspace_path=workspace_path, + needed_files=plan.get("needed_files", {"created": [], "updated": [], "moved": []}), + operation_counts_preview=plan.get( + "operation_counts_preview", + { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + }, + ), + needed_size_bytes=int(plan.get("needed_size_bytes", 0) or 0), + replica_targets=list(plan.get("replica_targets", []) or []), + fallback_used=False, + error=None, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"[upload_service] Error planning delta upload: {e}") + return PlanResponse( + success=False, + workspace_path=request.workspace_path, + needed_files={"created": [], "updated": [], "moved": []}, + operation_counts_preview={ + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + }, + needed_size_bytes=0, + replica_targets=[], + fallback_used=True, + error={ + "code": "PLAN_ERROR", + "message": str(e), + }, + ) + + +@app.post("/api/v1/delta/apply_ops", response_model=UploadResponse) +async def apply_delta_ops(request: ApplyOperationsRequest): + """Apply metadata-only delta operations without uploading a tar bundle.""" + key: Optional[str] = None + bundle_id: Optional[str] = None + sequence_number: Optional[int] = None + try: + workspace = Path(request.workspace_path) + if not workspace.is_absolute(): + workspace = Path(WORK_DIR) / workspace + workspace_path = str(workspace.resolve()) + + if AUTH_ENABLED: + session_value = str(request.session or "").strip() + try: + record = validate_session(session_value) + except AuthDisabledError: + record = None + except Exception as e: + logger.error(f"[upload_service] Failed to validate auth session for apply_ops: {e}") + raise HTTPException(status_code=500, detail="Failed to validate auth session") + if record is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired session", + ) + + # Resolve collection name for ACL enforcement + collection_name, repo_name = _resolve_collection_for_request( + workspace_path=workspace_path, + client_collection_name=request.collection_name, + logical_repo_id=request.logical_repo_id, + source_path=request.source_path, + ) + + # Enforce collection write access for plan/apply when auth is enabled + if AUTH_ENABLED and CTXCE_MCP_ACL_ENFORCE: + if not collection_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Collection resolution failed for ACL enforcement", + ) + uid = str((record or {}).get("user_id") or "").strip() + if not uid: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired session", + ) + try: + allowed = has_collection_access(uid, str(collection_name), "write") + except AuthDisabledError: + allowed = True + if not allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"User does not have write access to collection '{collection_name}'", + ) + + manifest = request.manifest or {} + bundle_id = manifest.get("bundle_id") + manifest_sequence = manifest.get("sequence_number") + key = get_workspace_key(workspace_path) + last_sequence = get_last_sequence(workspace_path) + sequence_number = manifest_sequence if manifest_sequence is not None else last_sequence + 1 + + if sequence_number is not None and sequence_number != last_sequence + 1: + return UploadResponse( + success=False, + error={ + "code": "SEQUENCE_MISMATCH", + "message": f"Expected sequence {last_sequence + 1}, got {sequence_number}", + "expected_sequence": last_sequence + 1, + "received_sequence": sequence_number, + "retry_after": 5000, + }, + ) + + start_time = datetime.now() + _upload_result_tracker[key] = { + "workspace_path": workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": None, + "processing_time_ms": None, + "status": "processing", + "completed_at": None, + } + + operations_count = await asyncio.to_thread( + apply_delta_operations, + workspace_path, + request.operations, + request.file_hashes, + ) + processing_time = int((datetime.now() - start_time).total_seconds() * 1000) + failed_count = int((operations_count or {}).get("failed") or 0) + applied_count = int( + (operations_count or {}).get("created", 0) + + (operations_count or {}).get("updated", 0) + + (operations_count or {}).get("deleted", 0) + + (operations_count or {}).get("moved", 0) + ) + status_value = "completed" if failed_count == 0 else "failed" + if applied_count > 0: + _sequence_tracker[key] = sequence_number + _upload_result_tracker[key] = { + "workspace_path": workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": operations_count, + "processing_time_ms": processing_time, + "status": status_value, + "failed_count": failed_count, + "partial": bool(failed_count > 0 and applied_count > 0), + "completed_at": datetime.now().isoformat(), + } + if failed_count > 0: + logger.warning( + "[upload_service] apply_ops completed with failures bundle=%s seq=%s failed=%d ops=%s", + bundle_id, + sequence_number, + failed_count, + operations_count, + ) + return UploadResponse( + success=False, + bundle_id=bundle_id, + sequence_number=sequence_number, + processed_operations=operations_count, + processing_time_ms=processing_time, + next_sequence=sequence_number + 1 if sequence_number is not None else None, + error={ + "code": "APPLY_OPS_PARTIAL_FAILURE", + "message": f"One or more operations failed during apply_ops (failed={failed_count})", + "failed_count": failed_count, + "processed_operations": operations_count, + }, + ) + logger.info( + "[upload_service] Applied metadata-only operations bundle=%s seq=%s in %sms ops=%s", + bundle_id, + sequence_number, + processing_time, + operations_count, + ) + return UploadResponse( + success=True, + bundle_id=bundle_id, + sequence_number=sequence_number, + processed_operations=operations_count, + processing_time_ms=processing_time, + next_sequence=sequence_number + 1 if sequence_number is not None else None, + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"[upload_service] Error applying metadata-only operations: {e}") + if key: + _upload_result_tracker[key] = { + "workspace_path": request.workspace_path, + "bundle_id": bundle_id, + "sequence_number": sequence_number, + "processed_operations": None, + "processing_time_ms": None, + "status": "error", + "error": str(e), + "message": str(e), + "completed_at": datetime.now().isoformat(), + } + return UploadResponse( + success=False, + error={ + "code": "APPLY_OPS_ERROR", + "message": str(e), + }, + ) + @app.post("/api/v1/delta/upload", response_model=UploadResponse) async def upload_delta_bundle( request: Request, @@ -1422,60 +1899,13 @@ async def upload_delta_bundle( workspace_path = str(workspace.resolve()) - # Always derive repo_name from workspace_path for origin tracking - repo_name = _extract_repo_name_from_path(workspace_path) if _extract_repo_name_from_path else None - if not repo_name: - repo_name = Path(workspace_path).name - - # Preserve any client-supplied collection name but allow server-side overrides - client_collection_name = collection_name - resolved_collection: Optional[str] = None - - # Resolve collection name, preferring server-side mapping for logical_repo_id when enabled - if logical_repo_reuse_enabled() and logical_repo_id and find_collection_for_logical_repo: - try: - existing = find_collection_for_logical_repo(logical_repo_id, search_root=WORK_DIR) - except Exception: - existing = None - if existing: - resolved_collection = existing - - # Latent migration: when no explicit mapping exists yet for this logical_repo_id, but there is a - # single existing collection mapping, prefer reusing it rather than creating a fresh collection. - if logical_repo_reuse_enabled() and logical_repo_id and resolved_collection is None and get_collection_mappings: - try: - mappings = get_collection_mappings(search_root=WORK_DIR) or [] - except Exception: - mappings = [] - - if len(mappings) == 1: - canonical = mappings[0] - canonical_coll = canonical.get("collection_name") - if canonical_coll: - resolved_collection = canonical_coll - if update_workspace_state: - try: - update_workspace_state( - workspace_path=canonical.get("container_path") or canonical.get("state_file"), - updates={"logical_repo_id": logical_repo_id}, - repo_name=canonical.get("repo_name"), - ) - except Exception as migrate_err: - logger.debug( - f"[upload_service] Failed to migrate logical_repo_id for existing mapping: {migrate_err}" - ) - - # Finalize collection_name: prefer resolved server-side mapping, then client-supplied name, - # then standard get_collection_name/DEFAULT_COLLECTION fallbacks. - if resolved_collection is not None: - collection_name = resolved_collection - elif client_collection_name: - collection_name = client_collection_name - else: - if get_collection_name and repo_name: - collection_name = get_collection_name(repo_name) - else: - collection_name = DEFAULT_COLLECTION + # Resolve collection name and repo name + collection_name, repo_name = _resolve_collection_for_request( + workspace_path=workspace_path, + client_collection_name=collection_name, + logical_repo_id=logical_repo_id, + source_path=source_path, + ) # Enforce collection write access for uploads when auth is enabled. # Semantics: "write" is sufficient for uploading/indexing content. diff --git a/scripts/warm_all_collections.py b/scripts/warm_all_collections.py index 0344da82..19f37241 100644 --- a/scripts/warm_all_collections.py +++ b/scripts/warm_all_collections.py @@ -5,6 +5,7 @@ import os import sys import subprocess +from pathlib import Path from qdrant_client import QdrantClient def main(): @@ -12,6 +13,7 @@ def main(): qdrant_url = os.environ.get("QDRANT_URL", "http://qdrant:6333") ef = os.environ.get("EF", "256") limit = os.environ.get("LIMIT", "3") + script_dir = Path(__file__).resolve().parent print(f"Connecting to Qdrant at {qdrant_url}") @@ -37,8 +39,8 @@ def main(): result = subprocess.run( [ - "python", - "/app/scripts/warm_start.py", + sys.executable or "python", + str(script_dir / "warm_start.py"), "--ef", ef, "--limit", limit ], diff --git a/scripts/warm_start.py b/scripts/warm_start.py index 3119721b..a024f33a 100644 --- a/scripts/warm_start.py +++ b/scripts/warm_start.py @@ -1,15 +1,10 @@ #!/usr/bin/env python3 import os import argparse -import sys -from pathlib import Path from qdrant_client import QdrantClient, models -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - from scripts.utils import sanitize_vector_name +from scripts.embedder import get_embedding_model as _get_model # Warm start: load embedding model and warm Qdrant HNSW search path with a small query # Useful to reduce first-query latency and set a higher runtime ef for quality @@ -21,14 +16,7 @@ def derive_vector_name(model_name: str) -> str: def get_embedding_model(model_name: str): """Get embedding model with Qwen3 support via embedder factory.""" - try: - from scripts.embedder import get_embedding_model as _get_model - return _get_model(model_name) - except ImportError: - pass - # Fallback to direct fastembed - from fastembed import TextEmbedding - return TextEmbedding(model_name=model_name) + return _get_model(model_name) def main(): diff --git a/scripts/watch_index.py b/scripts/watch_index.py index 8fe5a740..0a9890fe 100644 --- a/scripts/watch_index.py +++ b/scripts/watch_index.py @@ -2,53 +2,51 @@ from __future__ import annotations import os -import sys import time +from collections import Counter from pathlib import Path from typing import Optional from qdrant_client import QdrantClient from watchdog.observers import Observer -ROOT_DIR = Path(__file__).resolve().parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) - -from scripts.watch_index_core.config import ( # noqa: E402 - LOGGER, - MODEL, - QDRANT_URL, - ROOT as WATCH_ROOT, - default_collection_name, -) +from scripts.watch_index_core import config as watch_config +from scripts.watch_index_core.config import LOGGER, MODEL, QDRANT_URL, default_collection_name from scripts.watch_index_core.utils import ( get_boolean_env, resolve_vector_name_config, create_observer, ) -from scripts.watch_index_core.handler import IndexHandler # noqa: E402 -from scripts.watch_index_core.pseudo import _start_pseudo_backfill_worker # noqa: E402 -from scripts.watch_index_core.processor import _process_paths # noqa: E402 -from scripts.watch_index_core.queue import ChangeQueue # noqa: E402 -from scripts.workspace_state import ( # noqa: E402 - _extract_repo_name_from_path, +from scripts.watch_index_core.handler import IndexHandler +from scripts.watch_index_core.init_maintenance import start_init_maintenance_worker +from scripts.watch_index_core.pseudo import _start_pseudo_backfill_worker +from scripts.watch_index_core.processor import _process_paths +from scripts.watch_index_core.queue import ChangeQueue +from scripts.watch_index_core.consistency import ( + run_consistency_audit, + run_empty_dir_sweep_maintenance, +) +from scripts.workspace_state import ( compute_indexing_config_hash, - get_collection_name, get_indexing_config_snapshot, + list_pending_index_journal_entries, is_multi_repo_mode, persist_indexing_config, update_indexing_status, - update_workspace_state, initialize_watcher_state, ) -import scripts.ingest_code as idx # noqa: E402 +_sleep = time.sleep + +import scripts.ingest_code as idx logger = LOGGER -ROOT = WATCH_ROOT +ROOT = watch_config.ROOT # Back-compat: legacy modules/tests expect a module-level COLLECTION constant. # We use a sentinel and a getter to ensure the resolved value is returned. _COLLECTION: Optional[str] = None +_JOURNAL_DRAIN_LAST_LOG = 0.0 +_JOURNAL_DRAIN_LAST_TOTAL = 0 def get_collection() -> str: @@ -58,7 +56,132 @@ def get_collection() -> str: return default_collection_name() +def _set_runtime_root() -> None: + global ROOT + runtime_root = Path( + os.environ.get("WATCH_ROOT") + or os.environ.get("WORKSPACE_PATH") + or str(ROOT) + ) + try: + runtime_root = runtime_root.resolve() + except Exception: + pass + + ROOT = runtime_root + watch_config.ROOT = runtime_root + + +def _journal_log_interval_secs() -> float: + try: + return max(0.0, float(os.environ.get("WATCH_JOURNAL_LOG_INTERVAL_SECS", "120") or 120.0)) + except Exception: + return 120.0 + + +def _maybe_log_journal_drain( + *, + total: int, + queued: int, + op_counts: Counter[str], + queue: ChangeQueue, +) -> None: + global _JOURNAL_DRAIN_LAST_LOG, _JOURNAL_DRAIN_LAST_TOTAL + now = time.time() + interval = _journal_log_interval_secs() + should_log = False + if total <= 0 and _JOURNAL_DRAIN_LAST_TOTAL > 0: + should_log = True + elif total > 0 and (_JOURNAL_DRAIN_LAST_LOG <= 0 or (now - _JOURNAL_DRAIN_LAST_LOG) >= interval): + should_log = True + if not should_log: + _JOURNAL_DRAIN_LAST_TOTAL = total + return + + queue_stats = {} + try: + queue_stats = queue.stats() + except Exception: + queue_stats = {} + logger.info( + "watch_index::journal_drain backlog=%d queued=%d ops=%s queue=%s", + total, + queued, + dict(op_counts), + queue_stats, + extra={ + "root": str(ROOT), + "backlog": total, + "queued": queued, + "op_counts": dict(op_counts), + "queue_stats": queue_stats, + }, + ) + _JOURNAL_DRAIN_LAST_LOG = now + _JOURNAL_DRAIN_LAST_TOTAL = total + + +def _drain_pending_journal(queue: ChangeQueue) -> None: + pending_path: Optional[str] = None + try: + pending_entries = list_pending_index_journal_entries(str(ROOT)) + queued = 0 + op_counts: Counter[str] = Counter() + for pending_entry in pending_entries: + op_type = str(pending_entry.get("op_type") or "unknown").strip() or "unknown" + op_counts[op_type] += 1 + pending_path = str(pending_entry.get("path") or "").strip() + if pending_path: + queue.add(Path(pending_path), force=True) + queued += 1 + _maybe_log_journal_drain( + total=len(pending_entries), + queued=queued, + op_counts=op_counts, + queue=queue, + ) + except Exception as exc: + logger.exception( + "watch_index::pending_journal_drain_failed", + extra={"root": str(ROOT), "pending_path": pending_path, "error": str(exc)}, + ) + + +def _run_periodic_maintenance(client: QdrantClient) -> None: + try: + run_consistency_audit(client, ROOT) + except Exception as exc: + logger.exception( + "watch_index::consistency_audit_failed", + extra={"root": str(ROOT), "error": str(exc)}, + ) + try: + run_empty_dir_sweep_maintenance(ROOT) + except Exception as exc: + logger.exception( + "watch_index::empty_dir_sweep_failed", + extra={"root": str(ROOT), "error": str(exc)}, + ) + + +def _maintenance_interval_secs() -> float: + try: + return max(0.0, float(os.environ.get("WATCH_MAINTENANCE_INTERVAL_SECS", "300") or 300.0)) + except Exception: + return 300.0 + + +def _journal_drain_enabled(multi_repo_enabled: bool) -> bool: + return get_boolean_env("WATCH_JOURNAL_DRAIN_ENABLED", default=multi_repo_enabled) + + +def _fs_events_enabled(multi_repo_enabled: bool) -> bool: + return get_boolean_env("WATCH_FS_EVENTS_ENABLED", default=(not multi_repo_enabled)) + + def main() -> None: + _set_runtime_root() + # Resolve collection name from workspace state before any client/state ops try: from scripts.workspace_state import get_collection_name_with_staging as _get_coll @@ -94,6 +217,16 @@ def main() -> None: f"Watch mode: root={ROOT} qdrant={QDRANT_URL} collection={default_collection} model={MODEL}" ) + # Guardrail: deferring pseudo to a worker only makes sense if the worker is enabled. + # Otherwise you'd silently disable pseudo generation (old behavior). + pseudo_defer = get_boolean_env("PSEUDO_DEFER_TO_WORKER") + pseudo_backfill_enabled = get_boolean_env("PSEUDO_BACKFILL_ENABLED") + if pseudo_defer and not pseudo_backfill_enabled: + print( + "[pseudo] Warning: PSEUDO_DEFER_TO_WORKER=1 but PSEUDO_BACKFILL_ENABLED=0; " + "inline pseudo will remain enabled (no deferral)." + ) + # Health check: detect and auto-heal cache/collection sync issues. # In multi-repo mode this can be expensive and may duplicate external init checks, # so default it OFF unless explicitly enabled. @@ -137,18 +270,10 @@ def main() -> None: url=QDRANT_URL, timeout=int(os.environ.get("QDRANT_TIMEOUT", "20") or 20) ) - # Use centralized embedder factory if available (supports Qwen3 feature flag) - try: - from scripts.embedder import get_embedding_model, get_model_dimension - - model = get_embedding_model(MODEL) - model_dim = get_model_dimension(MODEL) - except ImportError: - # Fallback to direct fastembed initialization - from fastembed import TextEmbedding + from scripts.embedder import get_embedding_model, get_model_dimension - model = TextEmbedding(model_name=MODEL) - model_dim = len(next(model.embed(["dimension probe"]))) + model = get_embedding_model(MODEL) + model_dim = get_model_dimension(MODEL) vector_name = resolve_vector_name_config(client, default_collection, model_dim, MODEL) @@ -166,7 +291,17 @@ def main() -> None: except Exception: pass - _start_pseudo_backfill_worker(client, default_collection, model_dim, vector_name) + # Start backfill worker even in multi-repo mode; it uses workspace mappings and + # will no-op if disabled. Only allow a single-repo fallback to the default + # collection when startup was explicitly permitted to touch that collection. + _start_pseudo_backfill_worker( + client, + default_collection, + model_dim, + vector_name, + allow_default_collection_fallback=ensure_default_collection, + ) + init_maintenance_shutdown = start_init_maintenance_worker() try: initialize_watcher_state(str(ROOT), multi_repo_enabled, default_collection) @@ -178,21 +313,45 @@ def main() -> None: paths, client, model, vector_name, model_dim, str(ROOT) ) ) - handler = IndexHandler(ROOT, q, client, default_collection) + journal_drain_enabled = _journal_drain_enabled(multi_repo_enabled) + fs_events_enabled = _fs_events_enabled(multi_repo_enabled) + + print( + "[watch_mode] sources " + f"journal_drain={'on' if journal_drain_enabled else 'off'} " + f"fs_events={'on' if fs_events_enabled else 'off'}" + ) + + obs = None + if fs_events_enabled: + handler = IndexHandler(ROOT, q, client, default_collection) + use_polling = get_boolean_env("WATCH_USE_POLLING") + obs = create_observer(use_polling, observer_cls=Observer) + obs.schedule(handler, str(ROOT), recursive=True) + obs.start() - use_polling = get_boolean_env("WATCH_USE_POLLING") - obs = create_observer(use_polling, observer_cls=Observer) - obs.schedule(handler, str(ROOT), recursive=True) - obs.start() + maintenance_interval = _maintenance_interval_secs() + last_maintenance: Optional[float] = None try: while True: - time.sleep(1.0) + if journal_drain_enabled: + # Upload/apply records upsert/delete intent here so missed filesystem + # events can still be replayed after watcher/container restarts. + _drain_pending_journal(q) + now = time.time() + if last_maintenance is None or (now - last_maintenance) >= maintenance_interval: + _run_periodic_maintenance(client) + last_maintenance = now + _sleep(1.0) except KeyboardInterrupt: pass finally: - obs.stop() - obs.join() + if init_maintenance_shutdown is not None: + init_maintenance_shutdown.set() + if obs is not None: + obs.stop() + obs.join() if __name__ == "__main__": diff --git a/scripts/watch_index_core/config.py b/scripts/watch_index_core/config.py index c9fa8354..600f840c 100644 --- a/scripts/watch_index_core/config.py +++ b/scripts/watch_index_core/config.py @@ -3,15 +3,12 @@ from __future__ import annotations import os -import sys from pathlib import Path from scripts.logger import get_logger ROOT_DIR = Path(__file__).resolve().parent.parent.parent -if str(ROOT_DIR) not in sys.path: - sys.path.insert(0, str(ROOT_DIR)) def build_logger(): @@ -33,6 +30,12 @@ def build_logger(): # Debounce interval for file system events DELAY_SECS = float(os.environ.get("WATCH_DEBOUNCE_SECS", "1.0")) +# Suppress repeated processing of the exact same observed file state for a short +# window. This is especially useful on shared/polled filesystems like CephFS. +RECENT_FINGERPRINT_TTL_SECS = float( + os.environ.get("WATCH_RECENT_FINGERPRINT_TTL_SECS", "0") +) + def default_collection_name() -> str: """Base fallback for collection name before runtime resolution.""" diff --git a/scripts/watch_index_core/consistency.py b/scripts/watch_index_core/consistency.py new file mode 100644 index 00000000..c3779837 --- /dev/null +++ b/scripts/watch_index_core/consistency.py @@ -0,0 +1,703 @@ +from __future__ import annotations + +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Optional, Set, Tuple + +from qdrant_client import QdrantClient + +import scripts.ingest_code as idx +from scripts.workspace_state import ( + _get_state_lock, + _extract_repo_name_from_path, + _normalize_cache_key_path, + get_collection_state_snapshot, + get_workspace_state, + list_workspaces, + update_workspace_state, + upsert_index_journal_entries, +) + +from .config import LOGGER +from .utils import get_boolean_env +from .paths import is_internal_metadata_path + +logger = LOGGER +_DEFAULT_EMPTY_DIR_SWEEP_INTERVAL_SECONDS = 7 * 24 * 60 * 60 + + +def _consistency_audit_enabled() -> bool: + return get_boolean_env("WATCH_CONSISTENCY_AUDIT_ENABLED", default=True) + + +def _consistency_audit_interval_secs() -> int: + try: + return max(60, int(os.environ.get("WATCH_CONSISTENCY_AUDIT_INTERVAL_SECS", "86400") or 86400)) + except Exception: + return 86400 + + +def _consistency_audit_max_paths() -> int: + try: + return max(0, int(os.environ.get("WATCH_CONSISTENCY_AUDIT_MAX_PATHS", "200000") or 200000)) + except Exception: + return 200000 + + +def _consistency_repair_enabled() -> bool: + return get_boolean_env("WATCH_CONSISTENCY_REPAIR_ENABLED", default=True) + + +def _consistency_repair_max_ops() -> int: + try: + return max(0, int(os.environ.get("WATCH_CONSISTENCY_REPAIR_MAX_OPS", "5000") or 5000)) + except Exception: + return 5000 + + +def _consistency_graph_audit_enabled() -> bool: + return get_boolean_env("WATCH_CONSISTENCY_AUDIT_GRAPH_ENABLED", default=True) + + +def _empty_dir_sweep_enabled() -> bool: + if "WATCH_EMPTY_DIR_SWEEP_ENABLED" in os.environ: + return get_boolean_env("WATCH_EMPTY_DIR_SWEEP_ENABLED", default=True) + return get_boolean_env("CTXCE_UPLOAD_EMPTY_DIR_SWEEP", default=True) + + +def _empty_dir_sweep_interval_secs() -> int: + raw = os.environ.get("WATCH_EMPTY_DIR_SWEEP_INTERVAL_SECONDS") + if raw is None: + raw = os.environ.get( + "CTXCE_UPLOAD_EMPTY_DIR_SWEEP_INTERVAL_SECONDS", + str(_DEFAULT_EMPTY_DIR_SWEEP_INTERVAL_SECONDS), + ) + try: + return max(0, int(raw or _DEFAULT_EMPTY_DIR_SWEEP_INTERVAL_SECONDS)) + except Exception: + return _DEFAULT_EMPTY_DIR_SWEEP_INTERVAL_SECONDS + + +def _parse_ts(value: Any) -> Optional[datetime]: + raw = str(value or "").strip() + if not raw: + return None + try: + parsed = datetime.fromisoformat(raw.replace("Z", "+00:00")) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + + +def _should_run_consistency_audit(workspace_path: str, repo_name: Optional[str]) -> bool: + if not _consistency_audit_enabled(): + return False + interval = _consistency_audit_interval_secs() + try: + state = get_workspace_state(workspace_path=workspace_path, repo_name=repo_name) or {} + except Exception: + return True + maintenance = dict(state.get("maintenance") or {}) + last = _parse_ts(maintenance.get("last_consistency_audit_at")) + if last is None: + return True + age = (datetime.now(timezone.utc) - last).total_seconds() + return age >= interval + + +def _sweep_empty_workspace_dirs(workspace_root: Path) -> bool: + """Sweep empty workspace directories and return True if fully successful.""" + protected_top_level = {".codebase", ".remote-git"} + try: + workspace_root = workspace_root.resolve() + except Exception: + return False + try: + for root, _dirnames, _filenames in os.walk(workspace_root, topdown=False): + current = Path(root) + if current == workspace_root: + continue + if current.parent == workspace_root and current.name in protected_top_level: + continue + try: + rel = current.relative_to(workspace_root) + except Exception: + continue + if rel.parts and rel.parts[0] in protected_top_level: + continue + try: + if any(current.iterdir()): + continue + current.rmdir() + except Exception: + # If any directory operation fails, the sweep was not fully successful + return False + except Exception: + return False + return True + + +def _should_run_empty_dir_sweep(workspace_path: str, repo_name: Optional[str]) -> bool: + if not _empty_dir_sweep_enabled(): + return False + interval_seconds = _empty_dir_sweep_interval_secs() + if interval_seconds == 0: + return True + try: + state = get_workspace_state(workspace_path=workspace_path, repo_name=repo_name) or {} + except Exception: + return True + maintenance = state.get("maintenance") or {} + last_sweep_at = _parse_ts(maintenance.get("last_empty_dir_sweep_at")) + if last_sweep_at is None: + return True + age_seconds = (datetime.now(timezone.utc) - last_sweep_at).total_seconds() + return age_seconds >= interval_seconds + + +def _record_empty_dir_sweep(workspace_path: str, repo_name: Optional[str]) -> None: + try: + lock = _get_state_lock(workspace_path, repo_name) + with lock: + state = get_workspace_state( + workspace_path=workspace_path, + repo_name=repo_name, + ) or {} + maintenance = dict(state.get("maintenance") or {}) + maintenance["last_empty_dir_sweep_at"] = datetime.now( + timezone.utc + ).isoformat() + update_workspace_state( + workspace_path=workspace_path, + repo_name=repo_name, + updates={"maintenance": maintenance}, + ) + except Exception as exc: + logger.warning( + "Failed to record empty dir sweep timestamp: %s (workspace=%s, repo=%s)", + exc, + workspace_path, + repo_name, + ) + + +def _load_cached_hashes( + workspace_path: str, + repo_name: Optional[str], + *, + metadata_root: Optional[Path] = None, +) -> Dict[str, str]: + workspace_norm = _normalize_cache_key_path(workspace_path) + workspace_prefix = f"{workspace_norm.rstrip('/')}/" + candidates: list[Path] = [] + seen: set[str] = set() + + def _append_candidate(path: Path) -> None: + key = str(path) + if key in seen: + return + seen.add(key) + candidates.append(path) + + root = Path(metadata_root or workspace_path) + if repo_name: + _append_candidate(root / ".codebase" / "repos" / repo_name / "cache.json") + else: + _append_candidate(root / ".codebase" / "cache.json") + + for cache_path in candidates: + if not cache_path.exists(): + continue + try: + with cache_path.open("r", encoding="utf-8-sig") as f: + data = json.load(f) + hashes = data.get("file_hashes", {}) + if not isinstance(hashes, dict): + return {} + normalized: Dict[str, str] = {} + for path_key, value in hashes.items(): + norm = _normalize_cache_key_path(str(path_key)) + if not norm: + continue + if workspace_norm and not ( + norm == workspace_norm or norm.startswith(workspace_prefix) + ): + continue + if isinstance(value, dict): + digest = str(value.get("hash") or "").strip() + else: + digest = str(value or "").strip() + normalized[norm] = digest + return normalized + except Exception: + return {} + return {} + + +def _is_index_eligible_path(path_str: str, workspace_root: Path, excluder) -> bool: + try: + p = Path(path_str).resolve() + except Exception: + p = Path(path_str) + try: + rel = p.resolve().relative_to(workspace_root.resolve()) + except Exception: + return False + + if not rel.parts: + return False + if not p.exists() or p.is_dir(): + return False + try: + if int(p.stat().st_size) == 0: + # Empty files (e.g. many __init__.py stubs) produce no vectors; do not + # enqueue consistency upserts for them. + return False + except Exception: + return False + if is_internal_metadata_path(p): + return False + + # .remote-git manifests are control files and must not be treated as indexable. + if _is_remote_git_manifest(p.as_posix()): + return False + + try: + rel_dir = "/" + str(rel.parent).replace(os.sep, "/") + if rel_dir == "/.": + rel_dir = "/" + if excluder.exclude_dir(rel_dir): + return False + except Exception: + return False + + if not idx.is_indexable_file(p): + return False + + try: + relf = (rel_dir.rstrip("/") + "/" + p.name).replace("//", "/") + if excluder.exclude_file(relf): + return False + except Exception: + return False + return True + + +def _scan_indexable_fs_paths(workspace_root: Path, *, max_paths: int) -> Tuple[Set[str], bool]: + paths: Set[str] = set() + excluder = idx._Excluder(workspace_root) + try: + workspace_root = workspace_root.resolve() + except Exception: + pass + + for root_str, dirnames, filenames in os.walk(workspace_root): + current = Path(root_str) + pruned_dirnames = [] + for dirname in dirnames: + child = current / dirname + if is_internal_metadata_path(child): + continue + try: + rel_dir = "/" + str(child.relative_to(workspace_root)).replace(os.sep, "/") + if excluder.exclude_dir(rel_dir): + continue + except Exception: + pass + pruned_dirnames.append(dirname) + dirnames[:] = pruned_dirnames + + for filename in filenames: + file_path = current / filename + normalized = _normalize_cache_key_path(str(file_path)) + if not normalized: + continue + if not _is_index_eligible_path(normalized, workspace_root, excluder): + continue + paths.add(normalized) + if max_paths > 0 and len(paths) >= max_paths: + return paths, True + return paths, False + + +def _load_indexed_paths_for_collection( + client: QdrantClient, + collection: str, + workspace_path: str, + *, + max_paths: int, +) -> Tuple[Set[str], bool]: + paths: Set[str] = set() + workspace_norm = _normalize_cache_key_path(workspace_path) + workspace_prefix = f"{workspace_norm.rstrip('/')}/" + offset = None + while True: + points, next_offset = client.scroll( + collection_name=collection, + limit=1000, + with_payload=True, + with_vectors=False, + offset=offset, + ) + for pt in points or []: + payload = getattr(pt, "payload", {}) or {} + metadata = payload.get("metadata", {}) or {} + path = _normalize_cache_key_path(str(metadata.get("path") or "")) + if path: + if workspace_norm and not ( + path == workspace_norm or path.startswith(workspace_prefix) + ): + continue + paths.add(path) + if max_paths > 0 and len(paths) >= max_paths: + return paths, True + if next_offset is None: + break + offset = next_offset + return paths, False + + +def _load_graph_paths_for_collection( + client: QdrantClient, + collection: str, + workspace_path: str, + *, + max_paths: int, +) -> Tuple[Set[str], bool]: + paths: Set[str] = set() + workspace_norm = _normalize_cache_key_path(workspace_path) + workspace_prefix = f"{workspace_norm.rstrip('/')}/" + graph_collection = f"{collection}_graph" + offset = None + while True: + points, next_offset = client.scroll( + collection_name=graph_collection, + limit=1000, + with_payload=True, + with_vectors=False, + offset=offset, + ) + for pt in points or []: + payload = getattr(pt, "payload", {}) or {} + path = _normalize_cache_key_path(str(payload.get("caller_path") or "")) + if path: + if workspace_norm and not ( + path == workspace_norm or path.startswith(workspace_prefix) + ): + continue + paths.add(path) + if max_paths > 0 and len(paths) >= max_paths: + return paths, True + if next_offset is None: + break + offset = next_offset + return paths, False + + +def _record_consistency_audit( + workspace_path: str, + repo_name: Optional[str], + summary: Dict[str, Any], +) -> None: + try: + lock = _get_state_lock(workspace_path, repo_name) + with lock: + state = get_workspace_state( + workspace_path=workspace_path, + repo_name=repo_name, + ) or {} + maintenance = dict(state.get("maintenance") or {}) + maintenance["last_consistency_audit_at"] = datetime.now( + timezone.utc + ).isoformat() + maintenance["last_consistency_audit_summary"] = summary + update_workspace_state( + workspace_path=workspace_path, + repo_name=repo_name, + updates={"maintenance": maintenance}, + ) + except Exception as exc: + logger.warning( + "Failed to record consistency audit: %s (workspace=%s, repo=%s)", + exc, + workspace_path, + repo_name, + ) + + +def _is_remote_git_manifest(path: str) -> bool: + """Check if path is a .remote-git git history manifest file (control file, not indexable content).""" + try: + p = Path(path) + return any(part == ".remote-git" for part in p.parts) and p.suffix.lower() == ".json" + except Exception: + return False + + +def _enqueue_consistency_repairs( + workspace_root: Path, + workspace_path: str, + repo_name: Optional[str], + stale_paths: list[str], + missing_paths: list[str], + cached_hashes: Dict[str, str], +) -> Tuple[int, int]: + if not _consistency_repair_enabled(): + return 0, 0 + max_ops = _consistency_repair_max_ops() + if max_ops <= 0: + return 0, 0 + + entries: list[Dict[str, Any]] = [] + enqueued_stale = 0 + enqueued_missing = 0 + missing_set = set(missing_paths) + excluder = idx._Excluder(workspace_root) + + for path in stale_paths: + if len(entries) >= max_ops: + break + # Skip .remote-git git history manifests - they are control files, not indexable content + if _is_remote_git_manifest(path): + continue + # Cache can lag after state resets/rebuilds; if the path still exists and is + # index-eligible, treat it as missing/upsert instead of stale/delete. + if _is_index_eligible_path(path, workspace_root, excluder): + missing_set.add(path) + continue + entries.append({"path": path, "op_type": "delete"}) + enqueued_stale += 1 + for path in sorted(missing_set): + if len(entries) >= max_ops: + break + # Skip .remote-git git history manifests - they are control files, not indexable content + if _is_remote_git_manifest(path): + continue + entries.append( + { + "path": path, + "op_type": "upsert", + "content_hash": cached_hashes.get(path) or None, + } + ) + enqueued_missing += 1 + + if not entries: + return 0, 0 + + # Fetch existing journal entries to preserve retry state + existing_entries: Dict[str, Dict[str, Any]] = {} + try: + from scripts.workspace_state import list_pending_index_journal_entries + all_pending = list_pending_index_journal_entries( + workspace_path=workspace_path, + repo_name=repo_name, + ) + for entry in all_pending or []: + path = str(entry.get("path") or "") + if path: + existing_entries[path] = entry + except Exception: + pass # If we can't fetch existing entries, proceed without preserving state + + # Merge existing retry state into new entries where appropriate + merged_entries = [] + for entry in entries: + path = str(entry.get("path") or "") + existing = existing_entries.get(path) + + # Skip if already pending/in-progress to avoid duplicate work + if existing and existing.get("status") in {"pending", "in_progress"}: + continue + + # Preserve retry state from existing failed entries + if existing and existing.get("status") == "failed": + entry["status"] = "failed" + entry["attempts"] = existing.get("attempts", 0) + entry["last_error"] = existing.get("last_error") + # Keep created_at from existing entry to preserve original enqueue time + if existing.get("created_at"): + entry["created_at"] = existing["created_at"] + + merged_entries.append(entry) + + if not merged_entries: + return 0, 0 + + try: + upsert_index_journal_entries( + merged_entries, + workspace_path=workspace_path, + repo_name=repo_name, + ) + except Exception as exc: + logger.debug( + "[consistency_audit] failed to enqueue repairs workspace=%s repo=%s: %s", + workspace_path, + repo_name, + exc, + ) + return 0, 0 + + # Return counts based on actually enqueued entries + enqueued_stale = sum(1 for e in merged_entries if e.get("op_type") == "delete") + enqueued_missing = sum(1 for e in merged_entries if e.get("op_type") == "upsert") + return enqueued_stale, enqueued_missing + + +def run_consistency_audit(client: QdrantClient, root: Path) -> None: + if not _consistency_audit_enabled(): + return + max_paths = _consistency_audit_max_paths() + try: + candidates = list_workspaces(search_root=str(root), use_qdrant_fallback=False) + except Exception: + candidates = [] + for ws in candidates: + workspace_path = str(ws.get("workspace_path") or "").strip() + if not workspace_path: + continue + repo_name = _extract_repo_name_from_path(workspace_path) + if not _should_run_consistency_audit(workspace_path, repo_name): + continue + try: + snapshot = get_collection_state_snapshot( + workspace_path=workspace_path, + repo_name=repo_name, + ) + collection = str(snapshot.get("active_collection") or "").strip() + if not collection: + continue + cached_hashes = _load_cached_hashes( + workspace_path, + repo_name, + metadata_root=root, + ) + workspace_root = Path(workspace_path) + fs_paths, fs_truncated = _scan_indexable_fs_paths( + workspace_root, + max_paths=max_paths, + ) + excluder = idx._Excluder(workspace_root) + cached_paths = { + path + for path in cached_hashes.keys() + if _is_index_eligible_path(path, workspace_root, excluder) + } + indexed_paths, indexed_truncated = _load_indexed_paths_for_collection( + client, + collection, + workspace_path, + max_paths=max_paths, + ) + graph_paths: Set[str] = set() + graph_truncated = False + graph_orphans: list[str] = [] + if _consistency_graph_audit_enabled(): + try: + graph_paths, graph_truncated = _load_graph_paths_for_collection( + client, + collection, + workspace_path, + max_paths=max_paths, + ) + except Exception: + graph_paths, graph_truncated = set(), False + if fs_truncated or indexed_truncated: + stale = [] + missing = [] + enq_stale = 0 + enq_missing = 0 + else: + stale_set = set(indexed_paths - fs_paths) + if not graph_truncated: + graph_orphans = sorted(graph_paths - indexed_paths) + stale_set.update(graph_orphans) + stale = sorted(stale_set) + missing = sorted(fs_paths - indexed_paths) + enq_stale, enq_missing = _enqueue_consistency_repairs( + workspace_root, + workspace_path, + repo_name, + stale, + missing, + cached_hashes, + ) + summary = { + "fs_count": len(fs_paths), + "cache_count": len(cached_paths), + "qdrant_count": len(indexed_paths), + "graph_count": len(graph_paths), + "fs_scan_truncated": fs_truncated, + "qdrant_scan_truncated": indexed_truncated, + "graph_scan_truncated": graph_truncated, + "repair_skipped_due_to_truncation": bool(fs_truncated or indexed_truncated), + "stale_in_qdrant_count": len(stale), + "missing_in_qdrant_count": len(missing), + "orphan_graph_count": len(graph_orphans), + "repair_enqueued_stale_count": int(enq_stale), + "repair_enqueued_missing_count": int(enq_missing), + "sample_stale": stale[:20], + "sample_missing": missing[:20], + "sample_orphan_graph": graph_orphans[:20], + } + _record_consistency_audit(workspace_path, repo_name, summary) + logger.info( + "[consistency_audit] repo=%s collection=%s fs=%d cache=%d qdrant=%d graph=%d stale=%d missing=%d graph_orphans=%d repair_stale=%d repair_missing=%d", + repo_name or "", + collection, + len(fs_paths), + len(cached_paths), + len(indexed_paths), + len(graph_paths), + len(stale), + len(missing), + len(graph_orphans), + int(enq_stale), + int(enq_missing), + ) + except Exception as exc: + logger.debug( + "[consistency_audit] failed workspace=%s repo=%s: %s", + workspace_path, + repo_name, + exc, + ) + + +def run_empty_dir_sweep_maintenance(root: Path) -> None: + if not _empty_dir_sweep_enabled(): + return + try: + candidates = list_workspaces(search_root=str(root), use_qdrant_fallback=False) + except Exception: + candidates = [] + for ws in candidates: + workspace_path = str(ws.get("workspace_path") or "").strip() + if not workspace_path: + continue + repo_name = _extract_repo_name_from_path(workspace_path) + if not _should_run_empty_dir_sweep(workspace_path, repo_name): + continue + try: + logger.info("[empty_dir_sweep] Sweeping empty directories under %s", workspace_path) + sweep_success = _sweep_empty_workspace_dirs(Path(workspace_path)) + if sweep_success: + _record_empty_dir_sweep(workspace_path, repo_name) + else: + logger.debug( + "[empty_dir_sweep] sweep had failures workspace=%s repo=%s - not recording success", + workspace_path, + repo_name, + ) + except Exception as exc: + logger.debug( + "[empty_dir_sweep] failed workspace=%s repo=%s: %s", + workspace_path, + repo_name, + exc, + ) diff --git a/scripts/watch_index_core/handler.py b/scripts/watch_index_core/handler.py index bf5cb6d9..36c4d457 100644 --- a/scripts/watch_index_core/handler.py +++ b/scripts/watch_index_core/handler.py @@ -12,7 +12,6 @@ import scripts.ingest_code as idx from scripts.workspace_state import ( _extract_repo_name_from_path, - _get_global_state_dir, get_cached_file_hash, log_watcher_activity as _log_activity, remove_cached_file, @@ -27,6 +26,7 @@ safe_print, ) from .rename import _rename_in_store +from .paths import is_internal_metadata_path class IndexHandler(FileSystemEventHandler): @@ -81,6 +81,9 @@ def _maybe_reload_excluder(self) -> None: except Exception: pass + def _is_internal_metadata_path(self, p: Path) -> bool: + return is_internal_metadata_path(p) + def _maybe_enqueue(self, src_path: str) -> None: self._maybe_reload_excluder() p = Path(src_path) @@ -95,15 +98,7 @@ def _maybe_enqueue(self, src_path: str) -> None: except ValueError: return - try: - if callable(_get_global_state_dir): - global_state_dir = _get_global_state_dir() - if global_state_dir is not None and p.is_relative_to(global_state_dir): - return - except (OSError, ValueError): - pass - - if any(part == ".codebase" for part in p.parts): + if self._is_internal_metadata_path(p): return # Git history manifests are handled by a separate ingestion pipeline and should still @@ -140,7 +135,7 @@ def on_deleted(self, event): p = Path(event.src_path).resolve() except Exception: return - if any(part == ".codebase" for part in p.parts): + if self._is_internal_metadata_path(p): return if not idx.is_indexable_file(p): return @@ -162,6 +157,42 @@ def on_moved(self, event): dest = Path(event.dest_path).resolve() except Exception: return + # Handle internal-boundary moves properly + src_internal = self._is_internal_metadata_path(src) + dest_internal = self._is_internal_metadata_path(dest) + if src_internal and dest_internal: + # Both internal -> ignore + return + if dest_internal: + # External -> internal: delete source, don't index destination + if idx.is_indexable_file(src): + try: + coll = self._resolve_collection(src) + deleted = False + if self.client is not None and coll is not None: + idx.delete_points_by_path(self.client, coll, str(src)) + # Clean up graph edges for the moved file + try: + idx.delete_graph_edges_by_path( + self.client, + coll, + caller_path=str(src), + ) + except Exception: + pass # Graph cleanup is best-effort + deleted = True + if deleted: + safe_print(f"[moved:external_to_internal] deleted {src}") + except Exception as exc: + safe_print(f"[moved:external_to_internal:error] {src}: {exc}") + finally: + self._invalidate_cache(src) + return + if src_internal: + # Internal -> external: index destination as new file + if idx.is_indexable_file(dest): + self._maybe_enqueue(str(dest)) + return if not idx.is_indexable_file(dest) and not idx.is_indexable_file(src): return try: @@ -174,18 +205,25 @@ def on_moved(self, event): if idx.is_indexable_file(src): try: coll = self._resolve_collection(src) + deleted = False if self.client is not None and coll is not None: idx.delete_points_by_path(self.client, coll, str(src)) - safe_print(f"[moved:ignored_dest_deleted_src] {src} -> {dest}") - src_repo_path = _detect_repo_for_file(src) - src_repo_name = _repo_name_or_none(src_repo_path) - try: - if src_repo_name: - remove_cached_file(str(src), src_repo_name) - except Exception: - pass - except Exception: - pass + # Clean up graph edges for the moved file + try: + idx.delete_graph_edges_by_path( + self.client, + coll, + caller_path=str(src), + ) + except Exception: + pass # Graph cleanup is best-effort + deleted = True + if deleted: + safe_print(f"[moved:ignored_dest_deleted_src] {src} -> {dest}") + except Exception as exc: + safe_print(f"[moved:ignored_dest_deleted_src:error] {src}: {exc}") + finally: + self._invalidate_cache(src) return except Exception: pass @@ -270,6 +308,14 @@ def _delete_points(self, path: Path, collection: str | None) -> None: return try: idx.delete_points_by_path(self.client, collection, str(path)) + try: + idx.delete_graph_edges_by_path( + self.client, + collection, + caller_path=str(path), + ) + except Exception: + pass safe_print(f"[deleted] {path} -> {collection}") except Exception: pass diff --git a/scripts/watch_index_core/init_maintenance.py b/scripts/watch_index_core/init_maintenance.py new file mode 100644 index 00000000..bb2dece9 --- /dev/null +++ b/scripts/watch_index_core/init_maintenance.py @@ -0,0 +1,191 @@ +"""Periodic bootstrap/init maintenance for the long-lived watcher.""" + +from __future__ import annotations + +import os +import subprocess +import sys +import threading +from pathlib import Path +from typing import Optional, Sequence + +from scripts.workspace_state import _cross_process_lock, _get_global_state_dir + +from . import config as watch_config +from .config import LOGGER +from .utils import get_boolean_env + +logger = LOGGER + +DEFAULT_INTERVAL_MINUTES = 120.0 +DEFAULT_COMMAND_TIMEOUT_SECONDS = 1800.0 + + +def _interval_seconds() -> float: + raw = os.environ.get("WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES") + if raw is None: + raw = os.environ.get("INIT_MAINTENANCE_INTERVAL_MINUTES") + try: + minutes = float(raw if raw is not None else DEFAULT_INTERVAL_MINUTES) + except Exception: + minutes = DEFAULT_INTERVAL_MINUTES + return max(0.0, minutes * 60.0) + + +def _command_timeout_seconds() -> float: + try: + return max( + 1.0, + float( + os.environ.get( + "WATCH_INIT_MAINTENANCE_COMMAND_TIMEOUT_SECS", + str(DEFAULT_COMMAND_TIMEOUT_SECONDS), + ) + or DEFAULT_COMMAND_TIMEOUT_SECONDS + ), + ) + except Exception: + return DEFAULT_COMMAND_TIMEOUT_SECONDS + + +def _script_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def _wait_for_qdrant_command(script_root: Path) -> list[str]: + return [str(script_root / "wait-for-qdrant.sh")] + + +def _python_script_command(script_root: Path, script_name: str) -> list[str]: + return [sys.executable or "python", str(script_root / script_name)] + + +def _maintenance_commands(script_root: Optional[Path] = None) -> list[list[str]]: + scripts = script_root or _script_root() + return [ + _wait_for_qdrant_command(scripts), + _python_script_command(scripts, "create_indexes.py"), + _python_script_command(scripts, "warm_all_collections.py"), + _python_script_command(scripts, "health_check.py"), + ] + + +def _env_for_subprocess() -> dict[str, str]: + env = os.environ.copy() + root = str(Path(__file__).resolve().parents[2]) + existing = env.get("PYTHONPATH") + env["PYTHONPATH"] = f"{root}{os.pathsep}{existing}" if existing else root + if str(watch_config.ROOT): + env.setdefault("WORKSPACE_PATH", str(watch_config.ROOT)) + env.setdefault("WORKDIR", str(watch_config.ROOT)) + env.setdefault("WORK_DIR", str(watch_config.ROOT)) + return env + + +def _run_command(command: Sequence[str], *, timeout: float, env: dict[str, str]) -> bool: + label = " ".join(str(part) for part in command) + logger.info("[init_maintenance] running: %s", label) + try: + result = subprocess.run( + list(command), + cwd=str(watch_config.ROOT), + env=env, + text=True, + capture_output=True, + timeout=timeout, + check=False, + ) + except subprocess.TimeoutExpired: + logger.error("[init_maintenance] timed out after %.0fs: %s", timeout, label) + return False + except Exception as exc: + logger.error("[init_maintenance] failed to start %s: %s", label, exc, exc_info=True) + return False + + if result.returncode == 0: + logger.info("[init_maintenance] completed: %s", label) + if result.stdout: + logger.debug("[init_maintenance] stdout for %s:\n%s", label, result.stdout[-4000:]) + if result.stderr: + logger.debug("[init_maintenance] stderr for %s:\n%s", label, result.stderr[-4000:]) + return True + + logger.warning( + "[init_maintenance] command failed rc=%s: %s\nstdout:\n%s\nstderr:\n%s", + result.returncode, + label, + (result.stdout or "")[-4000:], + (result.stderr or "")[-4000:], + ) + return False + + +def run_init_maintenance_once( + *, + commands: Optional[Sequence[Sequence[str]]] = None, + lock_path: Optional[Path] = None, +) -> bool: + """Run the existing init scripts once under a cross-process lock.""" + + timeout = _command_timeout_seconds() + env = _env_for_subprocess() + cmd_list = [list(cmd) for cmd in (commands or _maintenance_commands())] + if not cmd_list: + return True + + target_lock = lock_path + if target_lock is None: + try: + target_lock = _get_global_state_dir(str(watch_config.ROOT)) / "init_maintenance.lock" + except Exception: + target_lock = Path("/tmp/context-engine-init-maintenance.lock") + + with _cross_process_lock(target_lock): + for command in cmd_list: + if not _run_command(command, timeout=timeout, env=env): + return False + return True + + +def start_init_maintenance_worker() -> Optional[threading.Event]: + """Start periodic init maintenance, controlled by watcher env vars.""" + + if not get_boolean_env("WATCH_INIT_MAINTENANCE_ENABLED", default=True): + return None + + interval = _interval_seconds() + if interval <= 0: + return None + + run_on_start = get_boolean_env("WATCH_INIT_MAINTENANCE_RUN_ON_START", default=False) + shutdown_event = threading.Event() + + def _worker() -> None: + if not run_on_start: + shutdown_event.wait(timeout=interval) + while not shutdown_event.is_set(): + try: + ok = run_init_maintenance_once() + if ok: + logger.info("[init_maintenance] pass completed") + else: + logger.warning("[init_maintenance] pass completed with failures") + except Exception: + logger.error("[init_maintenance] unexpected worker error", exc_info=True) + shutdown_event.wait(timeout=interval) + + thread = threading.Thread(target=_worker, name="init-maintenance", daemon=True) + thread.start() + logger.info( + "[init_maintenance] worker started interval=%.1fm run_on_start=%s", + interval / 60.0, + run_on_start, + ) + return shutdown_event + + +__all__ = [ + "DEFAULT_INTERVAL_MINUTES", + "run_init_maintenance_once", + "start_init_maintenance_worker", +] diff --git a/scripts/watch_index_core/paths.py b/scripts/watch_index_core/paths.py new file mode 100644 index 00000000..2e76cfb9 --- /dev/null +++ b/scripts/watch_index_core/paths.py @@ -0,0 +1,36 @@ +"""Path classification helpers shared by watcher components.""" + +from __future__ import annotations + +from pathlib import Path + +from scripts.workspace_state import ( + _get_global_state_dir, + INTERNAL_STATE_TOP_LEVEL_DIRS, +) + + +def is_internal_metadata_path(path: Path) -> bool: + """Return True when path points into watcher/internal metadata trees.""" + try: + # Deliberately match internal segments anywhere in the path to prevent + # indexing of nested metadata mirrors (for example in replicated roots). + if any(part in INTERNAL_STATE_TOP_LEVEL_DIRS for part in path.parts): + return True + global_state_dir = _get_global_state_dir() + if global_state_dir is not None and path.is_relative_to(global_state_dir): + return True + except (OSError, ValueError): + return False + return False + + +def is_internal_top_level_path(path: Path, root: Path) -> bool: + """Return True when path's top-level segment under root is internal metadata.""" + try: + rel = path.resolve().relative_to(root.resolve()) + except Exception: + return False + if not rel.parts: + return False + return rel.parts[0] in INTERNAL_STATE_TOP_LEVEL_DIRS diff --git a/scripts/watch_index_core/processor.py b/scripts/watch_index_core/processor.py index 45e9db7e..822089e7 100644 --- a/scripts/watch_index_core/processor.py +++ b/scripts/watch_index_core/processor.py @@ -3,28 +3,46 @@ from __future__ import annotations import hashlib +import json import os import subprocess import sys +import atexit +import threading +import time +from collections import deque +from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime from pathlib import Path from typing import Dict, List, Optional +from qdrant_client import models + import scripts.ingest_code as idx +from scripts.pseudo_config import effective_pseudo_mode +from scripts.ingest.graph_edges import ( + normalize_caller_path as _normalize_graph_caller_path, +) from scripts.workspace_state import ( + _normalize_cache_key_path, _extract_repo_name_from_path, get_cached_file_hash, + list_pending_index_journal_entries, get_workspace_state, is_staging_enabled, log_watcher_activity as _log_activity, persist_indexing_config, remove_cached_file, + set_cached_file_hash, set_indexing_progress as _update_progress, set_indexing_started as _set_status_indexing, + update_index_journal_entry_status, update_indexing_status, ) +from . import config as watch_config +from .rename import _rename_in_store +from .paths import is_internal_metadata_path -from .config import QDRANT_URL, ROOT, ROOT_DIR, LOGGER as logger from .utils import ( _detect_repo_for_file, _get_collection_for_file, @@ -33,43 +51,294 @@ safe_log_error, ) +logger = watch_config.LOGGER + class _SkipUnchanged(Exception): """Sentinel exception to skip unchanged files in the watch loop.""" + def __init__(self, *, text: Optional[str] = None, file_hash: str = "") -> None: + super().__init__("unchanged") + self.text = text + self.file_hash = file_hash -def _process_git_history_manifest( + +def _is_internal_ignored_path(path: Path) -> bool: + return is_internal_metadata_path(path) + + +def _staging_requires_subprocess(state: Optional[Dict[str, object]]) -> bool: + """Return True only when dual-root staging is actually active for this repo.""" + if not (is_staging_enabled() and isinstance(state, dict)): + return False + + staging = state.get("staging") + if isinstance(staging, dict) and staging: + return True + + active_slug = str(state.get("active_repo_slug") or "").strip() + serving_slug = str(state.get("serving_repo_slug") or "").strip() + if serving_slug.endswith("_old"): + return True + if active_slug and serving_slug and active_slug != serving_slug: + return True + return False + + +def _env_int(name: str, default: int) -> int: + try: + raw = str(os.environ.get(name, str(default))).strip() + val = int(raw) + return val if val > 0 else default + except Exception: + return default + + +_GIT_HISTORY_MAX_WORKERS = _env_int("WATCH_GIT_HISTORY_MAX_WORKERS", 1) +_GIT_HISTORY_TIMEOUT_SECONDS = _env_int("WATCH_GIT_HISTORY_TIMEOUT_SECONDS", 0) +_GIT_HISTORY_EXECUTOR = ThreadPoolExecutor( + max_workers=_GIT_HISTORY_MAX_WORKERS, + thread_name_prefix="git-history", +) + + +def _shutdown_git_history_executor() -> None: + try: + _GIT_HISTORY_EXECUTOR.shutdown(wait=False) + except Exception: + pass + + +atexit.register(_shutdown_git_history_executor) +_GIT_HISTORY_INFLIGHT: set[str] = set() +_GIT_HISTORY_INFLIGHT_LOCK = threading.Lock() + + +def _manifest_key(p: Path) -> str: + try: + return str(p.resolve()) + except Exception: + return str(p) + + +def _manifest_stats(p: Path) -> tuple[str, int]: + run_id = "unknown" + commit_count = -1 + try: + with p.open("r", encoding="utf-8") as fh: + data = json.load(fh) + if isinstance(data, dict): + commits = data.get("commits") or [] + if isinstance(commits, list): + commit_count = len(commits) + name = p.name + run_id = name[:-5] if name.endswith(".json") else name + except Exception: + pass + return run_id, commit_count + + +def _run_git_history_ingest( p: Path, collection: str, repo_name: Optional[str], env_snapshot: Optional[Dict[str, str]] = None, ) -> None: - try: - script = ROOT_DIR / "scripts" / "ingest_history.py" - if not script.exists(): - return - cmd = [sys.executable or "python3", str(script), "--manifest-json", str(p)] - env = _build_subprocess_env(collection, repo_name, env_snapshot) + script = watch_config.ROOT_DIR / "scripts" / "ingest_history.py" + if not script.exists(): + raise RuntimeError(f"[git_history_manifest] ingest script missing: {script}") + + cmd = [sys.executable or "python3", "-m", "scripts.ingest_history", "--manifest-json", str(p)] + env = _build_subprocess_env(collection, repo_name, env_snapshot) + started = time.monotonic() + timeout = _GIT_HISTORY_TIMEOUT_SECONDS if _GIT_HISTORY_TIMEOUT_SECONDS > 0 else None + stdout_tail: deque[str] = deque(maxlen=20) + stderr_tail: deque[str] = deque(maxlen=20) + tail_lock = threading.Lock() + + def _tail_snapshot(tail: deque[str], limit: int = 5) -> str: + with tail_lock: + return " | ".join(list(tail)[-limit:]) + + def _stream_pipe(pipe, label: str, tail: deque[str], lock: threading.Lock) -> None: try: - print( - f"[git_history_manifest] launching ingest_history.py for {p} " - f"collection={collection} repo={repo_name}" - ) + for raw in iter(pipe.readline, ""): + line = (raw or "").rstrip() + if not line: + continue + with lock: + tail.append(line) + logger.info("[git_history_manifest][%s] %s", label, line) except Exception: pass - # Use subprocess.run for better error observability. - # NOTE: This blocks until ingest_history.py completes. If history ingestion - # is slow, this may need revisiting (e.g., revert to Popen fire-and-forget - # or run in a separate thread) to avoid blocking the watcher. - result = subprocess.run(cmd, env=env, capture_output=True, text=True, check=False) - if result.returncode != 0: - logger.warning( - "[git_history_manifest] ingest_history.py failed for %s: exit=%d stderr=%s", - p, result.returncode, (result.stderr or "")[:500], + finally: + try: + pipe.close() + except Exception: + pass + + proc: Optional[subprocess.Popen] = None + try: + proc = subprocess.Popen( + cmd, + cwd=str(watch_config.ROOT_DIR), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + t_out = threading.Thread( + target=_stream_pipe, + args=(proc.stdout, "stdout", stdout_tail, tail_lock), + daemon=True, + ) + t_err = threading.Thread( + target=_stream_pipe, + args=(proc.stderr, "stderr", stderr_tail, tail_lock), + daemon=True, + ) + t_out.start() + t_err.start() + + deadline = (started + timeout) if timeout else None + timed_out = False + while True: + code = proc.poll() + if code is not None: + break + if deadline and time.monotonic() >= deadline: + timed_out = True + try: + proc.kill() + except Exception: + pass + break + time.sleep(0.2) + + # Ensure threads flush trailing output after process exit/kill. + t_out.join(timeout=1.0) + t_err.join(timeout=1.0) + + if timed_out: + elapsed_ms = int((time.monotonic() - started) * 1000) + error_msg = ( + f"[git_history_manifest] ingest_history.py timeout for {p} after {elapsed_ms}ms " + f"(timeout={_GIT_HISTORY_TIMEOUT_SECONDS}s)" ) + if stderr_tail: + error_msg += f" stderr={_tail_snapshot(stderr_tail)}" + logger.warning(error_msg) + raise RuntimeError(error_msg) + + returncode = proc.wait(timeout=1.0) except Exception as e: - logger.warning("[git_history_manifest] error processing %s: %s", p, e) - return + logger.warning("[git_history_manifest] subprocess error for %s: %s", p, e) + try: + if proc and proc.poll() is None: + proc.kill() + except Exception: + pass + raise RuntimeError(f"[git_history_manifest] subprocess error for {p}: {e}") from e + + elapsed_ms = int((time.monotonic() - started) * 1000) + if returncode != 0: + error_msg = ( + f"[git_history_manifest] ingest_history.py failed for {p}: exit={returncode} " + f"elapsed_ms={elapsed_ms} stderr={_tail_snapshot(stderr_tail)}" + ) + logger.warning(error_msg) + raise RuntimeError(error_msg) + + logger.info( + "[git_history_manifest] completed for %s: exit=0 elapsed_ms=%d", + p, + elapsed_ms, + ) + if stdout_tail: + logger.info( + "[git_history_manifest] stdout tail for %s: %s", + p, + _tail_snapshot(stdout_tail), + ) + if stderr_tail: + logger.warning( + "[git_history_manifest] stderr tail for %s: %s", + p, + _tail_snapshot(stderr_tail), + ) + + +def _on_git_history_done(manifest_path: Path, collection: str, repo_name: Optional[str], future: Future) -> None: + manifest_key = _manifest_key(manifest_path) + with _GIT_HISTORY_INFLIGHT_LOCK: + _GIT_HISTORY_INFLIGHT.discard(manifest_key) + remaining = len(_GIT_HISTORY_INFLIGHT) + logger.info("[git_history_manifest] in-flight remaining=%d", remaining) + try: + future.result() + # Mark journal as done after successful completion + repo_path = _detect_repo_for_file(manifest_path) + if repo_path: + repo_key = str(repo_path) + _mark_journal_done(manifest_path, repo_key, repo_name) + logger.info("[git_history_manifest] marked journal as done: %s", manifest_path) + except Exception as e: + repo_path = _detect_repo_for_file(manifest_path) + repo_key = str(repo_path) if repo_path else "" + if repo_key: + _mark_journal_failed( + manifest_path, + repo_key, + repo_name, + f"git history worker failed for collection '{collection}': {e}", + ) + logger.warning( + "[git_history_manifest] worker crashed for %s (collection=%s, repo_key=%s): %s", + manifest_key, + collection, + repo_key or "", + e, + exc_info=True, + ) + + +def _process_git_history_manifest( + p: Path, + collection: str, + repo_name: Optional[str], + env_snapshot: Optional[Dict[str, str]] = None, +) -> None: + key = _manifest_key(p) + run_id, commit_count = _manifest_stats(p) + queued = 0 + with _GIT_HISTORY_INFLIGHT_LOCK: + if key in _GIT_HISTORY_INFLIGHT: + logger.info( + "[git_history_manifest] skip duplicate in-flight manifest: %s run_id=%s", + p, + run_id, + ) + return + _GIT_HISTORY_INFLIGHT.add(key) + queued = len(_GIT_HISTORY_INFLIGHT) + logger.info( + "[git_history_manifest] queued ingest_history.py for %s run_id=%s commits=%d collection=%s repo=%s in_flight=%d", + p, + run_id, + commit_count, + collection, + repo_name, + queued, + ) + future = _GIT_HISTORY_EXECUTOR.submit( + _run_git_history_ingest, + p, + collection, + repo_name, + env_snapshot, + ) + future.add_done_callback(lambda fut, manifest_path=p, coll=collection, rn=repo_name: _on_git_history_done(manifest_path, coll, rn, fut)) def _advance_progress( @@ -92,6 +361,216 @@ def _advance_progress( pass +def _mark_journal_done(path: Path, repo_key: str, repo_name: Optional[str]) -> None: + try: + update_index_journal_entry_status( + str(path), + status="done", + workspace_path=repo_key, + repo_name=repo_name, + ) + except Exception: + pass + + +def _mark_journal_failed( + path: Path, + repo_key: str, + repo_name: Optional[str], + error: str, +) -> None: + try: + update_index_journal_entry_status( + str(path), + status="failed", + error=error, + workspace_path=repo_key, + repo_name=repo_name, + remove_on_done=False, + ) + except Exception: + pass + + +def _path_has_indexed_points(client, collection: str, path: Path) -> Optional[bool]: + try: + filt = models.Filter( + must=[ + models.FieldCondition( + key="metadata.path", match=models.MatchValue(value=str(path)) + ) + ] + ) + points, _ = client.scroll( + collection_name=collection, + scroll_filter=filt, + with_payload=False, + with_vectors=False, + limit=1, + ) + return bool(points) + except Exception: + return None + + +def _verify_delete_committed(client, collection: str, path: Path) -> bool: + has_points = _path_has_indexed_points(client, collection, path) + return has_points is False + + +def _path_has_graph_edges(client, collection: str, path: Path) -> Optional[bool]: + graph_collection = f"{collection}_graph" + try: + # Graph edges normalize paths (Windows -> POSIX separators). Verification must + # query using the same normalization to avoid false "deleted" reports. + raw_path = str(path) + candidates: list[str] = [] + try: + norm_path = str(_normalize_graph_caller_path(raw_path) or "").strip() + if norm_path: + candidates.append(norm_path) + except Exception: + pass + # Back-compat: also consider the raw string and a slash-normalized form in case + # older data was written without normalization. + raw_slash = raw_path.replace("\\", "/").strip() + for v in (raw_slash, raw_path.strip()): + if v and v not in candidates: + candidates.append(v) + + match_obj = ( + models.MatchAny(any=candidates) + if len(candidates) > 1 + else models.MatchValue(value=(candidates[0] if candidates else raw_path)) + ) + filt = models.Filter( + must=[ + models.FieldCondition( + key="caller_path", match=match_obj + ) + ] + ) + points, _ = client.scroll( + collection_name=graph_collection, + scroll_filter=filt, + with_payload=False, + with_vectors=False, + limit=1, + ) + return bool(points) + except Exception as e: + # Missing graph collection means there are no graph edges to verify. + err = str(e).lower() + if "404" in err or "not found" in err or "doesn't exist" in err: + return False + return None + + +def _verify_graph_delete_committed(client, collection: str, path: Path) -> bool: + has_edges = _path_has_graph_edges(client, collection, path) + return has_edges is False + + +def _verify_upsert_committed( + client, + collection: str, + path: Path, + repo_name: Optional[str], + expected_file_hash: Optional[str], + source_text: Optional[str] = None, +) -> bool: + indexed_hash = str( + idx.get_indexed_file_hash(client, collection, str(path)) or "" + ).strip() + expected_hash = str(expected_file_hash or "").strip() + if expected_hash: + if bool(indexed_hash) and indexed_hash == expected_hash: + return True + # Empty/whitespace-only files can legitimately have no indexed points/hash. + try: + if source_text is not None and not source_text.strip(): + has_points = _path_has_indexed_points(client, collection, path) + return has_points is False + except Exception: + pass + return False + has_points = _path_has_indexed_points(client, collection, path) + return has_points is True + + +def _verify_and_update_journal_for_upsert( + p: Path, + client, + collection: str, + repo_key: str, + repo_name: Optional[str], + journal_content_hash: str, + *, + text: Optional[str] = None, + file_hash: Optional[str] = None, +) -> None: + source_text = text + expected_hash = str(file_hash or "").strip() + if source_text is None or not expected_hash: + read_text, read_hash = _read_text_and_sha1(p) + if source_text is None: + source_text = read_text + if not expected_hash: + expected_hash = read_hash + expected_hash = expected_hash or journal_content_hash + if _verify_upsert_committed( + client, + collection, + p, + repo_name, + expected_hash or None, + source_text=source_text, + ): + _mark_journal_done(p, repo_key, repo_name) + else: + _mark_journal_failed( + p, + repo_key, + repo_name, + "upsert_verification_failed", + ) + + +def _finalize_journal_after_index_attempt( + path: Path, + client, + collection: str | None, + repo_key: str, + repo_name: Optional[str], + *, + force_upsert: bool, + journal_content_hash: str, + text: Optional[str] = None, + file_hash: Optional[str] = None, + default_error: Optional[str] = None, + skip_verify_reason: Optional[str] = None, +) -> None: + if force_upsert and client is not None and collection is not None: + # If another worker currently owns this file lock, leave the journal entry + # pending for retry instead of recording a false verification failure. + if skip_verify_reason == "file_locked": + return + _verify_and_update_journal_for_upsert( + path, + client, + collection, + repo_key, + repo_name, + journal_content_hash, + text=text, + file_hash=file_hash, + ) + elif default_error: + _mark_journal_failed(path, repo_key, repo_name, default_error) + else: + _mark_journal_done(path, repo_key, repo_name) + + def _build_subprocess_env( collection: str | None, repo_name: str | None, @@ -105,8 +584,8 @@ def _build_subprocess_env( pass if collection: env["COLLECTION_NAME"] = collection - if QDRANT_URL: - env["QDRANT_URL"] = QDRANT_URL + if watch_config.QDRANT_URL: + env["QDRANT_URL"] = watch_config.QDRANT_URL if repo_name: env["REPO_NAME"] = repo_name return env @@ -114,6 +593,7 @@ def _build_subprocess_env( def _maybe_handle_staging_file( path: Path, + client, collection: str | None, repo_name: str | None, repo_key: str, @@ -121,27 +601,46 @@ def _maybe_handle_staging_file( state_env: Optional[Dict[str, str]], repo_progress: Dict[str, int], started_at: str, + *, + force_upsert: bool = False, + journal_content_hash: str = "", ) -> bool: - if not (is_staging_enabled() and state_env and collection): + if not (state_env and collection): return False - _text, file_hash = _read_text_and_sha1(path) + source_text, file_hash = _read_text_and_sha1(path) if file_hash: try: cached_hash = get_cached_file_hash(str(path), repo_name) if repo_name else None except Exception: cached_hash = None if cached_hash and cached_hash == file_hash: + if force_upsert and client is not None: + if _verify_upsert_committed( + client, + collection, + path, + repo_name, + file_hash or journal_content_hash or None, + source_text=source_text, + ): + safe_print(f"[skip_unchanged] {path} (hash match)") + _log_activity(repo_key, "skipped", path, {"reason": "hash_unchanged"}) + _mark_journal_done(path, repo_key, repo_name) + _advance_progress(repo_progress, repo_key, repo_files, started_at, path) + return True # Fast path: skip if content hash matches cached hash (file unchanged) # Safety: startup health check clears stale cache per-repo - safe_print(f"[skip_unchanged] {path} (hash match)") - _log_activity(repo_key, "skipped", path, {"reason": "hash_unchanged"}) - _advance_progress(repo_progress, repo_key, repo_files, started_at, path) - return True + if not force_upsert: + safe_print(f"[skip_unchanged] {path} (hash match)") + _log_activity(repo_key, "skipped", path, {"reason": "hash_unchanged"}) + _advance_progress(repo_progress, repo_key, repo_files, started_at, path) + return True cmd = [ sys.executable or "python3", - str(ROOT_DIR / "scripts" / "ingest_code.py"), + "-m", + "scripts.ingest_code", "--root", str(path), "--no-skip-unchanged", @@ -178,6 +677,19 @@ def _maybe_handle_staging_file( ) else: safe_print(f"[indexed_subprocess] {path} -> {collection}") + _finalize_journal_after_index_attempt( + path, + client, + collection, + repo_key, + repo_name, + force_upsert=force_upsert, + journal_content_hash=journal_content_hash, + text=source_text, + file_hash=file_hash, + ) + if result.returncode != 0 and force_upsert: + _mark_journal_failed(path, repo_key, repo_name, "subprocess_index_failed") _advance_progress(repo_progress, repo_key, repo_files, started_at, path) return True @@ -218,18 +730,92 @@ def _process_paths( pass repo_progress: Dict[str, int] = {key: 0 for key in repo_groups.keys()} + repo_pending_journal_ops: Dict[str, Dict[str, Dict[str, str]]] = {} + repo_move_source_for_dest: Dict[str, Dict[str, str]] = {} + move_dest_keys: set[str] = set() + move_source_keys: set[str] = set() + for repo_path in repo_groups.keys(): + try: + repo_name = _extract_repo_name_from_path(repo_path) + entries = list_pending_index_journal_entries(repo_path, repo_name) + repo_pending_journal_ops[repo_path] = {} + upserts_by_hash: Dict[str, List[str]] = {} + deletes_by_hash: Dict[str, List[str]] = {} + for rec in entries: + path_key = _normalize_cache_key_path(str(rec.get("path") or "")) + op_type = str(rec.get("op_type") or "").strip().lower() + content_hash = str(rec.get("content_hash") or "").strip().lower() + if not path_key: + continue + repo_pending_journal_ops[repo_path][path_key] = { + "op_type": op_type, + "content_hash": content_hash, + } + if not content_hash: + continue + if op_type == "upsert": + upserts_by_hash.setdefault(content_hash, []).append(path_key) + elif op_type == "delete": + deletes_by_hash.setdefault(content_hash, []).append(path_key) + pairs: Dict[str, str] = {} + for content_hash, dest_paths in upserts_by_hash.items(): + src_paths = deletes_by_hash.get(content_hash) or [] + if not src_paths: + continue + src_idx = 0 + for dest_key in dest_paths: + while src_idx < len(src_paths) and src_paths[src_idx] == dest_key: + src_idx += 1 + if src_idx >= len(src_paths): + break + src_key = src_paths[src_idx] + src_idx += 1 + pairs[dest_key] = src_key + move_dest_keys.add(dest_key) + move_source_keys.add(src_key) + repo_move_source_for_dest[repo_path] = pairs + except Exception: + repo_pending_journal_ops[repo_path] = {} + repo_move_source_for_dest[repo_path] = {} + + unique_paths = sorted( + unique_paths, + key=lambda p: ( + 0 + if _normalize_cache_key_path(str(p)) in move_dest_keys + else (2 if _normalize_cache_key_path(str(p)) in move_source_keys else 1), + str(p), + ), + ) + completed_move_sources: set[str] = set() for p in unique_paths: repo_path = _detect_repo_for_file(p) or Path(workspace_path) repo_key = str(repo_path) repo_files = repo_groups.get(repo_key, []) repo_name = _extract_repo_name_from_path(repo_key) + path_key = _normalize_cache_key_path(str(p)) + if path_key in completed_move_sources: + _advance_progress(repo_progress, repo_key, repo_files, started_at, p) + continue + journal_rec = repo_pending_journal_ops.get(repo_key, {}).get(path_key, {}) + journal_op = str(journal_rec.get("op_type") or "").strip().lower() + force_delete = journal_op == "delete" + force_upsert = journal_op == "upsert" + journal_content_hash = str(journal_rec.get("content_hash") or "").strip().lower() + if _is_internal_ignored_path(p): + _log_activity(repo_key, "skipped", p, {"reason": "internal_ignored_path"}) + # Internal metadata paths should never drive indexing or collection creation. + # If they entered the journal via drift repair, mark done and drop. + _mark_journal_done(p, repo_key, repo_name) + _advance_progress(repo_progress, repo_key, repo_files, started_at, p) + continue collection = _get_collection_for_file(p) state_env: Optional[Dict[str, str]] = None try: st = get_workspace_state(repo_key, repo_name) if get_workspace_state else None if isinstance(st, dict): - if is_staging_enabled(): + if _staging_requires_subprocess(st): state_env = st.get("indexing_env") except Exception: state_env = None @@ -240,31 +826,118 @@ def _process_paths( p, collection, repo_name, - env_snapshot=(state_env if is_staging_enabled() else None), + env_snapshot=state_env, ) except Exception as exc: safe_print(f"[commit_ingest_error] {p}: {exc}") _advance_progress(repo_progress, repo_key, repo_files, started_at, p) continue - if not p.exists(): + if force_upsert and not p.exists(): + _log_activity(repo_key, "skipped", p, {"reason": "upsert_missing_file"}) + _mark_journal_failed( + p, + repo_key, + repo_name, + "upsert_missing_file", + ) + _advance_progress(repo_progress, repo_key, repo_files, started_at, p) + continue + + if force_upsert and client is not None and collection is not None: + move_src_key = repo_move_source_for_dest.get(repo_key, {}).get(path_key) + if move_src_key: + move_src_path = Path(move_src_key) + src_collection = _get_collection_for_file(move_src_path) + try: + moved_count, renamed_hash = _rename_in_store( + client, + src_collection, + move_src_path, + p, + collection, + ) + except Exception: + moved_count, renamed_hash = -1, None + if moved_count and moved_count > 0: + try: + if repo_name: + remove_cached_file(str(move_src_path), repo_name) + except Exception: + pass + final_hash = renamed_hash or journal_content_hash + try: + if repo_name and final_hash: + set_cached_file_hash(str(p), final_hash, repo_name) + except Exception: + pass + _log_activity( + repo_key, + "moved", + p, + {"from": str(move_src_path), "chunks": int(moved_count)}, + ) + _mark_journal_done(p, repo_key, repo_name) + _mark_journal_done(move_src_path, repo_key, repo_name) + completed_move_sources.add(move_src_key) + _advance_progress(repo_progress, repo_key, repo_files, started_at, p) + continue + + if force_delete or not p.exists(): + deleted_ok = False if client is not None: try: idx.delete_points_by_path(client, collection, str(p)) + try: + idx.delete_graph_edges_by_path( + client, + collection, + caller_path=str(p), + repo=repo_name, + ) + # Repo tags can drift over time, so always follow a repo-scoped + # delete with a path-only sweep to remove any stale rows left under + # an older/default repo tag. + if repo_name: + idx.delete_graph_edges_by_path( + client, + collection, + caller_path=str(p), + repo=None, + ) + except Exception as graph_exc: + safe_print(f"[deleted:graph_failed] {p} -> {collection}: {graph_exc}") safe_print(f"[deleted] {p} -> {collection}") + deleted_ok = True except Exception: - pass + deleted_ok = False + if deleted_ok and client is not None and collection is not None: + deleted_ok = _verify_delete_committed(client, collection, p) + if deleted_ok and client is not None and collection is not None: + verify_graph_delete = get_boolean_env("WATCH_VERIFY_GRAPH_DELETE", True) + if verify_graph_delete: + deleted_ok = _verify_graph_delete_committed(client, collection, p) try: if repo_name: remove_cached_file(str(p), repo_name) except Exception: pass _log_activity(repo_key, "deleted", p) + if deleted_ok: + _mark_journal_done(p, repo_key, repo_name) + else: + _mark_journal_failed( + p, + repo_key, + repo_name, + "delete_points_or_graph_failed", + ) _advance_progress(repo_progress, repo_key, repo_files, started_at, p) continue if _maybe_handle_staging_file( p, + client, collection, repo_name, repo_key, @@ -272,17 +945,39 @@ def _process_paths( state_env, repo_progress, started_at, + force_upsert=force_upsert, + journal_content_hash=journal_content_hash, ): continue if client is not None and model is not None: try: + verify_context: Dict[str, Optional[str]] = {} ok = _run_indexing_strategy( - p, client, model, collection, vector_name, model_dim, repo_name + p, + client, + model, + collection, + vector_name, + model_dim, + repo_name, + force_upsert=force_upsert, + verify_context=verify_context if force_upsert else None, ) - except _SkipUnchanged: + except _SkipUnchanged as exc: status = "skipped" safe_print(f"[{status}] {p} -> {collection}") _log_activity(repo_key, "skipped", p, {"reason": "hash_unchanged"}) + _finalize_journal_after_index_attempt( + p, + client, + collection, + repo_key, + repo_name, + force_upsert=force_upsert, + journal_content_hash=journal_content_hash, + text=exc.text, + file_hash=exc.file_hash, + ) _advance_progress(repo_progress, repo_key, repo_files, started_at, p) continue except Exception: @@ -295,6 +990,7 @@ def _process_paths( "file": str(p), }, ) + _mark_journal_failed(p, repo_key, repo_name, "indexing_error") _advance_progress(repo_progress, repo_key, repo_files, started_at, p) continue @@ -306,10 +1002,35 @@ def _process_paths( except Exception: size = None _log_activity(repo_key, "indexed", p, {"file_size": size}) + _finalize_journal_after_index_attempt( + p, + client, + collection, + repo_key, + repo_name, + force_upsert=force_upsert, + journal_content_hash=journal_content_hash, + text=verify_context.get("text"), + file_hash=verify_context.get("file_hash"), + skip_verify_reason=verify_context.get("skip_verify_reason"), + ) else: _log_activity( repo_key, "skipped", p, {"reason": "no-change-or-error"} ) + _finalize_journal_after_index_attempt( + p, + client, + collection, + repo_key, + repo_name, + force_upsert=force_upsert, + journal_content_hash=journal_content_hash, + text=verify_context.get("text"), + file_hash=verify_context.get("file_hash"), + default_error="no_change_or_error", + skip_verify_reason=verify_context.get("skip_verify_reason"), + ) _advance_progress(repo_progress, repo_key, repo_files, started_at, p) else: safe_print(f"Not processing locally: {p}") @@ -333,7 +1054,7 @@ def _read_text_and_sha1(path: Path) -> tuple[Optional[str], str]: text = path.read_text(encoding="utf-8", errors="ignore") except Exception: text = None - if not text: + if text is None: return text, "" try: file_hash = hashlib.sha1(text.encode("utf-8", errors="ignore")).hexdigest() @@ -350,69 +1071,111 @@ def _run_indexing_strategy( vector_name: str, model_dim: int, repo_name: str | None, + force_upsert: bool = False, + *, + verify_context: Optional[Dict[str, Optional[str]]] = None, ) -> bool: if collection is None: return False - try: - idx.ensure_collection_and_indexes_once(client, collection, model_dim, vector_name) - except Exception: - pass text, file_hash = _read_text_and_sha1(path) + if verify_context is not None: + verify_context["text"] = text + verify_context["file_hash"] = file_hash + verify_context["skip_verify_reason"] = None ok = False if text is not None: try: language = idx.detect_language(path) except Exception: language = "" + try: + is_text_like = bool(idx.is_text_like_language(language)) + except Exception: + is_text_like = False if file_hash: try: cached_hash = get_cached_file_hash(str(path), repo_name) if repo_name else None except Exception: cached_hash = None - if cached_hash and cached_hash == file_hash: + if cached_hash and cached_hash == file_hash and not force_upsert: ok = True - raise _SkipUnchanged() - try: - use_smart, smart_reason = idx.should_use_smart_reindexing(str(path), file_hash) - except Exception: - use_smart, smart_reason = False, "smart_check_failed" - # Bootstrap: if we have no symbol cache yet, still run smart path once - bootstrap = smart_reason == "no_cached_symbols" - if use_smart or bootstrap: - msg_kind = ( - "smart reindexing" - if use_smart - else "bootstrap (no_cached_symbols) for smart reindex" - ) - safe_print( - f"[SMART_REINDEX][watcher] Using {msg_kind} for {path} ({smart_reason})" - ) + raise _SkipUnchanged(text=text, file_hash=file_hash) + + # Repair upserts must materialize points when a path is missing in Qdrant. + # Smart reindex can return "skipped" for unchanged symbols, which is valid + # only when points already exist. + force_full_reindex = False + if force_upsert and client is not None: try: - status = idx.process_file_with_smart_reindexing( - path, - text, - language, - client, - collection, - repo_name, - model, - vector_name, + existing_hash = str( + idx.get_indexed_file_hash(client, collection, str(path)) or "" + ).strip() + except Exception: + existing_hash = "" + if not existing_hash: + has_points = _path_has_indexed_points(client, collection, path) + if has_points is not True: + force_full_reindex = True + + if not is_text_like: + try: + use_smart, smart_reason = idx.should_use_smart_reindexing(str(path), file_hash) + except Exception: + use_smart, smart_reason = False, "smart_check_failed" + # Bootstrap: if we have no symbol cache yet, still run smart path once + bootstrap = smart_reason == "no_cached_symbols" + if (use_smart or bootstrap) and not force_full_reindex: + msg_kind = ( + "smart reindexing" + if use_smart + else "bootstrap (no_cached_symbols) for smart reindex" ) - ok = status in ("success", "skipped") - except Exception as exc: safe_print( - f"[SMART_REINDEX][watcher] Smart reindexing failed for {path}: {exc}" + f"[SMART_REINDEX][watcher] Using {msg_kind} for {path} ({smart_reason})" ) - ok = False - else: - safe_print( - f"[SMART_REINDEX][watcher] Using full reindexing for {path} ({smart_reason})" - ) - # Fallback: full single-file reindex. Pseudo/tags are inlined by default; - # when PSEUDO_DEFER_TO_WORKER=1 we run base-only and rely on backfill. + try: + status = idx.process_file_with_smart_reindexing( + path, + text, + language, + client, + collection, + repo_name, + model, + vector_name, + model_dim=model_dim, + ) + ok = status in ("success", "skipped") + except Exception as exc: + safe_print( + f"[SMART_REINDEX][watcher] Smart reindexing failed for {path}: {exc}" + ) + ok = False + else: + if force_full_reindex: + safe_print( + f"[SMART_REINDEX][watcher] Forcing full reindex for {path} " + "(force_upsert_missing_points)" + ) + safe_print( + f"[SMART_REINDEX][watcher] Using full reindexing for {path} ({smart_reason})" + ) + # Fallback: full single-file reindex. Pseudo/tags are inlined by default; + # when PSEUDO_DEFER_TO_WORKER=1 we run base-only and rely on backfill. if not ok: - pseudo_mode = "off" if get_boolean_env("PSEUDO_DEFER_TO_WORKER") else "full" + try: + idx.ensure_collection_and_indexes_once( + client, collection, model_dim, vector_name + ) + except Exception: + pass + # PSEUDO_DEFER_TO_WORKER is a foreground/background semantics knob; it should + # only disable inline pseudo/tags generation when the backfill worker is enabled. + pseudo_mode = effective_pseudo_mode( + defer_to_worker=get_boolean_env("PSEUDO_DEFER_TO_WORKER"), + backfill_enabled=get_boolean_env("PSEUDO_BACKFILL_ENABLED"), + ) ok = idx.index_single_file( client, model, @@ -423,7 +1186,16 @@ def _run_indexing_strategy( skip_unchanged=False, pseudo_mode=pseudo_mode, repo_name_for_cache=repo_name, + preloaded_text=text, + preloaded_file_hash=file_hash, + preloaded_language=language if text is not None else None, ) + if force_upsert and not ok and verify_context is not None: + try: + if idx.is_file_locked(str(path)): + verify_context["skip_verify_reason"] = "file_locked" + except Exception: + pass return ok diff --git a/scripts/watch_index_core/pseudo.py b/scripts/watch_index_core/pseudo.py index dc7bb0a8..b4b573f5 100644 --- a/scripts/watch_index_core/pseudo.py +++ b/scripts/watch_index_core/pseudo.py @@ -8,6 +8,7 @@ from typing import Optional import scripts.ingest_code as idx +from . import config as watch_config from .utils import get_boolean_env from scripts.workspace_state import ( _cross_process_lock, @@ -17,8 +18,6 @@ is_multi_repo_mode, ) -from .config import ROOT - logger = logging.getLogger(__name__) @@ -27,6 +26,8 @@ def _start_pseudo_backfill_worker( default_collection: str, model_dim: int, vector_name: str, + *, + allow_default_collection_fallback: bool = True, ) -> Optional[threading.Event]: """Start a daemon thread that periodically backfills pseudo/tags. @@ -34,7 +35,12 @@ def _start_pseudo_backfill_worker( or None if the worker was not started (disabled via env). """ - if not get_boolean_env("PSEUDO_DEFER_TO_WORKER"): + # This worker is controlled by PSEUDO_BACKFILL_ENABLED (pseudo/tags) and/or + # GRAPH_EDGES_BACKFILL (graph edges). PSEUDO_DEFER_TO_WORKER only controls + # whether the foreground index path generates pseudo inline. + pseudo_backfill_enabled = get_boolean_env("PSEUDO_BACKFILL_ENABLED") + graph_backfill_enabled = get_boolean_env("GRAPH_EDGES_BACKFILL") + if not (pseudo_backfill_enabled or graph_backfill_enabled): return None try: @@ -49,20 +55,36 @@ def _start_pseudo_backfill_worker( max_points = 256 if max_points <= 0: max_points = 1 + try: + graph_max_files = int( + os.environ.get("GRAPH_EDGES_BACKFILL_MAX_FILES", "128") or 128 + ) + except Exception: + graph_max_files = 128 + if graph_max_files <= 0: + graph_max_files = 1 shutdown_event = threading.Event() def _worker() -> None: while not shutdown_event.is_set(): try: + pseudo_backfill_on = get_boolean_env("PSEUDO_BACKFILL_ENABLED") + graph_backfill_on = get_boolean_env("GRAPH_EDGES_BACKFILL") try: - mappings = get_collection_mappings(search_root=str(ROOT)) + mappings = get_collection_mappings(search_root=str(watch_config.ROOT)) except Exception: mappings = [] if not mappings: - mappings = [ - {"repo_name": None, "collection_name": default_collection}, - ] + # Do not fall back to the default collection unless startup explicitly + # allowed the watcher to touch it. This keeps background backfill from + # recreating collections that the caller intentionally left alone. + if is_multi_repo_mode() or not allow_default_collection_fallback: + mappings = [] + else: + mappings = [ + {"repo_name": None, "collection_name": default_collection}, + ] for mapping in mappings: if shutdown_event.is_set(): break @@ -74,21 +96,52 @@ def _worker() -> None: if is_multi_repo_mode() and repo_name: state_dir = _get_repo_state_dir(repo_name) else: - state_dir = _get_global_state_dir(str(ROOT)) + state_dir = _get_global_state_dir(str(watch_config.ROOT)) lock_path = state_dir / "pseudo.lock" with _cross_process_lock(lock_path): - processed = idx.pseudo_backfill_tick( - client, - coll, - repo_name=repo_name, - max_points=max_points, - dim=model_dim, - vector_name=vector_name, - ) - if processed: - logger.info( - "[pseudo_backfill] repo=%s collection=%s processed=%d", - repo_name or "default", coll, processed, + if pseudo_backfill_on: + processed = idx.pseudo_backfill_tick( + client, + coll, + repo_name=repo_name, + max_points=max_points, + dim=model_dim, + vector_name=vector_name, + ) + if processed: + logger.info( + "[pseudo_backfill] repo=%s collection=%s processed=%d", + repo_name or "default", + coll, + processed, + ) + # Optional: backfill graph edge collection from main points. + # Controlled separately because it may scan large collections over time. + # Run under its own lock to avoid blocking pseudo/tag backfill workers. + if graph_backfill_on: + try: + graph_lock_path = state_dir / "graph_edges.lock" + with _cross_process_lock(graph_lock_path): + files_done = idx.graph_edges_backfill_tick( + client, + coll, + repo_name=repo_name, + max_files=graph_max_files, + ) + if files_done: + logger.info( + "[graph_backfill] repo=%s collection=%s files=%d", + repo_name or "default", + coll, + files_done, + ) + except Exception as exc: + logger.error( + "[graph_backfill] error repo=%s collection=%s: %s", + repo_name or "default", + coll, + exc, + exc_info=True, ) except Exception as exc: logger.error( @@ -110,4 +163,3 @@ def _worker() -> None: __all__ = ["_start_pseudo_backfill_worker"] - diff --git a/scripts/watch_index_core/queue.py b/scripts/watch_index_core/queue.py index ede8835b..ca420842 100644 --- a/scripts/watch_index_core/queue.py +++ b/scripts/watch_index_core/queue.py @@ -3,10 +3,11 @@ from __future__ import annotations import threading +import time from pathlib import Path from typing import Callable, Iterable, List, Set -from .config import DELAY_SECS, LOGGER +from .config import DELAY_SECS, LOGGER, RECENT_FINGERPRINT_TTL_SECS class ChangeQueue: @@ -16,15 +17,21 @@ def __init__(self, process_cb: Callable[[List[Path]], None]): self._lock = threading.Lock() self._paths: Set[Path] = set() self._pending: Set[Path] = set() + self._forced_paths: Set[Path] = set() + self._pending_forced: Set[Path] = set() self._timer: threading.Timer | None = None self._process_cb = process_cb # Serialize processing to avoid concurrent use of TextEmbedding/QdrantClient self._processing_lock = threading.Lock() + self._recent_fingerprints: dict[Path, tuple[tuple[int, int], float]] = {} - def add(self, p: Path) -> None: + def add(self, p: Path, *, force: bool = False) -> None: with self._lock: + already_queued = p in self._paths self._paths.add(p) - if self._timer is not None: + if force: + self._forced_paths.add(p) + if self._timer is not None and not already_queued: try: self._timer.cancel() except Exception as exc: @@ -32,21 +39,101 @@ def add(self, p: Path) -> None: "Failed to cancel timer in ChangeQueue.add", extra={"error": str(exc)}, ) - self._timer = threading.Timer(DELAY_SECS, self._flush) - self._timer.daemon = True - self._timer.start() + if self._timer is None or not already_queued: + self._timer = threading.Timer(DELAY_SECS, self._flush) + self._timer.daemon = True + self._timer.start() + + def stats(self) -> dict[str, int | bool]: + with self._lock: + return { + "queued": len(self._paths), + "pending": len(self._pending), + "forced": len(self._forced_paths), + "pending_forced": len(self._pending_forced), + "processing": self._processing_lock.locked(), + } + + def _fingerprint_path(self, p: Path) -> tuple[int, int] | None: + try: + st = p.stat() + return ( + int(getattr(st, "st_size", 0)), + int(getattr(st, "st_mtime_ns", int(st.st_mtime * 1e9))), + ) + except Exception: + return None + + def _filter_recent_paths( + self, + paths: Iterable[Path], + *, + forced_paths: Iterable[Path] | None = None, + ) -> list[Path]: + ttl = float(RECENT_FINGERPRINT_TTL_SECS) + forced = set(forced_paths or []) + if ttl <= 0: + return list(paths) + + now = time.time() + keep: list[Path] = [] + for p in paths: + if p in forced: + keep.append(p) + continue + fp = self._fingerprint_path(p) + if fp is None: + keep.append(p) + continue + prev = self._recent_fingerprints.get(p) + if prev is not None: + prev_fp, prev_ts = prev + if prev_fp == fp and (now - prev_ts) < ttl: + continue + keep.append(p) + return keep + + def _mark_recent_paths(self, paths: Iterable[Path]) -> None: + ttl = float(RECENT_FINGERPRINT_TTL_SECS) + if ttl <= 0: + return + now = time.time() + for p in paths: + fp = self._fingerprint_path(p) + if fp is None: + continue + self._recent_fingerprints[p] = (fp, now) + # Keep at least a 1s grace for small TTLs while using a proportional + # buffer for larger TTLs so stale handled fingerprints age out cleanly. + cutoff = now - max(ttl * 2.0, ttl + 1.0) + stale = [p for p, (_fp, ts) in self._recent_fingerprints.items() if ts < cutoff] + for p in stale: + self._recent_fingerprints.pop(p, None) + + def _drain_pending(self) -> tuple[list[Path], Set[Path]] | None: + with self._lock: + if not self._pending: + return None + todo = list(self._pending) + todo_forced = {p for p in todo if p in self._pending_forced} + self._pending.clear() + self._pending_forced.clear() + return todo, todo_forced def _flush(self) -> None: # Grab current batch with self._lock: paths = list(self._paths) + forced_paths = {p for p in paths if p in self._forced_paths} self._paths.clear() + self._forced_paths.difference_update(paths) self._timer = None # Try to run the processor exclusively; if busy, queue and return if not self._processing_lock.acquire(blocking=False): with self._lock: self._pending.update(paths) + self._pending_forced.update(forced_paths) if self._timer is None: # schedule a follow-up flush to pick up pending when free self._timer = threading.Timer(DELAY_SECS, self._flush) @@ -56,15 +143,24 @@ def _flush(self) -> None: try: # Per-file locking in index_single_file handles indexer/watcher coordination todo: Iterable[Path] = paths + todo_forced: Set[Path] = set(forced_paths) while True: + filtered_todo = self._filter_recent_paths(todo, forced_paths=todo_forced) + if not filtered_todo: + pending = self._drain_pending() + if pending is None: + break + todo, todo_forced = pending + continue try: - self._process_cb(list(todo)) + self._process_cb(list(filtered_todo)) + self._mark_recent_paths(filtered_todo) except Exception as exc: # Log processing error via structured logging try: LOGGER.error( "Processing batch failed in ChangeQueue._flush", - extra={"error": str(exc), "batch_size": len(list(todo))}, + extra={"error": str(exc), "batch_size": len(filtered_todo)}, exc_info=True, ) except Exception as inner_exc: # pragma: no cover - logging fallback @@ -79,11 +175,10 @@ def _flush(self) -> None: except Exception: pass # Last resort: can't even print # drain any pending accumulated during processing - with self._lock: - if not self._pending: - break - todo = list(self._pending) - self._pending.clear() + pending = self._drain_pending() + if pending is None: + break + todo, todo_forced = pending finally: self._processing_lock.release() diff --git a/scripts/watch_index_core/utils.py b/scripts/watch_index_core/utils.py index 999daa5a..5f4086fd 100644 --- a/scripts/watch_index_core/utils.py +++ b/scripts/watch_index_core/utils.py @@ -8,7 +8,8 @@ from watchdog.observers import Observer import scripts.ingest_code as idx -from .config import LOGGER, ROOT, default_collection_name +from . import config as watch_config +from .config import LOGGER, default_collection_name from scripts.workspace_state import ( _extract_repo_name_from_path, PLACEHOLDER_COLLECTION_NAMES, @@ -93,13 +94,14 @@ def create_observer(use_polling: bool, observer_cls: Type[Observer] = Observer) def _detect_repo_for_file(file_path: Path) -> Optional[Path]: """Detect repository root for a file under WATCH root.""" + root = watch_config.ROOT try: - rel_path = file_path.resolve().relative_to(ROOT.resolve()) + rel_path = file_path.resolve().relative_to(root.resolve()) except Exception: return None if not rel_path.parts: - return ROOT - return ROOT / rel_path.parts[0] + return root + return root / rel_path.parts[0] def _repo_name_or_none(repo_path: Optional[Path]) -> Optional[str]: diff --git a/scripts/workspace_state.py b/scripts/workspace_state.py index b0cb28df..4651ccfa 100644 --- a/scripts/workspace_state.py +++ b/scripts/workspace_state.py @@ -9,6 +9,7 @@ - Multi-repo support with per-repo state files """ import json +import logging import os import re import uuid @@ -22,6 +23,8 @@ _CANONICAL_SLUG_RE = re.compile(r"^.+-[0-9a-f]{16}$") _SLUGGED_REPO_RE = re.compile(r"^.+-[0-9a-f]{16}(?:_old)?$") +INTERNAL_STATE_TOP_LEVEL_DIRS = frozenset({".codebase", ".git", "__pycache__"}) +logger = logging.getLogger(__name__) _managed_slug_cache_lock = threading.Lock() _managed_slug_cache: set[str] = set() _managed_slug_cache_neg: set[str] = set() @@ -112,7 +115,7 @@ def _server_managed_slug_from_path(path: Path) -> Optional[str]: return None work_dir = Path(os.environ.get("WORK_DIR") or os.environ.get("WORKDIR") or "/work") - marker = work_dir / ".codebase" / "repos" / slug / ".ctxce_managed_upload" + marker = work_dir / STATE_DIRNAME / "repos" / slug / ".ctxce_managed_upload" try: is_managed = marker.exists() except OSError: @@ -134,7 +137,8 @@ def _server_managed_slug_from_path(path: Path) -> Optional[str]: STATE_DIRNAME = ".codebase" STATE_FILENAME = "state.json" CACHE_FILENAME = "cache.json" -PLACEHOLDER_COLLECTION_NAMES = {"", "default-collection", "my-collection"} +INDEX_JOURNAL_FILENAME = "index_journal.json" +PLACEHOLDER_COLLECTION_NAMES = {"", "codebase"} class IndexingProgress(TypedDict, total=False): files_processed: int @@ -184,6 +188,37 @@ class StagingInfo(TypedDict, total=False): repo_name: Optional[str] +class MaintenanceInfo(TypedDict, total=False): + last_empty_dir_sweep_at: Optional[str] + last_consistency_audit_at: Optional[str] + last_consistency_audit_summary: Optional[Dict[str, Any]] + + +class IndexJournalRecord(TypedDict, total=False): + path: str + op_type: str + content_hash: Optional[str] + status: str + attempts: int + created_at: str + updated_at: str + last_error: Optional[str] + + +def _index_journal_retry_delay_seconds() -> float: + try: + return max(0.0, float(os.environ.get("INDEX_JOURNAL_RETRY_DELAY_SECS", "5") or 5)) + except Exception: + return 5.0 + + +def _index_journal_max_attempts() -> int: + try: + return max(0, int(os.environ.get("INDEX_JOURNAL_MAX_ATTEMPTS", "0") or 0)) + except Exception: + return 0 + + class WorkspaceState(TypedDict, total=False): created_at: str updated_at: str @@ -204,6 +239,7 @@ class WorkspaceState(TypedDict, total=False): active_repo_slug: Optional[str] serving_repo_slug: Optional[str] staging: Optional[StagingInfo] + maintenance: Optional[MaintenanceInfo] def is_multi_repo_mode() -> bool: """Check if multi-repo mode is enabled.""" @@ -226,6 +262,16 @@ def logical_repo_reuse_enabled() -> bool: "on", } + +def bindmount_repo_detection_enabled() -> bool: + """Allow git-based repo inference for bindmount-style deployments.""" + return os.environ.get("CTXCE_BINDMOUNT_REPO_DETECTION", "").strip().lower() in { + "1", + "true", + "yes", + "on", + } + _state_lock = threading.Lock() # Track last-used timestamps for cleanup of idle workspace locks _state_locks: Dict[str, threading.RLock] = {} @@ -233,7 +279,26 @@ def logical_repo_reuse_enabled() -> bool: def _resolve_workspace_root() -> str: """Determine the default workspace root path.""" - return os.environ.get("WORKSPACE_PATH") or os.environ.get("WATCH_ROOT") or "/work" + return ( + os.environ.get("CTXCE_METADATA_ROOT") + or os.environ.get("WORKSPACE_PATH") + or os.environ.get("WATCH_ROOT") + or "/work" + ) + + +def _configured_workspace_roots() -> List[Path]: + roots: List[Path] = [] + for key in ("CTXCE_METADATA_ROOT", "WORKSPACE_PATH", "WATCH_ROOT", "WORK_DIR", "WORKDIR"): + raw = (os.environ.get(key) or "").strip() + if not raw: + continue + try: + roots.append(Path(raw).resolve()) + except Exception: + roots.append(Path(raw)) + return roots + def _resolve_repo_context( workspace_path: Optional[str] = None, @@ -247,14 +312,47 @@ def _resolve_repo_context( return resolved_workspace, repo_name if workspace_path: - detected = _detect_repo_name_from_path(Path(workspace_path)) - if detected: - return resolved_workspace, detected + try: + requested = Path(workspace_path).resolve() + workspace_root = Path(_resolve_workspace_root()).resolve() + except Exception: + requested = Path(workspace_path) + workspace_root = Path(_resolve_workspace_root()) + if requested != workspace_root: + if any(requested == root for root in _configured_workspace_roots()): + return resolved_workspace, None + detected = _detect_repo_name_from_path(requested) + if detected: + return resolved_workspace, detected return resolved_workspace, None return resolved_workspace, repo_name + +def _get_repo_workspace_dir( + repo_name: str, + workspace_path: Optional[str] = None, +) -> Path: + try: + base_dir = Path(workspace_path or _resolve_workspace_root()).resolve() + except Exception: + base_dir = Path(workspace_path or _resolve_workspace_root()).absolute() + if base_dir.name == repo_name: + return base_dir + host_index_path = (os.environ.get("HOST_INDEX_PATH") or "").strip() + if host_index_path: + host_index_root = Path(host_index_path) + if not host_index_root.is_absolute(): + host_index_root = base_dir / host_index_root + candidate = host_index_root.resolve() / repo_name + if candidate.exists() or (candidate / STATE_DIRNAME).exists(): + return candidate + dev_workspace_candidate = base_dir / "dev-workspace" / repo_name + if dev_workspace_candidate.exists() or (dev_workspace_candidate / STATE_DIRNAME).exists(): + return dev_workspace_candidate + return base_dir / repo_name + def _get_state_lock(workspace_path: Optional[str] = None, repo_name: Optional[str] = None) -> threading.RLock: """Get or create a lock for the workspace or repo state and track usage.""" if repo_name and is_multi_repo_mode(): @@ -268,13 +366,52 @@ def _get_state_lock(workspace_path: Optional[str] = None, repo_name: Optional[st _state_lock_last_used[key] = time.time() return _state_locks[key] -def _get_repo_state_dir(repo_name: str) -> Path: +def _get_repo_state_dir( + repo_name: str, + workspace_path: Optional[str] = None, +) -> Path: """Get the state directory for a repository.""" - base_dir = Path(os.environ.get("WORKSPACE_PATH") or os.environ.get("WATCH_ROOT") or "/work") + workspace_root = Path(_resolve_workspace_root()).resolve() + base_dir = Path(workspace_path or str(workspace_root)).resolve() + global_repo_state_dir = workspace_root / STATE_DIRNAME / "repos" / repo_name if is_multi_repo_mode(): - return base_dir / STATE_DIRNAME / "repos" / repo_name + # Canonical multi-repo metadata layout is shared under workspace root. + return global_repo_state_dir return base_dir / STATE_DIRNAME + +def _is_repo_local_metadata_path(path: Path) -> bool: + try: + parts = path.resolve().parts + except Exception: + parts = path.parts + try: + idx = parts.index(STATE_DIRNAME) + except ValueError: + return False + if idx > 0 and _SLUGGED_REPO_RE.match(parts[idx - 1] or ""): + return True + if "repos" in parts: + ridx = parts.index("repos") + if ridx + 1 < len(parts) and _SLUGGED_REPO_RE.match(parts[ridx + 1] or ""): + return True + return False + + +def _apply_runtime_metadata_mode(path: Path) -> None: + try: + is_dir = path.is_dir() + except Exception: + is_dir = False + if _is_repo_local_metadata_path(path): + mode = 0o777 if is_dir else 0o666 + else: + mode = 0o775 if is_dir else 0o664 + try: + os.chmod(path, mode) + except Exception: + pass + def _get_state_path(workspace_path: str) -> Path: """Get the path to the state.json file for a workspace.""" workspace = Path(workspace_path).resolve() @@ -592,15 +729,13 @@ def _git_remote_repo_name(repo_path: Path) -> Optional[str]: def _detect_repo_name_from_path(path: Path) -> str: - """Detect repository name from path using git remote origin URL. + """Detect repository name from managed upload/workspace path structure. - This ensures consistency with how the MCP server detects repos during search. Priority: - 1. Fast-path for server-managed uploads and workspace-relative paths - 2. Git remote origin URL (canonical repo name like 'Context-Engine') - 3. Git toplevel directory name (folder name like 'Context-Engine-hash') - 4. Walk up to find .git and return that folder name - 5. Return parent folder name as fallback + 1. Server-managed upload slug markers + 2. Workspace-relative first path segment + 3. Bindmount git inference when CTXCE_BINDMOUNT_REPO_DETECTION=1 + 4. Structure/name fallback """ slug = _server_managed_slug_from_path(path) if slug: @@ -623,29 +758,29 @@ def _detect_repo_name_from_path(path: Path) -> str: rel = resolved.relative_to(ws_root) if rel.parts: candidate = rel.parts[0] - if candidate not in {".codebase", ".git", "__pycache__"}: + if candidate not in INTERNAL_STATE_TOP_LEVEL_DIRS: return candidate except Exception: pass - try: - base = path if path.is_dir() else path.parent - git_name = _git_remote_repo_name(base) - if git_name: - return git_name - except Exception: - pass - try: - # Walk up to find .git - cur = path if path.is_dir() else path.parent - for p in [cur] + list(cur.parents): - try: - if (p / ".git").exists(): - return p.name - except Exception: - continue - except Exception: - pass + if bindmount_repo_detection_enabled(): + try: + base = path if path.is_dir() else path.parent + git_name = _git_remote_repo_name(base) + if git_name: + return git_name + except Exception: + pass + try: + cur = path if path.is_dir() else path.parent + for p in [cur] + list(cur.parents): + try: + if (p / ".git").exists(): + return p.name + except Exception: + continue + except Exception: + pass try: structure_name = _detect_repo_name_from_path_by_structure(path) @@ -672,12 +807,7 @@ def _atomic_write_state(state_path: Path, state: WorkspaceState) -> None: with open(temp_path, 'w', encoding='utf-8') as f: json.dump(state, f, indent=2, ensure_ascii=False) temp_path.replace(state_path) - # Ensure state/cache files are group-writable so multiple processes - # (upload service, watcher, indexer) can update them. - try: - os.chmod(state_path, 0o664) - except PermissionError: - pass + _apply_runtime_metadata_mode(state_path) except Exception: # Clean up temp file if something went wrong try: @@ -705,7 +835,7 @@ def get_workspace_state( lock_scope_path: Path if is_multi_repo_mode() and repo_name: - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, workspace_path) try: ws_root = Path(_resolve_workspace_root()) ws_dir = ws_root / repo_name @@ -717,12 +847,7 @@ def get_workspace_state( except Exception: return {} state_dir.mkdir(parents=True, exist_ok=True) - # Ensure repo state dir is group-writable so root upload service and - # non-root watcher/indexer processes can both write state/cache files. - try: - os.chmod(state_dir, 0o775) - except Exception: - pass + _apply_runtime_metadata_mode(state_dir) state_path = state_dir / STATE_FILENAME lock_scope_path = state_dir else: @@ -802,7 +927,7 @@ def update_workspace_state( # Allow updates when the repo state dir exists, even if the workspace # directory is not present (e.g. dev-remote simulations where only # .codebase state is persisted). - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, workspace_path) if not (ws_root / repo_name).exists() and not state_dir.exists(): return {} except Exception: @@ -823,8 +948,9 @@ def update_workspace_state( state["updated_at"] = datetime.now().isoformat() if is_multi_repo_mode() and repo_name: - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, workspace_path) state_dir.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(state_dir) state_path = state_dir / STATE_FILENAME else: try: @@ -1245,8 +1371,9 @@ def log_activity( return except Exception: return - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, workspace_path) state_dir.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(state_dir) state_path = state_dir / STATE_FILENAME lock_path = state_path.with_suffix(".lock") @@ -1331,6 +1458,37 @@ def _collection_name_for_repo_slug(normalized_repo: str, *, is_old_slug: bool) - return None +def _coerce_collection_repo_name(repo_name: Optional[str]) -> Optional[str]: + if not repo_name: + return None + + value = str(repo_name).strip() + if not value: + return None + + if "/" not in value and "\\" not in value: + return value + + try: + path = Path(value).resolve() + except Exception: + path = Path(value) + + try: + workspace_root = Path(_resolve_workspace_root()).resolve() + except Exception: + workspace_root = Path(_resolve_workspace_root()) + + try: + if path == workspace_root: + return None + except Exception: + pass + + detected = _extract_repo_name_from_path(str(path)) + return detected or None + + def get_collection_name(repo_name: Optional[str] = None) -> str: """Get collection name for repository or workspace. @@ -1338,7 +1496,7 @@ def get_collection_name(repo_name: Optional[str] = None) -> str: 1. Explicit COLLECTION_NAME env var - master override when set to a real value (if repo_name is an *_old clone, append _old to the override unless already present) 2. Derive from repo slug (including *_old suffix handling) - 3. Fallback: "global-collection" + 3. Fallback: DEFAULT_COLLECTION, COLLECTION_NAME, or "codebase" This ensures COLLECTION_NAME works as a master override in both local dev and container environments, while still allowing deterministic derivation @@ -1355,6 +1513,7 @@ def get_collection_name(repo_name: Optional[str] = None) -> str: pass return env_coll + repo_name = _coerce_collection_repo_name(repo_name) normalized = _normalize_repo_name_for_collection(repo_name) if repo_name else None is_old_slug = False try: @@ -1369,8 +1528,7 @@ def get_collection_name(repo_name: Optional[str] = None) -> str: if derived: return derived - # Default fallback - return "global-collection" + return os.environ.get("DEFAULT_COLLECTION") or os.environ.get("COLLECTION_NAME") or "codebase" def _detect_repo_name_from_path_by_structure(path: Path) -> str: """Detect repository name from path structure (fallback when git is unavailable).""" @@ -1405,7 +1563,7 @@ def _detect_repo_name_from_path_by_structure(path: Path) -> str: continue repo_name = rel_path.parts[0] - if repo_name in (".codebase", ".git", "__pycache__"): + if repo_name in INTERNAL_STATE_TOP_LEVEL_DIRS: continue repo_path = base / repo_name @@ -1425,7 +1583,8 @@ def _normalize_repo_slug(candidate: Optional[str]) -> Optional[str]: def _extract_repo_name_from_path(workspace_path: str) -> str: """Extract repository slug or canonical name from workspace path. - Accepts canonical slugs (repo-hash), `_old` slugs, and falls back to git remote name. + Accepts managed upload slugs and workspace-relative repo paths. Git-based + bindmount inference is opt-in via CTXCE_BINDMOUNT_REPO_DETECTION=1. """ if not workspace_path: return "" @@ -1440,14 +1599,30 @@ def _extract_repo_name_from_path(workspace_path: str) -> str: return slug try: - repo_path = path if path.is_dir() else path.parent - if (repo_path / ".git").exists(): - name = _git_remote_repo_name(repo_path) - if name: - return name + workspace_root = Path(_resolve_workspace_root()).resolve() + except Exception: + workspace_root = Path(_resolve_workspace_root()) + + try: + rel = path.relative_to(workspace_root) + if not rel.parts: + return "" + candidate = rel.parts[0] + if candidate not in INTERNAL_STATE_TOP_LEVEL_DIRS: + return candidate except Exception: pass + if bindmount_repo_detection_enabled(): + try: + repo_path = path if path.is_dir() else path.parent + if (repo_path / ".git").exists(): + name = _git_remote_repo_name(repo_path) + if name: + return name + except Exception: + pass + try: candidate = _normalize_repo_slug(path.name) if candidate: @@ -1575,10 +1750,330 @@ def _write_cache(workspace_path: str, cache: Dict[str, Any]) -> None: pass -def get_cached_file_hash(file_path: str, repo_name: Optional[str] = None) -> str: +def _get_index_journal_path( + workspace_path: Optional[str] = None, repo_name: Optional[str] = None +) -> Path: + workspace_path, repo_name = _resolve_repo_context(workspace_path, repo_name) + if repo_name: + state_dir = _get_repo_state_dir(repo_name, workspace_path) + else: + state_dir = _get_global_state_dir(workspace_path) + return state_dir / INDEX_JOURNAL_FILENAME + + +def _read_index_journal_file_uncached(journal_path: Path) -> Dict[str, Any]: + try: + with journal_path.open("r", encoding="utf-8-sig") as f: + obj = json.load(f) + if isinstance(obj, dict): + operations = obj.get("operations", {}) + if isinstance(operations, dict): + return obj + except (OSError, json.JSONDecodeError, ValueError): + pass + now = datetime.now().isoformat() + return {"version": 1, "operations": {}, "created_at": now, "updated_at": now} + + +def _write_index_journal( + workspace_path: Optional[str], + repo_name: Optional[str], + journal: Dict[str, Any], +) -> None: + workspace_path, repo_name = _resolve_repo_context(workspace_path, repo_name) + lock = _get_state_lock(workspace_path, repo_name) + with lock: + journal_path = _get_index_journal_path(workspace_path, repo_name) + journal_path.parent.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(journal_path.parent) + lock_path = journal_path.with_suffix(journal_path.suffix + ".lock") + with _cross_process_lock(lock_path): + tmp = journal_path.with_suffix(f".tmp.{uuid.uuid4().hex[:8]}") + try: + with open(tmp, "w", encoding="utf-8") as f: + json.dump(journal, f, ensure_ascii=False, indent=2) + tmp.replace(journal_path) + _apply_runtime_metadata_mode(journal_path) + finally: + try: + tmp.unlink(missing_ok=True) + except Exception: + pass + + +def _update_index_journal( + workspace_path: Optional[str], + repo_name: Optional[str], + mutator, +) -> Dict[str, Any]: + workspace_path, repo_name = _resolve_repo_context(workspace_path, repo_name) + lock = _get_state_lock(workspace_path, repo_name) + with lock: + journal_path = _get_index_journal_path(workspace_path, repo_name) + journal_path.parent.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(journal_path.parent) + lock_path = journal_path.with_suffix(journal_path.suffix + ".lock") + with _cross_process_lock(lock_path): + journal = _read_index_journal_file_uncached(journal_path) + mutator(journal) + journal["updated_at"] = datetime.now().isoformat() + tmp = journal_path.with_suffix(f".tmp.{uuid.uuid4().hex[:8]}") + try: + with open(tmp, "w", encoding="utf-8") as f: + json.dump(journal, f, ensure_ascii=False, indent=2) + tmp.replace(journal_path) + _apply_runtime_metadata_mode(journal_path) + finally: + try: + tmp.unlink(missing_ok=True) + except Exception: + pass + return journal + + +def upsert_index_journal_entries( + entries: List[Dict[str, Any]], + *, + workspace_path: Optional[str] = None, + repo_name: Optional[str] = None, +) -> Dict[str, Any]: + """Persist or replace repo-scoped index journal entries keyed by normalized path.""" + normalized_entries: List[IndexJournalRecord] = [] + now = datetime.now().isoformat() + valid_statuses = {"pending", "in_progress", "failed", "done"} + for entry in entries or []: + path = _normalize_cache_key_path(str(entry.get("path") or "")) + op_type = str(entry.get("op_type") or "").strip().lower() + if not path or op_type not in {"upsert", "delete"}: + continue + content_hash = str(entry.get("content_hash") or "").strip() or None + status = str(entry.get("status") or "pending").strip().lower() + if status not in valid_statuses: + status = "pending" + try: + attempts = int(entry.get("attempts", 0) or 0) + except Exception: + attempts = 0 + if attempts < 0: + attempts = 0 + last_error = entry.get("last_error") + if last_error is not None: + last_error = str(last_error) + normalized_entries.append( + { + "path": path, + "op_type": op_type, + "content_hash": content_hash, + "status": status, + "attempts": attempts, + "created_at": str(entry.get("created_at") or now), + "updated_at": str(entry.get("updated_at") or now), + "last_error": last_error, + } + ) + + def _mutate(journal: Dict[str, Any]) -> None: + ops = journal.setdefault("operations", {}) + if not isinstance(ops, dict): + ops = {} + journal["operations"] = ops + for entry in normalized_entries: + ops[entry["path"]] = entry + + return _update_index_journal(workspace_path, repo_name, _mutate) + + +def clear_index_journal_entries( + *, + workspace_path: Optional[str] = None, + repo_name: Optional[str] = None, +) -> int: + """Remove all operations from a workspace/repo index journal.""" + removed = 0 + + def _mutate(journal: Dict[str, Any]) -> None: + nonlocal removed + ops = journal.get("operations", {}) + if isinstance(ops, dict): + removed = len(ops) + journal["operations"] = {} + + _update_index_journal(workspace_path, repo_name, _mutate) + return removed + + +def list_pending_index_journal_entries( + workspace_path: Optional[str] = None, + repo_name: Optional[str] = None, +) -> List[IndexJournalRecord]: + """Return watcher-retryable journal records for a workspace or specific repo.""" + workspace_path, repo_name = _resolve_repo_context(workspace_path, repo_name) + retry_delay = _index_journal_retry_delay_seconds() + max_attempts = _index_journal_max_attempts() + now = datetime.now() + + def _read_repo_journal_entries( + target_repo_name: Optional[str], + *, + target_workspace_path: Optional[str] = None, + ) -> List[IndexJournalRecord]: + journal = _read_index_journal_file_uncached( + _get_index_journal_path(target_workspace_path or workspace_path, target_repo_name) + ) + merged_ops = journal.get("operations", {}) + if not isinstance(merged_ops, dict): + merged_ops = {} + result: List[IndexJournalRecord] = [] + for rec in merged_ops.values(): + if not isinstance(rec, dict): + continue + status = str(rec.get("status") or "pending").strip().lower() + if status not in {"pending", "failed"}: + continue + attempts_raw = rec.get("attempts") + try: + attempts = int(attempts_raw or 0) + except (ValueError, TypeError): + attempts = 0 + logger.warning( + "workspace_state::invalid_journal_attempts", + extra={"attempts": attempts_raw, "path": str(rec.get("path") or "")}, + ) + if max_attempts > 0 and attempts >= max_attempts: + continue + if status == "failed" and retry_delay > 0: + updated_at = str(rec.get("updated_at") or "").strip() + if updated_at: + try: + last = datetime.fromisoformat(updated_at) + if (now - last).total_seconds() < retry_delay: + continue + except Exception: + pass + p = _normalize_cache_key_path(str(rec.get("path") or "")) + op_type = str(rec.get("op_type") or "").strip().lower() + if not p or op_type not in {"upsert", "delete"}: + continue + result.append( + { + "path": p, + "op_type": op_type, + "content_hash": str(rec.get("content_hash") or "").strip() or None, + "status": status, + "attempts": attempts, + "created_at": str(rec.get("created_at") or ""), + "updated_at": str(rec.get("updated_at") or ""), + "last_error": str(rec.get("last_error") or "").strip() or None, + } + ) + return result + + if repo_name: + return _read_repo_journal_entries(repo_name) + + result: List[IndexJournalRecord] = [] + root_path = Path(workspace_path or _resolve_workspace_root()).resolve() + repo_candidates: set[str] = set() + multi_repo_mode = is_multi_repo_mode() + try: + for repo_root in root_path.iterdir(): + if not repo_root.is_dir(): + continue + if repo_root.name in INTERNAL_STATE_TOP_LEVEL_DIRS: + continue + if (not multi_repo_mode) and (not _SLUGGED_REPO_RE.match(repo_root.name)): + continue + repo_candidates.add(repo_root.name) + except Exception: + pass + + try: + repos_state_root = root_path / STATE_DIRNAME / "repos" + if repos_state_root.exists(): + for state_dir in repos_state_root.iterdir(): + if not state_dir.is_dir(): + continue + repo_candidates.add(state_dir.name) + except Exception: + pass + + try: + metadata_repos_root = Path(_resolve_workspace_root()).resolve() / STATE_DIRNAME / "repos" + if metadata_repos_root != root_path / STATE_DIRNAME / "repos" and metadata_repos_root.exists(): + for state_dir in metadata_repos_root.iterdir(): + if not state_dir.is_dir(): + continue + repo_candidates.add(state_dir.name) + except Exception: + pass + + for candidate in sorted(repo_candidates): + candidate_workspace_path: Optional[str] = None + if not multi_repo_mode: + candidate_workspace_path = str(root_path / candidate) + result.extend( + _read_repo_journal_entries( + candidate, + target_workspace_path=candidate_workspace_path, + ) + ) + + if result: + return result + return _read_repo_journal_entries(None) + + +def update_index_journal_entry_status( + path: str, + *, + status: str, + error: Optional[str] = None, + workspace_path: Optional[str] = None, + repo_name: Optional[str] = None, + remove_on_done: bool = True, +) -> Dict[str, Any]: + """Update or clear a repo-scoped journal entry after processing.""" + normalized_path = _normalize_cache_key_path(path) + now = datetime.now().isoformat() + + def _mutate(journal: Dict[str, Any]) -> None: + ops = journal.setdefault("operations", {}) + if not isinstance(ops, dict): + ops = {} + journal["operations"] = ops + rec = ops.get(normalized_path) + if not isinstance(rec, dict): + return + if status == "done" and remove_on_done: + ops.pop(normalized_path, None) + return + rec["status"] = status + rec["updated_at"] = now + attempts_raw = rec.get("attempts") + try: + attempts = int(attempts_raw or 0) + except (ValueError, TypeError): + attempts = 0 + logger.warning( + "workspace_state::invalid_journal_attempts", + extra={"attempts": attempts_raw, "path": normalized_path}, + ) + rec["attempts"] = attempts + 1 + rec["last_error"] = str(error or "").strip() or None + ops[normalized_path] = rec + + return _update_index_journal(workspace_path, repo_name, _mutate) + + +def get_cached_file_hash( + file_path: str, + repo_name: Optional[str] = None, + metadata_root: Optional[str] = None, +) -> str: """Get cached file hash for tracking changes.""" + root = metadata_root or _resolve_workspace_root() if is_multi_repo_mode() and repo_name: - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, root) cache_path = state_dir / CACHE_FILENAME cache = _read_cache_file_cached(cache_path) @@ -1589,19 +2084,23 @@ def get_cached_file_hash(file_path: str, repo_name: Optional[str] = None) -> str return str(val.get("hash") or "") return str(val or "") else: - cache = _read_cache_cached(_resolve_workspace_root()) + cache = _read_cache_cached(root) fp = _normalize_cache_key_path(file_path) val = cache.get("file_hashes", {}).get(fp, "") if isinstance(val, dict): return str(val.get("hash") or "") return str(val or "") - return "" - -def set_cached_file_hash(file_path: str, file_hash: str, repo_name: Optional[str] = None) -> None: +def set_cached_file_hash( + file_path: str, + file_hash: str, + repo_name: Optional[str] = None, + metadata_root: Optional[str] = None, +) -> None: """Set cached file hash for tracking changes.""" fp = _normalize_cache_key_path(file_path) + root = metadata_root or _resolve_workspace_root() st_size: Optional[int] = None st_mtime: Optional[int] = None @@ -1615,14 +2114,15 @@ def set_cached_file_hash(file_path: str, file_hash: str, repo_name: Optional[str if is_multi_repo_mode() and repo_name: try: - ws_root = Path(_resolve_workspace_root()) + ws_root = Path(root) if not (ws_root / repo_name).exists(): return except Exception: return - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, str(ws_root)) cache_path = state_dir / CACHE_FILENAME state_dir.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(state_dir) if cache_path.exists(): cache = _read_cache_file_cached(cache_path) @@ -1659,7 +2159,7 @@ def set_cached_file_hash(file_path: str, file_hash: str, repo_name: Optional[str _memoize_cache_obj(cache_path, cache) return - cache = _read_cache_cached(_resolve_workspace_root()) + cache = _read_cache_cached(root) existing = cache.get("file_hashes", {}).get(fp) if isinstance(existing, dict) and st_size is not None and st_mtime is not None: if ( @@ -1683,14 +2183,14 @@ def set_cached_file_hash(file_path: str, file_hash: str, repo_name: Optional[str pass cache.setdefault("file_hashes", {})[fp] = entry cache["updated_at"] = datetime.now().isoformat() - _write_cache(_resolve_workspace_root(), cache) - _memoize_cache_obj(_get_cache_path(_resolve_workspace_root()), cache) + _write_cache(root, cache) + _memoize_cache_obj(_get_cache_path(root), cache) def get_cached_file_meta(file_path: str, repo_name: Optional[str] = None) -> Dict[str, Any]: fp = _normalize_cache_key_path(file_path) if is_multi_repo_mode() and repo_name: - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, _resolve_workspace_root()) cache_path = state_dir / CACHE_FILENAME cache = _read_cache_file_cached(cache_path) @@ -1711,10 +2211,15 @@ def get_cached_file_meta(file_path: str, repo_name: Optional[str] = None) -> Dic return {} -def remove_cached_file(file_path: str, repo_name: Optional[str] = None) -> None: +def remove_cached_file( + file_path: str, + repo_name: Optional[str] = None, + metadata_root: Optional[str] = None, +) -> None: """Remove file entry from cache.""" + root = metadata_root or _resolve_workspace_root() if is_multi_repo_mode() and repo_name: - state_dir = _get_repo_state_dir(repo_name) + state_dir = _get_repo_state_dir(repo_name, root) cache_path = state_dir / CACHE_FILENAME if cache_path.exists(): @@ -1730,13 +2235,13 @@ def remove_cached_file(file_path: str, repo_name: Optional[str] = None) -> None: _memoize_cache_obj(cache_path, cache) return - cache = _read_cache_cached(_resolve_workspace_root()) + cache = _read_cache_cached(root) fp = _normalize_cache_key_path(file_path) if fp in cache.get("file_hashes", {}): cache["file_hashes"].pop(fp, None) cache["updated_at"] = datetime.now().isoformat() - _write_cache(_resolve_workspace_root(), cache) - _memoize_cache_obj(_get_cache_path(_resolve_workspace_root()), cache) + _write_cache(root, cache) + _memoize_cache_obj(_get_cache_path(root), cache) def cleanup_old_cache_locks(max_idle_seconds: int = 900) -> int: @@ -1780,42 +2285,65 @@ def cleanup_old_cache_locks(max_idle_seconds: int = 900) -> int: def get_collection_mappings(search_root: Optional[str] = None) -> List[Dict[str, Any]]: - """Enumerate collection mappings with origin metadata.""" + """Enumerate collection mappings with origin metadata. + + `search_root` may point at either workspace root (`/work`) or codebase root + (`/work/.codebase`). + """ root_path = Path(search_root or _resolve_workspace_root()).resolve() + if root_path.name == STATE_DIRNAME: + workspace_root = root_path.parent + codebase_root = root_path + else: + workspace_root = root_path + codebase_root = root_path / STATE_DIRNAME mappings: List[Dict[str, Any]] = [] try: if is_multi_repo_mode(): - repos_root = root_path / STATE_DIRNAME / "repos" + seen_state_files: set[str] = set() + + def _append_repo_mapping(repo_name: str, state_path: Path) -> None: + if not state_path.exists(): + return + try: + state_key = str(state_path.resolve()) + except Exception: + state_key = str(state_path) + if state_key in seen_state_files: + return + seen_state_files.add(state_key) + + try: + with open(state_path, "r", encoding="utf-8-sig") as f: + state = json.load(f) or {} + except Exception as e: + print(f"[workspace_state] Failed to read repo state from {state_path}: {e}") + return + + origin = state.get("origin", {}) or {} + repo_workspace_dir = _get_repo_workspace_dir(repo_name, str(workspace_root)) + mappings.append( + { + "repo_name": repo_name, + "collection_name": state.get("qdrant_collection") + or get_collection_name(repo_name), + "container_path": origin.get("container_path") + or str(repo_workspace_dir.resolve()), + "source_path": origin.get("source_path"), + "state_file": str(state_path), + "updated_at": state.get("updated_at"), + } + ) + + # Shared metadata root (`/.codebase/repos//state.json`) + repos_root = codebase_root / "repos" if repos_root.exists(): for repo_dir in sorted(p for p in repos_root.iterdir() if p.is_dir()): - repo_name = repo_dir.name - state_path = repo_dir / STATE_FILENAME - if not state_path.exists(): - continue - try: - with open(state_path, "r", encoding="utf-8-sig") as f: - state = json.load(f) or {} - except Exception as e: - print(f"[workspace_state] Failed to read repo state from {state_path}: {e}") - continue - - origin = state.get("origin", {}) or {} - mappings.append( - { - "repo_name": repo_name, - "collection_name": state.get("qdrant_collection") - or get_collection_name(repo_name), - "container_path": origin.get("container_path") - or str((Path(_resolve_workspace_root()) / repo_name).resolve()), - "source_path": origin.get("source_path"), - "state_file": str(state_path), - "updated_at": state.get("updated_at"), - } - ) + _append_repo_mapping(repo_dir.name, repo_dir / STATE_FILENAME) else: - state_path = root_path / STATE_DIRNAME / STATE_FILENAME + state_path = codebase_root / STATE_FILENAME if state_path.exists(): try: with open(state_path, "r", encoding="utf-8-sig") as f: @@ -1824,14 +2352,14 @@ def get_collection_mappings(search_root: Optional[str] = None) -> List[Dict[str, state = {} origin = state.get("origin", {}) or {} - repo_name = origin.get("repo_name") or Path(root_path).name + repo_name = origin.get("repo_name") or Path(workspace_root).name mappings.append( { "repo_name": repo_name, "collection_name": state.get("qdrant_collection") or get_collection_name(repo_name), "container_path": origin.get("container_path") - or str(root_path), + or str(workspace_root), "source_path": origin.get("source_path"), "state_file": str(state_path), "updated_at": state.get("updated_at"), @@ -2116,6 +2644,8 @@ def set_cached_symbols(file_path: str, symbols: dict, file_hash: str) -> None: """Save symbol metadata for a file. Extends existing to include pseudo data.""" cache_path = _get_symbol_cache_path(file_path) cache_path.parent.mkdir(parents=True, exist_ok=True) + _apply_runtime_metadata_mode(cache_path.parent) + temp_path = cache_path.with_suffix(f".tmp.{uuid.uuid4().hex[:8]}") try: cache_data = { @@ -2125,18 +2655,16 @@ def set_cached_symbols(file_path: str, symbols: dict, file_hash: str) -> None: "symbols": symbols } - with open(cache_path, 'w', encoding='utf-8') as f: + with open(temp_path, 'w', encoding='utf-8') as f: json.dump(cache_data, f, indent=2) - - # Ensure symbol cache files are group-writable so both indexer and - # watcher processes (potentially different users sharing a group) - # can update them on shared volumes. - try: - os.chmod(cache_path, 0o664) - except PermissionError: - pass + temp_path.replace(cache_path) + _apply_runtime_metadata_mode(cache_path) except Exception as e: print(f"[SYMBOL_CACHE_WARNING] Failed to save symbol cache for {file_path}: {e}") + try: + temp_path.unlink(missing_ok=True) + except Exception: + pass def get_cached_pseudo(file_path: str, symbol_id: str) -> tuple[str, list[str]]: @@ -2236,7 +2764,7 @@ def clear_symbol_cache( target_dirs: List[Path] = [] if is_multi_repo_mode() and repo_name: - target_dirs.append(_get_repo_state_dir(repo_name) / "symbols") + target_dirs.append(_get_repo_state_dir(repo_name, workspace_path) / "symbols") else: try: cache_parent = _get_cache_path(workspace_root).parent @@ -2288,6 +2816,53 @@ def compare_symbol_changes(old_symbols: dict, new_symbols: dict) -> tuple[list, unchanged = [] changed = [] + # Primary key should not be absolute start_line alone; leading comments/import + # shifts can move every symbol without changing their bodies. Prefer exact id + # first, then fall back to stable metadata matching. + old_symbols = old_symbols or {} + new_symbols = new_symbols or {} + remaining_old_by_exact = dict(old_symbols) + remaining_old_by_signature: Dict[tuple[str, str, str], list[str]] = {} + remaining_old_by_name_kind: Dict[tuple[str, str], list[str]] = {} + + for old_symbol_id, old_info in remaining_old_by_exact.items(): + kind = str(old_info.get("type") or "") + name = str(old_info.get("name") or "") + content_hash = str(old_info.get("content_hash") or "") + if kind and name and content_hash: + remaining_old_by_signature.setdefault((kind, name, content_hash), []).append( + old_symbol_id + ) + if kind and name: + remaining_old_by_name_kind.setdefault((kind, name), []).append(old_symbol_id) + + def _consume_old_symbol(old_id: str, old_info: dict) -> None: + remaining_old_by_exact.pop(old_id, None) + + old_kind = str(old_info.get("type") or "") + old_name = str(old_info.get("name") or "") + old_hash = str(old_info.get("content_hash") or "") + + if old_kind and old_name and old_hash: + sig = (old_kind, old_name, old_hash) + sig_ids = remaining_old_by_signature.get(sig) or [] + if old_id in sig_ids: + sig_ids.remove(old_id) + if sig_ids: + remaining_old_by_signature[sig] = sig_ids + else: + remaining_old_by_signature.pop(sig, None) + + if old_kind and old_name: + nk = (old_kind, old_name) + nk_ids = remaining_old_by_name_kind.get(nk) or [] + if old_id in nk_ids: + nk_ids.remove(old_id) + if nk_ids: + remaining_old_by_name_kind[nk] = nk_ids + else: + remaining_old_by_name_kind.pop(nk, None) + for symbol_id, symbol_info in new_symbols.items(): if symbol_id in old_symbols: old_info = old_symbols[symbol_id] @@ -2296,6 +2871,26 @@ def compare_symbol_changes(old_symbols: dict, new_symbols: dict) -> tuple[list, unchanged.append(symbol_id) else: changed.append(symbol_id) + _consume_old_symbol(symbol_id, old_info) + continue + + kind = str(symbol_info.get("type") or "") + name = str(symbol_info.get("name") or "") + content_hash = str(symbol_info.get("content_hash") or "") + signature = (kind, name, content_hash) + matched_old_ids = remaining_old_by_signature.get(signature) or [] + if matched_old_ids: + old_id = matched_old_ids.pop(0) + if not matched_old_ids: + remaining_old_by_signature.pop(signature, None) + _consume_old_symbol(old_id, old_symbols.get(old_id, {})) + unchanged.append(symbol_id) + continue + + # Same logical symbol name/type exists but content differs: changed. + if kind and name and remaining_old_by_name_kind.get((kind, name)): + remaining_old_by_name_kind.pop((kind, name), None) + changed.append(symbol_id) else: # New symbol changed.append(symbol_id) @@ -2537,6 +3132,3 @@ def _list_workspaces_from_qdrant(seen_paths: set) -> List[Dict[str, Any]]: pass return workspaces - - -# Add missing functions that callers expect (already defined above) \ No newline at end of file diff --git a/skills/context-engine/SKILL.md b/skills/context-engine/SKILL.md index 50ee67b6..c139ca8c 100644 --- a/skills/context-engine/SKILL.md +++ b/skills/context-engine/SKILL.md @@ -1,11 +1,11 @@ --- name: context-engine -description: Codebase search and context retrieval for any programming language. Hybrid semantic/lexical search with neural reranking. Use for code lookup, finding implementations, understanding codebases, Q&A grounded in source code, and persistent memory across sessions. +description: Codebase search and context retrieval for any programming language. Use for code lookup, finding implementations, understanding codebases, Q&A grounded in source code, and persistent memory across sessions. --- # Context-Engine -Search and retrieve code context from any codebase using hybrid vector search (semantic + lexical) with neural reranking. +Search and retrieve code context from any codebase using the configured retrieval mode. `repo_search` is the canonical code search tool; dense, fusion, and reranking behavior depends on deployment settings. ## Decision Tree: Choosing the Right Tool @@ -14,8 +14,8 @@ What do you need? | +-- Find code locations/implementations | | - | +-- Simple query --> info_request - | +-- Need filters/control --> repo_search + | +-- Any query --> repo_search + | +-- Need file-type focus --> repo_search with profile | +-- Understand how something works | | @@ -29,13 +29,13 @@ What do you need? | +-- Find specific file types | | - | +-- Test files --> search_tests_for - | +-- Config files --> search_config_for + | +-- Test files --> repo_search with profile="tests" + | +-- Config files --> repo_search with profile="config" | +-- Find relationships | | - | +-- Who calls this function --> search_callers_for - | +-- Who imports this module --> search_importers_for + | +-- Who calls this function --> symbol_graph query_type="callers" + | +-- Who imports this module --> symbol_graph query_type="importers" | +-- Symbol graph navigation (callers/defs/importers) --> symbol_graph | +-- Git history --> search_commits_for @@ -47,7 +47,7 @@ What do you need? ## Primary Search: repo_search -Use `repo_search` (or its alias `code_search`) for most code lookups. Reranking is ON by default. +Use `repo_search` for code lookups. Retrieval mode and reranking are controlled by deployment configuration and per-call arguments. ```json { @@ -106,24 +106,10 @@ Use `repo: "*"` to search all indexed repos. - `ext` - File extension - `repo` - Repository filter for multi-repo setups - `case` - Case-sensitive matching - -## Simple Lookup: info_request - -Use `info_request` for natural language queries with minimal parameters: - -```json -{ - "info_request": "how does user authentication work" -} -``` - -Add explanations: -```json -{ - "info_request": "database connection pooling", - "include_explanation": true -} -``` +- `profile` - Focus common scopes: + - `"tests"` - Test files + - `"config"` - Configuration files + - `"code"` - Source-code extensions ## Q&A with Citations: context_answer @@ -206,28 +192,27 @@ Find structurally similar code patterns across all languages. Accepts **either** The `query_signature` encodes control flow: `L` (loops), `B` (branches), `T` (try/except), `M` (match). -## Specialized Search Tools +## Focused Search Profiles -**search_tests_for** - Find test files: -```json -{"query": "UserService", "limit": 10} -``` +Use `repo_search.profile` instead of separate focused tools. -**search_config_for** - Find config files: +**Test files**: ```json -{"query": "database connection", "limit": 5} +{"query": "UserService", "profile": "tests", "limit": 10} ``` -**search_callers_for** - Find callers of a symbol: +**Config files**: ```json -{"query": "processPayment", "language": "typescript"} +{"query": "database connection", "profile": "config", "limit": 5} ``` -**search_importers_for** - Find importers: +**Source-code files**: ```json -{"query": "utils/helpers", "limit": 10} +{"query": "imports qdrant client", "profile": "code", "limit": 10} ``` +For caller/importer relationships, prefer `symbol_graph` when you know the symbol. Use `repo_search` for exploratory prose queries. + **symbol_graph** - Symbol graph navigation (callers / definition / importers): ```json {"symbol": "ASTAnalyzer", "query_type": "definition", "limit": 10} @@ -344,10 +329,7 @@ With recreate (drops existing data): Set via `output_format` parameter. -## Aliases and Compat Wrappers - -**Aliases:** -- `code_search` = `repo_search` (identical behavior) +## Compat Wrappers **Cross-server tools:** - `memory_store` / `memory_find` — Memory server tools for persistent knowledge diff --git a/templates/admin/acl.html b/templates/admin/acl.html index 952a0ce9..656ed773 100644 --- a/templates/admin/acl.html +++ b/templates/admin/acl.html @@ -1,6 +1,40 @@ {% extends "admin/base.html" %} {% block content %} + {% set qp = request.query_params %} + {% if qp and ((qp.get("copied") and qp.get("new")) or qp.get("deleted") or qp.get("journal_cleared")) %} +
+ {% if qp.get("copied") and qp.get("new") %} +
+ Copied collection {{ qp.get("copied") }}{{ qp.get("new") }}. + {% if qp.get("graph_copied") == "1" %} + (graph clone copied) + {% elif qp.get("graph_copied") == "0" %} + (graph clone not copied; will rebuild/backfill) + {% endif %} +
+ {% endif %} + {% if qp.get("deleted") %} +
+ Deleted collection {{ qp.get("deleted") }}. + {% if qp.get("graph_deleted") == "1" %} + (graph clone deleted) + {% elif qp.get("graph_deleted") == "0" %} + (graph clone not deleted or missing) + {% endif %} +
+ {% endif %} + {% if qp.get("journal_cleared") %} +
+ Cleared index journal for collection {{ qp.get("journal_cleared") }}. + {% if qp.get("journal_removed") %} + ({{ qp.get("journal_removed") }} entries removed) + {% endif %} +
+ {% endif %} +
+ {% endif %} +

Users

@@ -168,6 +202,10 @@

Collections

+
+ + +
{% endif %} {% if deletion_enabled %}
@@ -395,6 +433,12 @@

Grant Collection Access

+
+ + +
` : ""; const deleteHtml = deletionEnabled diff --git a/tests/conftest.py b/tests/conftest.py index b64ff9a9..a2586d83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,33 @@ os.environ.setdefault("PATTERN_VECTORS", "1") +_INTEGRATION_TEST_FILES = { + "test_collection_memory_backup_restore.py", + "test_integration_qdrant.py", + "test_subprocess_hybrid_smoke.py", + "test_tier2_fallback.py", +} + + +def pytest_addoption(parser): + parser.addoption( + "--run-integration", + action="store_true", + default=False, + help="collect tests that start real services such as Qdrant", + ) + + +def pytest_ignore_collect(collection_path, config): + if config.getoption("--run-integration", default=False): + return False + try: + name = Path(str(collection_path)).name + except Exception: + name = "" + return name in _INTEGRATION_TEST_FILES + + @pytest.fixture(scope="session", autouse=True) def _ensure_mcp_imported(): """Ensure mcp package is properly imported before any tests run. @@ -86,7 +113,7 @@ def qdrant_url(): pytest.skip("testcontainers not available and QDRANT_URL not set") container = ( - DockerContainer("qdrant/qdrant:latest") + DockerContainer("qdrant/qdrant:v1.15.4") .with_env("TESTCONTAINERS_RYUK_DISABLED", "true") .with_env("TESTCONTAINERS_RYUK_TIMEOUT", "0") .with_exposed_ports(6333) diff --git a/tests/test_admin_collection_delete.py b/tests/test_admin_collection_delete.py index ad42807e..1f0b9b4b 100644 --- a/tests/test_admin_collection_delete.py +++ b/tests/test_admin_collection_delete.py @@ -1,4 +1,6 @@ import importlib +import sys +import types import pytest from fastapi.testclient import TestClient @@ -22,11 +24,6 @@ def _fake_render_admin_error(_request, title, message, back_href="/admin", statu monkeypatch.setattr(srv, "render_admin_error", _fake_render_admin_error) - def _should_not_be_called(**_kwargs): - raise AssertionError("delete_collection_everywhere should not be called when env gate is off") - - monkeypatch.setattr(srv, "delete_collection_everywhere", _should_not_be_called) - client = TestClient(srv.app) resp = client.post("/admin/collections/delete", data={"collection": "c1", "delete_fs": ""}) assert resp.status_code == 403 @@ -52,6 +49,72 @@ def test_admin_role_gate_blocks_non_admin(monkeypatch): assert resp.json().get("detail") == "Admin required" +@pytest.mark.unit +def test_delete_redirect_includes_graph_deleted_param(monkeypatch): + monkeypatch.setenv("CTXCE_AUTH_ENABLED", "1") + monkeypatch.setenv("CTXCE_ADMIN_COLLECTION_DELETE_ENABLED", "1") + + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + + monkeypatch.setattr(srv, "_require_admin_session", lambda _req: {"user_id": "admin"}) + + def _fake_delete_collection_everywhere(**_kwargs): + return {"qdrant_deleted": True, "qdrant_graph_deleted": True} + + monkeypatch.setitem( + sys.modules, + "scripts.collection_admin", + types.SimpleNamespace(delete_collection_everywhere=_fake_delete_collection_everywhere), + ) + + client = TestClient(srv.app) + resp = client.post("/admin/collections/delete", data={"collection": "c1", "delete_fs": ""}, follow_redirects=False) + assert resp.status_code == 302 + loc = resp.headers.get("location") or "" + assert "deleted=c1" in loc + assert "graph_deleted=1" in loc + + +@pytest.mark.unit +def test_clear_journal_endpoint_clears_mapped_collection(monkeypatch): + monkeypatch.setenv("CTXCE_AUTH_ENABLED", "1") + + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + + calls = {} + monkeypatch.setattr(srv, "_require_admin_session", lambda _req: {"user_id": "admin"}) + monkeypatch.setitem( + sys.modules, + "scripts.indexing_admin", + types.SimpleNamespace( + resolve_collection_root=lambda **_kwargs: ( + "/work/repo-1234567890abcdef", + "repo-1234567890abcdef", + ) + ), + ) + + def _clear_index_journal_entries(**kwargs): + calls.update(kwargs) + return 3 + + monkeypatch.setattr(srv, "clear_index_journal_entries", _clear_index_journal_entries) + + client = TestClient(srv.app) + resp = client.post("/admin/collections/clear-journal", data={"collection": "c1"}, follow_redirects=False) + + assert resp.status_code == 302 + loc = resp.headers.get("location") or "" + assert "journal_cleared=c1" in loc + assert "journal_removed=3" in loc + assert calls == { + "workspace_path": "/work/repo-1234567890abcdef", + "repo_name": "repo-1234567890abcdef", + } + + @pytest.mark.unit def test_collection_admin_refuses_when_env_disabled(monkeypatch): monkeypatch.setenv("CTXCE_ADMIN_COLLECTION_DELETE_ENABLED", "0") diff --git a/tests/test_admin_ui.py b/tests/test_admin_ui.py new file mode 100644 index 00000000..8e068d8c --- /dev/null +++ b/tests/test_admin_ui.py @@ -0,0 +1,36 @@ +from starlette.requests import Request + +from scripts.admin_ui import ( + render_admin_acl, + render_admin_bootstrap, + render_admin_error, + render_admin_login, +) + + +def _request(path: str = "/admin/login") -> Request: + return Request( + { + "type": "http", + "method": "GET", + "path": path, + "headers": [], + "query_string": b"", + "server": ("testserver", 80), + "scheme": "http", + "client": ("127.0.0.1", 12345), + } + ) + + +def test_admin_templates_render_with_request_first_api(): + request = _request() + + responses = [ + render_admin_login(request), + render_admin_bootstrap(request), + render_admin_acl(request, users=[], collections=[], grants={}), + render_admin_error(request, title="Error", message="Something failed"), + ] + + assert [response.status_code for response in responses] == [200, 200, 200, 400] diff --git a/tests/test_cache_deduplication.py b/tests/test_cache_deduplication.py index 8180e755..5346b537 100644 --- a/tests/test_cache_deduplication.py +++ b/tests/test_cache_deduplication.py @@ -87,13 +87,13 @@ def test_cache_ttl_expiration(self): """Test TTL-based expiration.""" cache = UnifiedCache("test_ttl", max_size=10, eviction_policy=EvictionPolicy.TTL, default_ttl=0.1) - # Set value with short TTL - cache.set("key1", "value1", ttl=0.1) - self.assertEqual(cache.get("key1"), "value1") + with patch("scripts.cache_manager.time.time") as fake_time: + fake_time.return_value = 1000.0 + cache.set("key1", "value1", ttl=0.1) + self.assertEqual(cache.get("key1"), "value1") - # Wait for expiration - time.sleep(0.2) - self.assertIsNone(cache.get("key1")) # Should be expired + fake_time.return_value = 1000.2 + self.assertIsNone(cache.get("key1")) # Should be expired def test_cache_statistics(self): """Test cache statistics tracking.""" @@ -149,27 +149,29 @@ def test_cached_decorator(self): """Test the cached decorator.""" call_count = 0 - @cached("test_decorator", ttl=1.0) - def expensive_function(x): - nonlocal call_count - call_count += 1 - return x * 2 + with patch("scripts.cache_manager.time.time") as fake_time: + fake_time.return_value = 1000.0 + + @cached("test_decorator", ttl=1.0) + def expensive_function(x): + nonlocal call_count + call_count += 1 + return x * 2 - # First call should compute - result1 = expensive_function(5) - self.assertEqual(result1, 10) - self.assertEqual(call_count, 1) + # First call should compute + result1 = expensive_function(5) + self.assertEqual(result1, 10) + self.assertEqual(call_count, 1) - # Second call should use cache - result2 = expensive_function(5) - self.assertEqual(result2, 10) - self.assertEqual(call_count, 1) # Should not increase + # Second call should use cache + result2 = expensive_function(5) + self.assertEqual(result2, 10) + self.assertEqual(call_count, 1) # Should not increase - # Wait for expiration and call again - time.sleep(1.1) - result3 = expensive_function(5) - self.assertEqual(result3, 10) - self.assertEqual(call_count, 2) # Should recompute + fake_time.return_value = 1001.1 + result3 = expensive_function(5) + self.assertEqual(result3, 10) + self.assertEqual(call_count, 2) # Should recompute class TestRequestDeduplication(unittest.TestCase): @@ -304,42 +306,46 @@ def test_deduplication_ttl(self): request = {'queries': ['test'], 'limit': 10} - # First request should be unique - is_dup1, fp1 = deduplicator.is_duplicate(request) - self.assertFalse(is_dup1) + with patch("scripts.deduplication.time.time") as fake_time: + fake_time.return_value = 1000.0 - # Wait for expiration - time.sleep(0.2) + # First request should be unique + is_dup1, fp1 = deduplicator.is_duplicate(request) + self.assertFalse(is_dup1) - # Same request should be unique again after expiration - is_dup2, fp2 = deduplicator.is_duplicate(request) - self.assertFalse(is_dup2) + fake_time.return_value = 1000.2 + + # Same request should be unique again after expiration + is_dup2, fp2 = deduplicator.is_duplicate(request) + self.assertFalse(is_dup2) def test_deduplicate_request_decorator(self): """Test the deduplicate_request decorator.""" call_count = 0 - @deduplicate_request(ttl=1.0) - def expensive_search(query): - nonlocal call_count - call_count += 1 - return f"search_result_for_{query}" - - # First call should execute - result1 = expensive_search("test") - self.assertEqual(result1, "search_result_for_test") - self.assertEqual(call_count, 1) - - # Second identical call should be deduplicated - result2 = expensive_search("test") - self.assertIsNone(result2) # Decorator returns None for duplicates - self.assertEqual(call_count, 1) # Should not increase - - # Wait for expiration and call again - time.sleep(1.1) - result3 = expensive_search("test") - self.assertEqual(result3, "search_result_for_test") - self.assertEqual(call_count, 2) + with patch("scripts.deduplication.time.time") as fake_time: + fake_time.return_value = 1000.0 + + @deduplicate_request(ttl=1.0) + def expensive_search(query): + nonlocal call_count + call_count += 1 + return f"search_result_for_{query}" + + # First call should execute + result1 = expensive_search("test") + self.assertEqual(result1, "search_result_for_test") + self.assertEqual(call_count, 1) + + # Second identical call should be deduplicated + result2 = expensive_search("test") + self.assertIsNone(result2) # Decorator returns None for duplicates + self.assertEqual(call_count, 1) # Should not increase + + fake_time.return_value = 1001.1 + result3 = expensive_search("test") + self.assertEqual(result3, "search_result_for_test") + self.assertEqual(call_count, 2) class TestCacheIntegration(unittest.TestCase): diff --git a/tests/test_change_history_for_path.py b/tests/test_change_history_for_path.py index 52be7592..7d822150 100644 --- a/tests/test_change_history_for_path.py +++ b/tests/test_change_history_for_path.py @@ -15,6 +15,10 @@ def tool(self, *args, **kwargs): def _decorator(fn): return fn return _decorator + def resource(self, *args, **kwargs): + def _decorator(fn): + return fn + return _decorator class _Context: def __init__(self, *args, **kwargs): @@ -98,4 +102,3 @@ async def test_change_history_strict_match_under_work(monkeypatch): assert summary.get("ingested_min") == 90 assert summary.get("ingested_max") == 115 assert summary.get("churn_count_max") == 5 - diff --git a/tests/test_collection_memory_backup_restore.py b/tests/test_collection_memory_backup_restore.py index c9e74921..cfdaf63e 100644 --- a/tests/test_collection_memory_backup_restore.py +++ b/tests/test_collection_memory_backup_restore.py @@ -78,14 +78,14 @@ def test_memory_backup_restore_happy_path(qdrant_container, monkeypatch): - The collection should be updated (if possible) without recreation. - Existing points should remain intact. """ - os.environ["QDRANT_URL"] = qdrant_container + monkeypatch.setenv("QDRANT_URL", qdrant_container) collection = f"test-mem-{uuid.uuid4().hex[:8]}" client = _create_collection_with_memory(qdrant_container, collection, dim=8) # Force ReFRAG on so ensure_collection tries to add MINI_VECTOR_NAME - os.environ["REFRAG_MODE"] = "1" - os.environ.pop("STRICT_MEMORY_RESTORE", None) + monkeypatch.setenv("REFRAG_MODE", "1") + monkeypatch.delenv("STRICT_MEMORY_RESTORE", raising=False) # Run ensure_collection: this should trigger backup + recreate + restore ing.ensure_collection(client, collection, dim=8, vector_name="code") @@ -107,13 +107,13 @@ def test_memory_backup_restore_happy_path(qdrant_container, monkeypatch): def test_memory_restore_strict_mode_no_recreate(qdrant_container, monkeypatch): """STRICT_MEMORY_RESTORE should not trigger errors when no recreate occurs.""" - os.environ["QDRANT_URL"] = qdrant_container + monkeypatch.setenv("QDRANT_URL", qdrant_container) collection = f"test-mem-strict-{uuid.uuid4().hex[:8]}" client = _create_collection_with_memory(qdrant_container, collection, dim=8) - os.environ["REFRAG_MODE"] = "1" - os.environ["STRICT_MEMORY_RESTORE"] = "1" + monkeypatch.setenv("REFRAG_MODE", "1") + monkeypatch.setenv("STRICT_MEMORY_RESTORE", "1") # Patch subprocess.run to: # - allow the real memory_backup.py to run @@ -141,13 +141,13 @@ def test_memory_backup_failure_tolerant_mode_no_recreate(qdrant_container, monke """If backup fails but STRICT_MEMORY_RESTORE is not set, ensure_collection should still proceed without destructive recreation. """ - os.environ["QDRANT_URL"] = qdrant_container + monkeypatch.setenv("QDRANT_URL", qdrant_container) collection = f"test-mem-backup-fail-{uuid.uuid4().hex[:8]}" client = _create_collection_with_memory(qdrant_container, collection, dim=8) - os.environ["REFRAG_MODE"] = "1" - os.environ.pop("STRICT_MEMORY_RESTORE", None) + monkeypatch.setenv("REFRAG_MODE", "1") + monkeypatch.delenv("STRICT_MEMORY_RESTORE", raising=False) # Patch subprocess.run so memory_backup.py fails, but everything else runs normally orig_run = subprocess.run @@ -176,14 +176,14 @@ def fake_run(args, **kwargs): # type: ignore[override] assert "2" in ids -def test_memory_backup_and_restore_scripts_roundtrip(qdrant_container, tmp_path): +def test_memory_backup_and_restore_scripts_roundtrip(qdrant_container, tmp_path, monkeypatch): """Directly exercise memory_backup.export_memories and memory_restore.restore_memories without going through ensure_collection. This confirms that the backup file contains the expected memory and that restore_memories can recreate it in a fresh collection. """ - os.environ["QDRANT_URL"] = qdrant_container + monkeypatch.setenv("QDRANT_URL", qdrant_container) collection = f"test-mem-scripts-{uuid.uuid4().hex[:8]}" client = _create_collection_with_memory(qdrant_container, collection, dim=8) diff --git a/tests/test_concurrency_service.py b/tests/test_concurrency_service.py index a923163a..1e5e61de 100644 --- a/tests/test_concurrency_service.py +++ b/tests/test_concurrency_service.py @@ -10,6 +10,8 @@ async def test_repo_search_concurrent(monkeypatch): # In-process, fast stubbed hybrid search and model monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") + monkeypatch.setenv("RERANKER_ENABLED", "0") monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) import scripts.hybrid_search as hy diff --git a/tests/test_context_answer.py b/tests/test_context_answer.py index 595207b9..c7bf2e2f 100644 --- a/tests/test_context_answer.py +++ b/tests/test_context_answer.py @@ -1,9 +1,58 @@ -import importlib +import asyncio +import sys +import threading import types import pytest +from scripts.mcp_impl.context_answer import ( + _ca_prepare_filters_and_retrieve, + _context_answer_impl, +) + + +def _retrieval_result(items, **overrides): + result = { + "items": items, + "eff_language": None, + "eff_path_glob": None, + "eff_not_glob": None, + "override_under": False, + "sym_arg": None, + "cwd_root": "/work", + "path_regex": None, + "ext": None, + "kind": None, + "case": None, + } + result.update(overrides) + return result + + +def _install_fake_hybrid(monkeypatch, run_hybrid_search): + fake = types.ModuleType("scripts.hybrid_search") + fake.run_hybrid_search = run_hybrid_search + fake.lang_matches_path = lambda language, path: True + fake._merge_and_budget_spans = lambda items: list(items or []) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", fake) + return fake + + +def _run_context_answer(retrieval_fn=None, **kwargs): + return asyncio.get_event_loop().run_until_complete( + _context_answer_impl( + **kwargs, + get_embedding_model_fn=lambda *a, **k: None, + env_lock=threading.Lock(), + prepare_filters_and_retrieve_fn=retrieval_fn or _ca_prepare_filters_and_retrieve, + ) + ) + -srv = importlib.import_module("scripts.mcp_indexer_server") +def _isolate_context_answer_unit(monkeypatch): + monkeypatch.setenv("REFRAG_RUNTIME", "llamacpp") + monkeypatch.setenv("CTX_MULTI_COLLECTION", "0") + monkeypatch.setenv("CTX_DOC_PASS", "0") + monkeypatch.setenv("CTX_DOC_TOP_FALLBACK", "0") def _fake_items(): @@ -29,14 +78,7 @@ def _fake_items(): @pytest.mark.service def test_context_answer_happy_path(monkeypatch): - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - - # Fake retrieval output (already budgeted) - import scripts.hybrid_search as hs - - monkeypatch.setattr(hs, "run_hybrid_search", lambda **k: _fake_items()) - + _isolate_context_answer_unit(monkeypatch) # Fake decoder import scripts.refrag_llamacpp as ref @@ -51,8 +93,11 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 256, **kw monkeypatch.setattr(ref, "LlamaCppRefragClient", FakeLlama) monkeypatch.setattr(ref, "is_decoder_enabled", lambda: True) - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="how to do x", limit=2, per_path=1) + out = _run_context_answer( + retrieval_fn=lambda **_kwargs: _retrieval_result(_fake_items()), + query="how to do x", + limit=2, + per_path=1, ) assert isinstance(out, dict) @@ -63,12 +108,12 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 256, **kw def test_context_answer_decoder_disabled(monkeypatch): - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - - import scripts.hybrid_search as hs - - monkeypatch.setattr(hs, "run_hybrid_search", lambda **k: _fake_items()) + _isolate_context_answer_unit(monkeypatch) + monkeypatch.setenv("REFRAG_MODE", "0") + monkeypatch.setenv("REFRAG_GATE_FIRST", "0") + monkeypatch.setenv("REFRAG_RUNTIME", "llamacpp") + monkeypatch.setenv("CTX_CLIENT_DEADLINE_SEC", "178") + monkeypatch.setenv("CTX_DEADLINE_MARGIN_SEC", "6") import scripts.refrag_llamacpp as ref @@ -82,8 +127,10 @@ def generate_with_soft_embeddings(self, *a, **k): monkeypatch.setattr(ref, "LlamaCppRefragClient", FakeLlama) monkeypatch.setattr(ref, "is_decoder_enabled", lambda: False) - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="how to do y", limit=1) + out = _run_context_answer( + retrieval_fn=lambda **_kwargs: _retrieval_result(_fake_items()), + query="how to do y", + limit=1, ) assert "error" in out @@ -91,11 +138,7 @@ def generate_with_soft_embeddings(self, *a, **k): def test_context_answer_prefers_identifier_spans(monkeypatch): - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - - import scripts.hybrid_search as hs - + _isolate_context_answer_unit(monkeypatch) def _items(): return [ { @@ -105,6 +148,7 @@ def _items(): "start_line": 10, "end_line": 16, "text": "def helper():\n return 42\n", + "span_budgeted": True, }, { "score": 0.8, @@ -113,11 +157,10 @@ def _items(): "start_line": 5, "end_line": 9, "text": "RRF_K = 60\n", + "span_budgeted": True, }, ] - monkeypatch.setattr(hs, "run_hybrid_search", lambda **k: _items()) - import scripts.refrag_llamacpp as ref class FakeLlama: @@ -130,8 +173,11 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 256, **kw monkeypatch.setattr(ref, "LlamaCppRefragClient", FakeLlama) monkeypatch.setattr(ref, "is_decoder_enabled", lambda: True) - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="what is RRF_K in hybrid_search.py?", limit=1, per_path=1) + out = _run_context_answer( + retrieval_fn=lambda **_kwargs: _retrieval_result(_items()), + query="what is RRF_K in hybrid_search.py?", + limit=1, + per_path=1, ) cits = out.get("citations") or [] @@ -141,11 +187,7 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 256, **kw def test_context_answer_tier2_retry_without_gating(monkeypatch): """Tier 2 should retry run_hybrid_search with relaxed filters when Tier 1 yields zero.""" - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - - import scripts.hybrid_search as hs - + _isolate_context_answer_unit(monkeypatch) calls = [] def _run_hybrid_search(**kwargs): @@ -165,7 +207,7 @@ def _run_hybrid_search(**kwargs): # All other calls (tier1/usage/targeted search) yield no hits return [] - monkeypatch.setattr(hs, "run_hybrid_search", _run_hybrid_search) + _install_fake_hybrid(monkeypatch, _run_hybrid_search) import scripts.refrag_llamacpp as ref @@ -179,9 +221,7 @@ def generate_with_soft_embeddings(self, *a, **kw): monkeypatch.setattr(ref, "LlamaCppRefragClient", FakeLlama) monkeypatch.setattr(ref, "is_decoder_enabled", lambda: True) - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="RRF_K", limit=1, per_path=1) - ) + out = _run_context_answer(query="RRF_K", limit=1, per_path=1) # Ensure Tier 2 was invoked (run_hybrid_search called twice) assert len(calls) >= 3, "Tier 2 fallback should re-run hybrid search" @@ -200,9 +240,7 @@ def generate_with_soft_embeddings(self, *a, **kw): def test_context_answer_env_lock_release_on_retrieval_exception(monkeypatch): - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - + _isolate_context_answer_unit(monkeypatch) import os # Force retrieval to raise and ensure env/lock are restored prev = {k: os.environ.get(k) for k in ( @@ -212,16 +250,22 @@ def test_context_answer_env_lock_release_on_retrieval_exception(monkeypatch): def _raise_retrieval(*a, **k): raise RuntimeError("boom") - monkeypatch.setattr(srv, "_ca_prepare_filters_and_retrieve", _raise_retrieval) - - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="x", limit=1, per_path=1) + lock = threading.Lock() + out = asyncio.get_event_loop().run_until_complete( + _context_answer_impl( + query="x", + limit=1, + per_path=1, + get_embedding_model_fn=lambda *a, **k: None, + env_lock=lock, + prepare_filters_and_retrieve_fn=_raise_retrieval, + ) ) assert "error" in out # Lock should be free after failure - assert srv._ENV_LOCK.acquire(blocking=False), "_ENV_LOCK should be released on exception" - srv._ENV_LOCK.release() + assert lock.acquire(blocking=False), "context_answer env lock should be released on exception" + lock.release() # Env should be restored for k, v in prev.items(): @@ -243,12 +287,19 @@ def _fake_retrieval(*a, **k): "case": None, } - monkeypatch.setattr(srv, "_ca_prepare_filters_and_retrieve", _fake_retrieval) - import scripts.refrag_llamacpp as ref + + _install_fake_hybrid(monkeypatch, lambda **k: []) monkeypatch.setattr(ref, "is_decoder_enabled", lambda: False) - out2 = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query="x", limit=1, per_path=1) + out2 = asyncio.get_event_loop().run_until_complete( + _context_answer_impl( + query="x", + limit=1, + per_path=1, + get_embedding_model_fn=lambda *a, **k: None, + env_lock=lock, + prepare_filters_and_retrieve_fn=_fake_retrieval, + ) ) assert isinstance(out2, dict) diff --git a/tests/test_context_answer_fallback.py b/tests/test_context_answer_fallback.py index d3c0f599..af95291d 100644 --- a/tests/test_context_answer_fallback.py +++ b/tests/test_context_answer_fallback.py @@ -1,28 +1,53 @@ import asyncio +import sys +import types import pytest -from unittest.mock import patch + +from scripts.mcp_impl.context_answer import _context_answer_impl + @pytest.mark.asyncio -async def test_context_answer_has_no_filesystem_fallback_when_no_hits(): +async def test_context_answer_has_no_filesystem_fallback_when_no_hits(monkeypatch): """When retrieval yields no spans, we do NOT glob or read the host filesystem. Citations may be empty, and that's expected. """ - from scripts.mcp_indexer_server import context_answer - import scripts.mcp_indexer_server as srv + fake_hybrid = types.ModuleType("scripts.hybrid_search") + fake_hybrid.run_hybrid_search = lambda **k: [] + fake_hybrid.lang_matches_path = lambda language, path: True + fake_hybrid._merge_and_budget_spans = lambda items: list(items or []) + + monkeypatch.setenv("REFRAG_RUNTIME", "llamacpp") + monkeypatch.setenv("CTX_MULTI_COLLECTION", "0") + monkeypatch.setenv("CTX_DOC_PASS", "0") + monkeypatch.setenv("CTX_DOC_TOP_FALLBACK", "0") + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", fake_hybrid) - # Mock embedding model to avoid loading real model - with patch.object(srv, "_get_embedding_model", return_value=None): - out = await context_answer( - query="Describe module roles", - limit=3, - per_path=1, - include_snippet=True, - path_glob=["scripts/hybrid_search.py"], - # Force a very unlikely match to simulate empty retrieval - language="nonexistentlang", - ) - assert isinstance(out, dict) - # No fallback: citations can be empty - cits = out.get("citations") or [] - assert len(cits) == 0 + def _empty_retrieval(**_kwargs): + return { + "items": [], + "eff_language": "nonexistentlang", + "eff_path_glob": ["scripts/hybrid_search.py"], + "eff_not_glob": [], + "override_under": None, + "sym_arg": None, + "cwd_root": "/work", + "path_regex": None, + "ext": None, + "kind": None, + "case": None, + } + out = await _context_answer_impl( + query="Describe module roles", + limit=3, + per_path=1, + include_snippet=True, + path_glob=["scripts/hybrid_search.py"], + language="nonexistentlang", + get_embedding_model_fn=lambda *_args, **_kwargs: None, + prepare_filters_and_retrieve_fn=_empty_retrieval, + ) + assert isinstance(out, dict) + # No fallback: citations can be empty + cits = out.get("citations") or [] + assert len(cits) == 0 diff --git a/tests/test_context_answer_path_mention.py b/tests/test_context_answer_path_mention.py index 59299b8d..edc7b281 100644 --- a/tests/test_context_answer_path_mention.py +++ b/tests/test_context_answer_path_mention.py @@ -1,17 +1,47 @@ -import importlib +import asyncio +import sys +import threading +import types import pytest -srv = importlib.import_module("scripts.mcp_indexer_server") +from scripts.mcp_impl.context_answer import ( + _ca_prepare_filters_and_retrieve, + _context_answer_impl, +) + + +def _run_context_answer(**kwargs): + return asyncio.get_event_loop().run_until_complete( + _context_answer_impl( + **kwargs, + get_embedding_model_fn=lambda *a, **k: None, + env_lock=threading.Lock(), + prepare_filters_and_retrieve_fn=_ca_prepare_filters_and_retrieve, + ) + ) + + +def _install_fake_hybrid(monkeypatch, run_hybrid_search): + fake = types.ModuleType("scripts.hybrid_search") + fake.run_hybrid_search = run_hybrid_search + fake.lang_matches_path = lambda language, path: True + fake._merge_and_budget_spans = lambda items: list(items or []) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", fake) + return fake + + +def _isolate_context_answer_unit(monkeypatch): + monkeypatch.setenv("REFRAG_RUNTIME", "llamacpp") + monkeypatch.setenv("CTX_MULTI_COLLECTION", "0") + monkeypatch.setenv("CTX_DOC_PASS", "0") + monkeypatch.setenv("CTX_DOC_TOP_FALLBACK", "0") @pytest.mark.service def test_context_answer_path_mention_fallback(monkeypatch): - # Mock embedding model to avoid loading real model - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: None) - + _isolate_context_answer_unit(monkeypatch) # Force retrieval to return nothing so path-mention fallback engages - import scripts.hybrid_search as hs - monkeypatch.setattr(hs, "run_hybrid_search", lambda **k: []) + _install_fake_hybrid(monkeypatch, lambda **k: []) import scripts.refrag_llamacpp as ref @@ -30,9 +60,7 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 64, **kw) # Mention an actual file in this repo so fallback can find it q = "explain something in scripts/hybrid_search.py" - out = srv.asyncio.get_event_loop().run_until_complete( - srv.context_answer(query=q, limit=3, per_path=2) - ) + out = _run_context_answer(query=q, limit=3, per_path=2) assert isinstance(out, dict) cits = out.get("citations") or [] assert len(cits) >= 1 @@ -40,4 +68,3 @@ def generate_with_soft_embeddings(self, prompt: str, max_tokens: int = 64, **kw) p = cits[0].get("path") or "" rp = cits[0].get("rel_path") or "" assert p.endswith("scripts/hybrid_search.py") or rp.endswith("scripts/hybrid_search.py") - diff --git a/tests/test_ctx_cli.py b/tests/test_ctx_cli.py new file mode 100644 index 00000000..6ab6e8e9 --- /dev/null +++ b/tests/test_ctx_cli.py @@ -0,0 +1,67 @@ +import sys + +import scripts.ctx as ctx + + +def test_parse_mcp_response_prefers_structured_content(): + payload = { + "result": { + "content": [{"type": "text", "text": '{"result":{"results":[]}}'}], + "structuredContent": { + "result": {"results": [{"path": "structured.py"}], "total": 1} + }, + } + } + + assert ctx.parse_mcp_response(payload) == { + "results": [{"path": "structured.py"}], + "total": 1, + } + + +def test_parse_mcp_response_unwraps_text_result_payload(): + payload = { + "result": { + "content": [ + { + "type": "text", + "text": '{"result":{"results":[{"path":"text.py"}],"total":1}}', + } + ] + } + } + + assert ctx.parse_mcp_response(payload) == { + "results": [{"path": "text.py"}], + "total": 1, + } + + +def test_main_with_context_appends_supporting_context(monkeypatch, capsys): + monkeypatch.setattr( + sys, + "argv", + ["ctx.py", "--with-context", "where is dense search?"], + ) + monkeypatch.setattr( + ctx, + "fetch_context", + lambda *a, **k: ( + "- /work/scripts/hybrid_search.py:428-565 (run_pure_dense_search)", + "", + ), + ) + monkeypatch.setattr(ctx, "rewrite_prompt", lambda *a, **k: "rewritten prompt") + monkeypatch.setattr( + ctx, + "extract_allowed_citations", + lambda *a, **k: ({"/work/scripts/hybrid_search.py"}, {}), + ) + monkeypatch.setattr(ctx, "sanitize_citations", lambda text, *_: text) + + ctx.main() + + out = capsys.readouterr().out + assert "rewritten prompt" in out + assert "Supporting context:" in out + assert "run_pure_dense_search" in out diff --git a/tests/test_env_behavior.py b/tests/test_env_behavior.py index e87cc8db..220a02bf 100644 --- a/tests/test_env_behavior.py +++ b/tests/test_env_behavior.py @@ -1,9 +1,7 @@ import importlib -import types -import os import pytest -srv = importlib.import_module("scripts.mcp_indexer_server") +search_impl = importlib.import_module("scripts.mcp_impl.search") @pytest.mark.service @@ -11,19 +9,21 @@ def test_rerank_timeout_floor_and_env_defaults(monkeypatch): # Force rerank via env default when arg not provided monkeypatch.setenv("RERANKER_ENABLED", "1") monkeypatch.setenv("RERANK_IN_PROCESS", "0") + monkeypatch.setenv("HYBRID_IN_PROCESS", "0") + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") # Floor 1500ms; client asks 200ms -> effective >= 1500ms -> 1.5s monkeypatch.setenv("RERANK_TIMEOUT_FLOOR_MS", "1500") # Fix default timeout for test determinism (CI may set a higher value) monkeypatch.setenv("RERANKER_TIMEOUT_MS", "200") - # Fake _run_async to capture calls + # Fake subprocess runner to capture hybrid + rerank calls without loading the MCP facade. calls = [] async def fake_run(cmd, env=None, timeout=None): calls.append({"cmd": cmd, "timeout": timeout}) - # Distinguish hybrid vs rerank by script name - if any("rerank_local.py" in str(x) for x in cmd): + # Distinguish hybrid vs rerank by module name + if "scripts.rerank_tools.local" in " ".join(map(str, cmd)): # Return something that looks like rerank stdout return { "ok": True, @@ -40,17 +40,21 @@ async def fake_run(cmd, env=None, timeout=None): "code": 0, } - monkeypatch.setattr(srv, "_run_async", fake_run) - # Call repo_search with no rerank_enabled arg to pick env default - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search(query="foo", limit=3, per_path=1) + res = search_impl.asyncio.get_event_loop().run_until_complete( + search_impl._repo_search_impl( + query="foo", + limit=3, + per_path=1, + run_async_fn=fake_run, + require_auth_session_fn=lambda session: session, + ) ) assert any( - "rerank_local.py" in " ".join(map(str, c["cmd"])) for c in calls + "scripts.rerank_tools.local" in " ".join(map(str, c["cmd"])) for c in calls ), "rerank subprocess should be invoked" # find rerank call - rc = next(c for c in calls if any("rerank_local.py" in str(x) for x in c["cmd"])) + rc = next(c for c in calls if "scripts.rerank_tools.local" in " ".join(map(str, c["cmd"]))) assert rc["timeout"] >= 1.5 and rc["timeout"] <= 2.0 assert res["used_rerank"] is True diff --git a/tests/test_error_paths.py b/tests/test_error_paths.py index 129e2eb4..24765ea3 100644 --- a/tests/test_error_paths.py +++ b/tests/test_error_paths.py @@ -1,24 +1,34 @@ -import os -import asyncio +import importlib +import sys import types import pytest -import scripts.mcp_indexer_server as srv +search_impl = importlib.import_module("scripts.mcp_impl.search") @pytest.mark.service def test_repo_search_malformed_jsonl_subprocess(monkeypatch): # Force subprocess path and simulate malformed JSONL stdout monkeypatch.setenv("HYBRID_IN_PROCESS", "0") + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") + monkeypatch.setenv("RERANKER_ENABLED", "0") + + fake_hybrid = types.ModuleType("scripts.hybrid_search") + fake_hybrid.run_hybrid_search = lambda *a, **k: [] + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", fake_hybrid) async def fake_run(cmd, **kwargs): # Simulate subprocess failure with malformed output return {"ok": False, "code": 1, "stdout": "not-json\n", "stderr": "malformed"} - monkeypatch.setattr(srv, "_run_async", fake_run) - - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search(queries=["x"], limit=1, compact=False) + res = search_impl.asyncio.get_event_loop().run_until_complete( + search_impl._repo_search_impl( + queries=["x"], + limit=1, + compact=False, + run_async_fn=fake_run, + require_auth_session_fn=lambda session: session, + ) ) assert res.get("ok") is False @@ -30,26 +40,30 @@ def test_repo_search_inproc_qdrant_failure_fallback_and_fail(monkeypatch): # In-process hybrid raises (simulating Qdrant connectivity failure), # subprocess fallback also fails. monkeypatch.setenv("HYBRID_IN_PROCESS", "1") - - # Avoid real model load - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") + monkeypatch.setenv("RERANKER_ENABLED", "0") # Cause in-process path to fail - import scripts.hybrid_search as hy - def boom(*a, **k): raise ConnectionError("qdrant down") - monkeypatch.setattr(hy, "run_hybrid_search", boom) + fake_hybrid = types.ModuleType("scripts.hybrid_search") + fake_hybrid.run_hybrid_search = boom + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", fake_hybrid) # And make the subprocess fallback fail too async def fake_run(cmd, **kwargs): return {"ok": False, "code": 1, "stdout": "", "stderr": "qdrant unreachable"} - monkeypatch.setattr(srv, "_run_async", fake_run) - - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search(queries=["x"], limit=1, compact=True) + res = search_impl.asyncio.get_event_loop().run_until_complete( + search_impl._repo_search_impl( + queries=["x"], + limit=1, + compact=True, + get_embedding_model_fn=lambda *a, **k: object(), + run_async_fn=fake_run, + require_auth_session_fn=lambda session: session, + ) ) assert res.get("ok") is False diff --git a/tests/test_fname_boost.py b/tests/test_fname_boost.py index fcca42df..a3ff2f4c 100644 --- a/tests/test_fname_boost.py +++ b/tests/test_fname_boost.py @@ -6,7 +6,7 @@ - Position weighting (filename > directory) - Common token penalties """ -from scripts.rerank_recursive import ( +from scripts.rerank_recursive.utils import ( _compute_fname_boost, _split_identifier, _normalize_token, diff --git a/tests/test_glm_model_config.py b/tests/test_glm_model_config.py index 44baac1a..7823b46c 100644 --- a/tests/test_glm_model_config.py +++ b/tests/test_glm_model_config.py @@ -1,8 +1,20 @@ """Tests for GLM model version configuration and backwards compatibility.""" import os +import sys +from types import SimpleNamespace from unittest.mock import patch, MagicMock +def _install_fake_openai(monkeypatch): + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] + mock_client.chat.completions.create.return_value = mock_response + openai_module = SimpleNamespace(OpenAI=MagicMock(return_value=mock_client)) + monkeypatch.setitem(sys.modules, "openai", openai_module) + return mock_client + + class TestGLMModelConfig: """Test GLM model version detection and configuration.""" @@ -81,16 +93,10 @@ def test_glm45_config_values(self): class TestGLMRefragClientModelSelection: """Test GLMRefragClient model selection logic.""" - @patch("openai.OpenAI") - def test_default_model_is_glm46(self, mock_openai_class): + def test_default_model_is_glm46(self, monkeypatch): """Test that default model is glm-4.6.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) # Remove GLM_MODEL from env to test default, keep GLM_API_KEY env_copy = os.environ.copy() @@ -106,16 +112,10 @@ def test_default_model_is_glm46(self, mock_openai_class): assert call_kwargs["model"] == "glm-4.6" @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.6"}, clear=False) - @patch("openai.OpenAI") - def test_env_model_override(self, mock_openai_class): + def test_env_model_override(self, monkeypatch): """Test that GLM_MODEL env var overrides default.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() client.generate_with_soft_embeddings("test prompt") @@ -124,16 +124,10 @@ def test_env_model_override(self, mock_openai_class): assert call_kwargs["model"] == "glm-4.6" @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL_FAST": "glm-4.5"}, clear=False) - @patch("openai.OpenAI") - def test_fast_model_with_disable_thinking(self, mock_openai_class): + def test_fast_model_with_disable_thinking(self, monkeypatch): """Test that disable_thinking uses GLM_MODEL_FAST.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() client.generate_with_soft_embeddings("test prompt", disable_thinking=True) @@ -146,16 +140,10 @@ class TestGLMToolStreamSupport: """Test GLM-4.7 tool_stream feature support.""" @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.7"}, clear=False) - @patch("openai.OpenAI") - def test_tool_stream_enabled_for_glm47(self, mock_openai_class): + def test_tool_stream_enabled_for_glm47(self, monkeypatch): """Test that tool_stream is enabled for GLM-4.7 when requested.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}] @@ -166,16 +154,10 @@ def test_tool_stream_enabled_for_glm47(self, mock_openai_class): assert call_kwargs.get("extra_body", {}).get("tool_stream") is True @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.6"}, clear=False) - @patch("openai.OpenAI") - def test_tool_stream_not_enabled_for_glm46(self, mock_openai_class): + def test_tool_stream_not_enabled_for_glm46(self, monkeypatch): """Test that tool_stream is NOT enabled for GLM-4.6.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}] @@ -191,16 +173,10 @@ class TestGLMThinkingSupport: """Test GLM thinking/reasoning support.""" @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.7"}, clear=False) - @patch("openai.OpenAI") - def test_enable_thinking_for_glm47(self, mock_openai_class): + def test_enable_thinking_for_glm47(self, monkeypatch): """Test that thinking can be explicitly enabled for GLM-4.7.""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() client.generate_with_soft_embeddings("test prompt", enable_thinking=True) @@ -209,16 +185,10 @@ def test_enable_thinking_for_glm47(self, mock_openai_class): assert call_kwargs.get("extra_body", {}).get("thinking") == {"type": "enabled"} @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.5"}, clear=False) - @patch("openai.OpenAI") - def test_thinking_not_set_for_glm45(self, mock_openai_class): + def test_thinking_not_set_for_glm45(self, monkeypatch): """Test that thinking is NOT set for GLM-4.5 (no thinking support).""" from scripts.refrag_glm import GLMRefragClient - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() client.generate_with_soft_embeddings("test prompt", enable_thinking=True) @@ -234,16 +204,10 @@ class TestGLMMaxTokensLimit: """Test max_tokens limiting based on model capabilities.""" @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.7"}, clear=False) - @patch("openai.OpenAI") - def test_max_tokens_capped_to_model_limit(self, mock_openai_class): + def test_max_tokens_capped_to_model_limit(self, monkeypatch): """Test that max_tokens is capped to model's max_output_tokens.""" from scripts.refrag_glm import GLMRefragClient, GLM_MODEL_CONFIGS - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() # Request more than GLM-4.7 can output (131072) @@ -253,16 +217,10 @@ def test_max_tokens_capped_to_model_limit(self, mock_openai_class): assert call_kwargs["max_tokens"] <= GLM_MODEL_CONFIGS["glm-4.7"]["max_output_tokens"] @patch.dict(os.environ, {"GLM_API_KEY": "test-key", "GLM_MODEL": "glm-4.5"}, clear=False) - @patch("openai.OpenAI") - def test_max_tokens_uses_smaller_limit_for_glm45(self, mock_openai_class): + def test_max_tokens_uses_smaller_limit_for_glm45(self, monkeypatch): """Test that GLM-4.5 uses its smaller max_output limit.""" from scripts.refrag_glm import GLMRefragClient, GLM_MODEL_CONFIGS - - mock_client = MagicMock() - mock_response = MagicMock() - mock_response.choices = [MagicMock(message=MagicMock(content="test response"))] - mock_client.chat.completions.create.return_value = mock_response - mock_openai_class.return_value = mock_client + mock_client = _install_fake_openai(monkeypatch) client = GLMRefragClient() # Request more than GLM-4.5 can output (8192) diff --git a/tests/test_globs_and_snippet.py b/tests/test_globs_and_snippet.py index 4c30ff48..d52abc1e 100644 --- a/tests/test_globs_and_snippet.py +++ b/tests/test_globs_and_snippet.py @@ -1,14 +1,53 @@ import importlib +import asyncio import builtins import json import types from pathlib import Path +from types import SimpleNamespace import pytest # Import targets hyb = importlib.import_module("scripts.hybrid_search") -srv = importlib.import_module("scripts.mcp_indexer_server") +search_impl = importlib.import_module("scripts.mcp_impl.search") + + +@pytest.fixture(autouse=True) +def fake_qdrant_models(monkeypatch): + hybrid_qdrant = importlib.import_module("scripts.hybrid.qdrant") + + class FakeModels: + class SearchParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class QuantizationSearchParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class Filter: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class FieldCondition: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class MatchValue: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class MatchAny: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class SparseVector: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + monkeypatch.setattr(hyb, "models", FakeModels) + monkeypatch.setattr(hybrid_qdrant, "models", FakeModels) class _Pt: @@ -33,6 +72,19 @@ class FakeQdrant: def __init__(self, points): self._points = points + def get_collection(self, collection): + return SimpleNamespace( + config=SimpleNamespace( + params=SimpleNamespace( + vectors={ + "unit-test": SimpleNamespace(size=8), + "dense": SimpleNamespace(size=8), + }, + sparse_vectors={}, + ) + ) + ) + # dense_query tries query_points first, then search on exception def query_points(self, **kwargs): return _QP(self._points) @@ -116,6 +168,32 @@ def test_run_hybrid_search_slugged_path_globs(monkeypatch): assert "/work/other/docs/readme.md" not in paths +@pytest.mark.unit +def test_run_hybrid_search_under_recursive_scope(monkeypatch): + pts = [ + _Pt("1", "/work/repo/space/ship/a.py"), + _Pt("2", "/work/repo/direct/tools/b.py"), + ] + monkeypatch.setattr(hyb, "get_qdrant_client", lambda *a, **k: FakeQdrant(pts)) + monkeypatch.setattr(hyb, "return_qdrant_client", lambda *a, **k: None) + monkeypatch.setenv("EMBEDDING_MODEL", "unit-test") + monkeypatch.setenv("QDRANT_URL", "http://localhost:6333") + monkeypatch.setattr(hyb, "TextEmbedding", lambda *a, **k: FakeEmbed()) + monkeypatch.setattr(hyb, "_get_embedding_model", lambda *a, **k: FakeEmbed()) + + items = hyb.run_hybrid_search( + queries=["rotate heading"], + limit=10, + per_path=2, + under="space", + expand=False, + model=FakeEmbed(), + ) + paths = {it.get("path") for it in items} + assert "/work/repo/space/ship/a.py" in paths + assert "/work/repo/direct/tools/b.py" not in paths + + @pytest.mark.unit def test_dense_query_preserves_collection_on_filter_drop(monkeypatch): calls = [] @@ -147,6 +225,32 @@ def query_points(self, **kwargs): assert calls[1].get("query_filter") is None or calls[1].get("filter") is None +@pytest.mark.unit +def test_run_pure_dense_search_honors_per_path_cap(): + points = [ + SimpleNamespace( + score=0.99, + payload={"metadata": {"path": "/work/repo/a.py", "start_line": 1, "end_line": 2}}, + ), + SimpleNamespace( + score=0.98, + payload={"metadata": {"path": "/work/repo/a.py", "start_line": 10, "end_line": 11}}, + ), + SimpleNamespace( + score=0.97, + payload={"metadata": {"path": "/work/repo/b.py", "start_line": 3, "end_line": 4}}, + ), + ] + + items = hyb._shape_dense_points( + points, + limit=2, + per_path=1, + ) + + assert [item["path"] for item in items] == ["/work/repo/a.py", "/work/repo/b.py"] + + @pytest.mark.unit def test_collection_prefers_env_over_state(monkeypatch, tmp_path): # State file should be ignored when COLLECTION_NAME env var is set @@ -162,18 +266,9 @@ def test_collection_prefers_env_over_state(monkeypatch, tmp_path): @pytest.mark.unit def test_repo_search_snippet_strict_cap_after_highlight(monkeypatch): - # Stub run_hybrid_search to emit a single result with a known path and range - async def fake_run(**kwargs): - return {"results": [{"path": "/work/f.txt", "start_line": 1, "end_line": 1}]} - # Force in-process shaping to trigger snippet code path monkeypatch.setenv("HYBRID_IN_PROCESS", "1") - # Monkeypatch srv.hybrid_search.run_hybrid_search result pathing via repo_search flow - monkeypatch.setattr( - srv, "_tokens_from_queries", lambda q: ["foo"] - ) # ensure highlight runs - # Fake open for the specific /work path big_line = "foo " * 1000 # large content to exceed cap _orig_open = builtins.open @@ -184,9 +279,9 @@ def fake_open(path, *a, **k): return _orig_open(path, *a, **k) # pragma: no cover # Ensure sandbox passes - monkeypatch.setattr(srv.os.path, "isabs", lambda p: True) - monkeypatch.setattr(srv.os.path, "realpath", lambda p: "/work/f.txt") - monkeypatch.setenv("MCP_SNIPPET_MAX_BYTES", "64") + monkeypatch.setattr(search_impl.os.path, "isabs", lambda p: True) + monkeypatch.setattr(search_impl.os.path, "realpath", lambda p: "/work/f.txt") + monkeypatch.setattr(search_impl, "SNIPPET_MAX_BYTES", 64) # Stub hybrid_search.run_hybrid_search to return a single item import sys @@ -202,13 +297,21 @@ def run_hybrid_search(**kwargs): import io - # Patch open builtin used by server + # Patch open builtin used by repo_search implementation. monkeypatch.setattr(builtins, "open", fake_open) # Execute - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search( - query="foo", include_snippet=True, highlight_snippet=True, context_lines=0 + res = asyncio.run( + search_impl._repo_search_impl( + query="foo", + mode="hybrid", + include_snippet=True, + highlight_snippet=True, + context_lines=0, + get_embedding_model_fn=lambda _name: object(), + require_auth_session_fn=lambda session: session, + do_highlight_snippet_fn=lambda snippet, _tokens: snippet, + run_async_fn=lambda *_a, **_k: {"ok": True, "code": 0, "stdout": "", "stderr": ""}, ) ) snip = res["results"][0].get("snippet", "") @@ -218,7 +321,7 @@ def run_hybrid_search(**kwargs): @pytest.mark.unit def test_repo_search_docstring_clean(): - doc = srv.repo_search.__doc__ + doc = search_impl._repo_search_impl.__doc__ assert doc and "Zero-config code search" in doc # Ensure stray inline pseudo-code is not embedded in docstring assert "Accept common alias keys from clients" not in doc diff --git a/tests/test_golden_structure.py b/tests/test_golden_structure.py index ab19917b..c8ac4d99 100644 --- a/tests/test_golden_structure.py +++ b/tests/test_golden_structure.py @@ -1,5 +1,6 @@ import os import json +import asyncio from pathlib import Path import pytest @@ -22,8 +23,8 @@ def stub(*a, **k): monkeypatch.setattr(hy, "run_hybrid_search", stub) - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search(queries=["q"], limit=2, compact=True) + res = asyncio.run( + srv.repo_search(queries=["q"], limit=2, compact=True, mode="hybrid") ) # Normalize subset: path/start_line/end_line/symbol only diff --git a/tests/test_graph_delete_verification_normalization.py b/tests/test_graph_delete_verification_normalization.py new file mode 100644 index 00000000..5531de7b --- /dev/null +++ b/tests/test_graph_delete_verification_normalization.py @@ -0,0 +1,45 @@ +from pathlib import PureWindowsPath + + +def test_graph_delete_verification_normalizes_caller_path(): + # Unit-level guard: watcher delete verification must query graph edges using + # the same path normalization as graph edge writes/deletes (Windows -> POSIX). + from scripts.watch_index_core import processor as proc + + captured = {} + + class DummyClient: + def scroll( + self, + *, + collection_name, + scroll_filter, + with_payload=False, + with_vectors=False, + limit=1, + ): + captured["collection_name"] = collection_name + captured["filter"] = scroll_filter + return ([], None) + + client = DummyClient() + path = PureWindowsPath(r"C:\repo\foo.py") + + has_edges = proc._path_has_graph_edges(client, "base_collection", path) + assert has_edges is False + + flt = captured["filter"] + assert flt is not None + assert getattr(flt, "must", None) + cond = flt.must[0] + assert cond.key == "caller_path" + + match = cond.match + values = [] + if hasattr(match, "any") and match.any is not None: + values = list(match.any) + elif hasattr(match, "value"): + values = [match.value] + + assert "C:/repo/foo.py" in values + diff --git a/tests/test_health_check.py b/tests/test_health_check.py new file mode 100644 index 00000000..0dceda58 --- /dev/null +++ b/tests/test_health_check.py @@ -0,0 +1,85 @@ +from types import SimpleNamespace + +import pytest + + +pytestmark = pytest.mark.unit + + +class _FakeVec: + def __init__(self, size): + self.size = size + + +class _FakeCollections: + def __init__(self, names): + self.collections = [SimpleNamespace(name=name) for name in names] + + +class _FakeClient: + collection_names = ["context-engine", "stale-empty"] + valid_vector_collections = {"context-engine"} + + def __init__(self, *args, **kwargs): + self.checked = [] + + def get_collections(self): + return _FakeCollections(self.collection_names) + + def get_collection(self, name): + self.checked.append(name) + vectors = ( + {"fast-bge-base-en-v1.5": _FakeVec(768)} + if name in self.valid_vector_collections + else {} + ) + return SimpleNamespace( + config=SimpleNamespace( + params=SimpleNamespace(vectors=vectors), + hnsw_config=SimpleNamespace(m=16, ef_construct=256), + ) + ) + + def query_points(self, *args, **kwargs): + return SimpleNamespace(points=[]) + + +class _FakeEmbedding: + def embed(self, texts): + for _ in texts: + yield SimpleNamespace(tolist=lambda: [0.0] * 768) + + +def test_health_check_checks_all_collections_without_crashing_on_mismatch(monkeypatch, capsys): + import scripts.health_check as health_check + + monkeypatch.setenv("COLLECTION_NAME", "context-engine") + monkeypatch.setenv("EMBEDDING_MODEL", "fast/bge-base-en-v1.5") + monkeypatch.setattr(health_check, "QdrantClient", _FakeClient) + monkeypatch.setattr(health_check, "get_embedding_model", lambda *_: _FakeEmbedding()) + monkeypatch.setattr(health_check, "get_model_dimension", lambda *_: 768) + monkeypatch.setattr(health_check, "ensure_collections", lambda *_: 0) + + health_check.main() + + output = capsys.readouterr().out + assert "Checking collection: context-engine" in output + assert "Checking collection: stale-empty" in output + assert "Skipping vector query for stale-empty" in output + + +def test_health_check_missing_named_vector_does_not_keyerror(monkeypatch, capsys): + import scripts.health_check as health_check + + monkeypatch.setenv("COLLECTION_NAME", "stale-empty") + monkeypatch.setenv("EMBEDDING_MODEL", "fast/bge-base-en-v1.5") + monkeypatch.setattr(health_check, "QdrantClient", _FakeClient) + monkeypatch.setattr(health_check, "get_embedding_model", lambda *_: _FakeEmbedding()) + monkeypatch.setattr(health_check, "get_model_dimension", lambda *_: 768) + monkeypatch.setattr(health_check, "ensure_collections", lambda *_: 0) + + health_check.main() + + output = capsys.readouterr().out + assert "Expected vector name present" in output + assert "Skipping vector query for stale-empty" in output diff --git a/tests/test_hybrid_cache_bm25.py b/tests/test_hybrid_cache_bm25.py index 9dea510d..35ab245c 100644 --- a/tests/test_hybrid_cache_bm25.py +++ b/tests/test_hybrid_cache_bm25.py @@ -6,6 +6,23 @@ hyb = importlib.import_module("scripts.hybrid_search") +@pytest.fixture(autouse=True) +def fake_qdrant_models(monkeypatch): + hybrid_qdrant = importlib.import_module("scripts.hybrid.qdrant") + + class FakeModels: + class SearchParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class QuantizationSearchParams: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + monkeypatch.setattr(hyb, "models", FakeModels) + monkeypatch.setattr(hybrid_qdrant, "models", FakeModels) + + class _Pt: def __init__(self, pid, path, code=""): self.id = pid @@ -39,6 +56,28 @@ def search(self, **kwargs): self.calls += 1 return self._points + def get_collection(self, collection): + return type( + "CollectionInfo", + (), + { + "config": type( + "Config", + (), + { + "params": type( + "Params", + (), + { + "vectors": {"unit-test": type("Vector", (), {"size": 8})()}, + "sparse_vectors": {}, + }, + )() + }, + )() + }, + )() + class _FakeEmbed: class _Vec: @@ -131,4 +170,3 @@ def test_lexical_bm25_boost_is_gentle_and_matches_multiplier(): # Gentle behavior: overall change should be modest (within 50%) ratio = weighted / base assert 0.5 <= ratio <= 1.5, f"BM25 weighting should be gentle, got ratio={ratio:.3f}" - diff --git a/tests/test_hybrid_cli_json.py b/tests/test_hybrid_cli_json.py index d13e4d9d..1b62779d 100644 --- a/tests/test_hybrid_cli_json.py +++ b/tests/test_hybrid_cli_json.py @@ -1,12 +1,38 @@ import json import sys +import types from types import SimpleNamespace import importlib def test_hybrid_cli_json_output(monkeypatch, capsys): + class DummyClient: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + class DummyModels(types.ModuleType): + def __getattr__(self, name): + def _factory(*args, **kwargs): + return SimpleNamespace(_model=name, args=args, **kwargs) + + return _factory + + fake_models = DummyModels("qdrant_client.models") + fake_qdrant = types.ModuleType("qdrant_client") + fake_qdrant.QdrantClient = DummyClient + fake_qdrant.models = fake_models + monkeypatch.setitem(sys.modules, "qdrant_client", fake_qdrant) + monkeypatch.setitem(sys.modules, "qdrant_client.models", fake_models) + + monkeypatch.setenv("HYBRID_LEXICAL_WEIGHT", "0.20") + monkeypatch.setenv("HYBRID_LEX_VECTOR_WEIGHT", "0.20") + monkeypatch.setenv("HYBRID_DENSE_WEIGHT", "1.5") + importlib.reload(importlib.import_module("scripts.hybrid.config")) + importlib.reload(importlib.import_module("scripts.hybrid.ranking")) hy = importlib.import_module("scripts.hybrid_search") + hy = importlib.reload(hy) embedder = importlib.import_module("scripts.embedder") class DummyVec: @@ -24,11 +50,6 @@ def embed(self, texts): for _ in texts: yield DummyVec() - class DummyClient: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - def fake_dense_query(client, vec_name, vector, flt, per_query, collection_name=None, query_text=None): md = { "path": "/work/pkg/a.py", diff --git a/tests/test_hybrid_ranking.py b/tests/test_hybrid_ranking.py index 5edcc949..71055cd5 100644 --- a/tests/test_hybrid_ranking.py +++ b/tests/test_hybrid_ranking.py @@ -18,10 +18,13 @@ # Fixture: Import ranking module # ============================================================================ @pytest.fixture -def ranking_module(): +def ranking_module(monkeypatch): """Import hybrid ranking module.""" import importlib + monkeypatch.setenv("HYBRID_LEXICAL_WEIGHT", "0.20") + monkeypatch.setenv("HYBRID_LEX_VECTOR_WEIGHT", "0.20") ranking = importlib.import_module("scripts.hybrid.ranking") + ranking = importlib.reload(ranking) return ranking diff --git a/tests/test_index_journal.py b/tests/test_index_journal.py new file mode 100644 index 00000000..b436ed94 --- /dev/null +++ b/tests/test_index_journal.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python3 +import importlib +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def ws_module(monkeypatch, tmp_path): + ws_root = tmp_path / "work" + ws_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.delenv("MULTI_REPO_MODE", raising=False) + ws = importlib.import_module("scripts.workspace_state") + return importlib.reload(ws) + + +def test_index_journal_roundtrip(ws_module, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "app.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [ + {"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}, + {"path": str(file_path.with_name("old.py")), "op_type": "delete"}, + ], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ] + assert str(file_path.resolve()) in pending + assert str((file_path.with_name("old.py")).resolve()) in pending + + ws_module.update_index_journal_entry_status( + str(file_path), + status="done", + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + pending_after = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ] + assert str(file_path.resolve()) not in pending_after + assert str((file_path.with_name("old.py")).resolve()) in pending_after + + +def test_index_journal_entries_include_operation_types(ws_module, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "entry.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [ + {"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}, + {"path": str(file_path.with_name("gone.py")), "op_type": "delete"}, + ], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + entries = ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + by_path = {entry["path"]: entry for entry in entries} + assert by_path[str(file_path.resolve())]["op_type"] == "upsert" + assert by_path[str((file_path.with_name("gone.py")).resolve())]["op_type"] == "delete" + + +def test_index_journal_clear_entries(ws_module, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "entry.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [ + {"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}, + {"path": str(file_path.with_name("gone.py")), "op_type": "delete"}, + ], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + removed = ws_module.clear_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + assert removed == 2 + assert ( + ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + == [] + ) + + +def test_index_journal_aggregates_repo_scoped_entries(ws_module, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "x.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work") + ) + ] + assert str(file_path.resolve()) in pending + + +@pytest.mark.parametrize("repo_name", ["repo-1234567890abcdef", "frontend"]) +def test_index_journal_aggregates_repo_scoped_entries_in_multi_repo_mode( + monkeypatch, tmp_path, repo_name +): + ws_root = tmp_path / "work" + ws_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + file_name = "app.ts" if repo_name == "frontend" else "multi.py" + file_path = ws_root / repo_name / "src" / file_name + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(ws_root / repo_name), + repo_name=repo_name, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries(workspace_path=str(ws_root)) + ] + assert str(file_path.resolve()) in pending + + +def test_index_journal_aggregates_split_watch_and_metadata_roots(monkeypatch, tmp_path): + watch_root = tmp_path / "work" + metadata_root = tmp_path / "metadata" + repo_name = "Context-Engine-41e67959950c8ab3" + file_path = watch_root / repo_name / "src" / "split.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + metadata_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("WATCH_ROOT", str(watch_root)) + monkeypatch.setenv("WORK_DIR", str(watch_root)) + monkeypatch.setenv("CTXCE_METADATA_ROOT", str(metadata_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(file_path.parent.parent), + repo_name=repo_name, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries(workspace_path=str(watch_root)) + ] + assert str(file_path.resolve()) in pending + + ws_module.update_index_journal_entry_status( + str(file_path), + status="done", + workspace_path=str(file_path.parent.parent), + repo_name=repo_name, + ) + pending_after = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries(workspace_path=str(watch_root)) + ] + assert str(file_path.resolve()) not in pending_after + + +def test_index_journal_file_is_group_writable(ws_module, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "perm.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + + journal_path = ws_module._get_index_journal_path( + str(tmp_path / "work" / repo_name), repo_name + ) + assert journal_path.exists() + assert oct(journal_path.stat().st_mode & 0o777) == "0o666" + + +def test_index_journal_failed_entry_respects_retry_delay(ws_module, monkeypatch, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "retry.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("INDEX_JOURNAL_RETRY_DELAY_SECS", "60") + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ws_module.update_index_journal_entry_status( + str(file_path), + status="failed", + error="boom", + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + remove_on_done=False, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ] + assert str(file_path.resolve()) not in pending + + +def test_index_journal_failed_entry_honors_max_attempts(ws_module, monkeypatch, tmp_path): + repo_name = "repo-1234567890abcdef" + file_path = tmp_path / "work" / repo_name / "src" / "retry2.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("INDEX_JOURNAL_RETRY_DELAY_SECS", "0") + monkeypatch.setenv("INDEX_JOURNAL_MAX_ATTEMPTS", "1") + + ws_module.upsert_index_journal_entries( + [{"path": str(file_path), "op_type": "upsert", "content_hash": "abc123"}], + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ws_module.update_index_journal_entry_status( + str(file_path), + status="failed", + error="boom", + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + remove_on_done=False, + ) + + pending = [ + str(e["path"]) + for e in ws_module.list_pending_index_journal_entries( + workspace_path=str(tmp_path / "work" / repo_name), + repo_name=repo_name, + ) + ] + assert str(file_path.resolve()) not in pending + + +def test_processor_delete_marks_journal_done(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + missing = tmp_path / "missing.py" + assert not missing.exists() + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "get_workspace_state", lambda *a, **k: {}) + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: False) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "remove_cached_file", lambda *a, **k: None) + + delete_mock = MagicMock() + graph_delete_mock = MagicMock() + journal_mock = MagicMock() + monkeypatch.setattr(proc_mod.idx, "delete_points_by_path", delete_mock) + monkeypatch.setattr(proc_mod.idx, "delete_graph_edges_by_path", graph_delete_mock) + monkeypatch.setattr(proc_mod, "_verify_delete_committed", lambda *a, **k: True) + monkeypatch.setattr(proc_mod, "_verify_graph_delete_committed", lambda *a, **k: True) + monkeypatch.setattr(proc_mod, "update_index_journal_entry_status", journal_mock) + + proc_mod._process_paths( + [missing], + client=MagicMock(), + model=None, + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + delete_mock.assert_called_once() + assert graph_delete_mock.call_count == 2 + assert graph_delete_mock.call_args_list[0].kwargs["repo"] == "repo" + assert graph_delete_mock.call_args_list[1].kwargs["repo"] is None + journal_mock.assert_called_once() + assert journal_mock.call_args.kwargs["status"] == "done" + + +def test_processor_honors_delete_journal_for_existing_file(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + existing = tmp_path / "present.py" + existing.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "get_workspace_state", lambda *a, **k: {}) + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: False) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "remove_cached_file", lambda *a, **k: None) + monkeypatch.setattr( + proc_mod, + "list_pending_index_journal_entries", + lambda *a, **k: [{"path": str(existing.resolve()), "op_type": "delete"}], + ) + + delete_mock = MagicMock() + graph_delete_mock = MagicMock() + journal_mock = MagicMock() + monkeypatch.setattr(proc_mod.idx, "delete_points_by_path", delete_mock) + monkeypatch.setattr(proc_mod.idx, "delete_graph_edges_by_path", graph_delete_mock) + monkeypatch.setattr(proc_mod, "_verify_delete_committed", lambda *a, **k: True) + monkeypatch.setattr(proc_mod, "_verify_graph_delete_committed", lambda *a, **k: True) + monkeypatch.setattr(proc_mod, "update_index_journal_entry_status", journal_mock) + + proc_mod._process_paths( + [existing], + client=MagicMock(), + model=None, + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + delete_mock.assert_called_once() + assert graph_delete_mock.call_count == 2 + assert graph_delete_mock.call_args_list[0].kwargs["repo"] == "repo" + assert graph_delete_mock.call_args_list[1].kwargs["repo"] is None + journal_mock.assert_called_once() + assert journal_mock.call_args.kwargs["status"] == "done" + + +def test_processor_relinks_move_journal_before_delete(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + src = tmp_path / "src.py" + dest = tmp_path / "dest.py" + dest.write_text("print('dest')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "get_workspace_state", lambda *a, **k: {}) + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: False) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "remove_cached_file", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "set_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr( + proc_mod, + "list_pending_index_journal_entries", + lambda *a, **k: [ + {"path": str(src.resolve()), "op_type": "delete", "content_hash": "cafebabe"}, + {"path": str(dest.resolve()), "op_type": "upsert", "content_hash": "cafebabe"}, + ], + ) + + rename_mock = MagicMock(return_value=(3, "cafebabe")) + delete_mock = MagicMock() + journal_mock = MagicMock() + monkeypatch.setattr(proc_mod, "_rename_in_store", rename_mock) + monkeypatch.setattr(proc_mod.idx, "delete_points_by_path", delete_mock) + monkeypatch.setattr(proc_mod, "update_index_journal_entry_status", journal_mock) + + proc_mod._process_paths( + [src, dest], + client=MagicMock(), + model=MagicMock(), + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + rename_mock.assert_called_once() + delete_mock.assert_not_called() + done_paths = [call.args[0] for call in journal_mock.call_args_list if call.kwargs.get("status") == "done"] + assert str(dest.resolve()) in done_paths + assert str(src.resolve()) in done_paths + + +def test_processor_skips_internal_git_path_without_collection_resolution(monkeypatch): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + internal = Path("/work/.git/HEAD") + + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "get_workspace_state", lambda *a, **k: {}) + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: False) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + + collection_mock = MagicMock(return_value="should-not-be-used") + journal_mock = MagicMock() + monkeypatch.setattr(proc_mod, "_get_collection_for_file", collection_mock) + monkeypatch.setattr(proc_mod, "update_index_journal_entry_status", journal_mock) + monkeypatch.setattr( + proc_mod, + "list_pending_index_journal_entries", + lambda *a, **k: [{"path": str(internal), "op_type": "delete"}], + ) + + proc_mod._process_paths( + [internal], + client=MagicMock(), + model=None, + vector_name="vec", + model_dim=1, + workspace_path="/work", + ) + + collection_mock.assert_not_called() + journal_mock.assert_called_once() + assert journal_mock.call_args.kwargs["status"] == "done" + + +def test_processor_force_upsert_empty_file_marks_done(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + empty_file = tmp_path / "pkg" / "__init__.py" + empty_file.parent.mkdir(parents=True, exist_ok=True) + empty_file.write_text("", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "get_workspace_state", lambda *a, **k: {}) + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: False) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "remove_cached_file", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_run_indexing_strategy", lambda *a, **k: False) + monkeypatch.setattr(proc_mod, "_path_has_indexed_points", lambda *a, **k: False) + + journal_mock = MagicMock() + monkeypatch.setattr(proc_mod, "update_index_journal_entry_status", journal_mock) + monkeypatch.setattr( + proc_mod, + "list_pending_index_journal_entries", + lambda *a, **k: [ + { + "path": str(empty_file.resolve()), + "op_type": "upsert", + "content_hash": "da39a3ee5e6b4b0d3255bfef95601890afd80709", + } + ], + ) + + proc_mod._process_paths( + [empty_file], + client=MagicMock(), + model=MagicMock(), + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + journal_mock.assert_called_once() + assert journal_mock.call_args.kwargs["status"] == "done" diff --git a/tests/test_ingest_cli.py b/tests/test_ingest_cli.py new file mode 100644 index 00000000..493d1a68 --- /dev/null +++ b/tests/test_ingest_cli.py @@ -0,0 +1,57 @@ +import sys +from pathlib import Path + +import pytest + + +@pytest.mark.unit +def test_cli_force_collection_disables_multi_repo_enumeration(monkeypatch, tmp_path: Path): + from scripts.ingest import cli + + # Create fake repo dirs to prove we are not enumerating them. + (tmp_path / "repo_a").mkdir() + (tmp_path / "repo_b").mkdir() + + calls = [] + + def _fake_index_repo( + root, + qdrant_url, + api_key, + collection, + model_name, + recreate, + dedupe, + skip_unchanged, + pseudo_mode, + schema_mode, + ): + calls.append( + { + "root": Path(root), + "collection": collection, + "recreate": recreate, + "dedupe": dedupe, + "skip_unchanged": skip_unchanged, + } + ) + + monkeypatch.setattr(cli, "index_repo", _fake_index_repo) + monkeypatch.setattr(cli, "is_multi_repo_mode", lambda: True) + monkeypatch.setattr(cli, "get_collection_name", lambda *_: "should-not-use") + + monkeypatch.setenv("MULTI_REPO_MODE", "1") + monkeypatch.setenv("COLLECTION_NAME", "forced-collection") + monkeypatch.setenv("CTXCE_FORCE_COLLECTION_NAME", "1") + + monkeypatch.setattr( + sys, + "argv", + ["ingest_code.py", "--root", str(tmp_path)], + ) + + cli.main() + + assert len(calls) == 1 + assert calls[0]["root"] == tmp_path + assert calls[0]["collection"] == "forced-collection" diff --git a/tests/test_ingest_schema_mode.py b/tests/test_ingest_schema_mode.py index c766089b..461212e3 100644 --- a/tests/test_ingest_schema_mode.py +++ b/tests/test_ingest_schema_mode.py @@ -91,6 +91,7 @@ def test_schema_mode_validate_errors_on_missing_vectors(monkeypatch): def test_schema_mode_migrate_adds_missing_vectors_and_indexes(monkeypatch): monkeypatch.setenv("PATTERN_VECTORS", "1") monkeypatch.setattr(ingq, "LEX_SPARSE_MODE", False) + ingq.ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard("test-collection") existing_vectors = { "code": object(), @@ -122,6 +123,7 @@ def test_schema_mode_migrate_adds_missing_vectors_and_indexes(monkeypatch): def test_schema_mode_create_creates_collection_only(monkeypatch): monkeypatch.setenv("PATTERN_VECTORS", "0") monkeypatch.setattr(ingq, "LEX_SPARSE_MODE", False) + ingq.ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard("test-collection") client = FakeClient(collection_exists=False) @@ -138,3 +140,15 @@ def test_schema_mode_create_creates_collection_only(monkeypatch): assert any( c["field_name"] == "metadata.language" for c in client.payload_index_calls ) + + +def test_ensure_payload_indexes_memoized_per_process(): + client = FakeClient(collection_exists=True) + ingq.ENSURED_PAYLOAD_INDEX_COLLECTIONS.discard("test-collection") + + ingq.ensure_payload_indexes(client, "test-collection") + first_count = len(client.payload_index_calls) + ingq.ensure_payload_indexes(client, "test-collection") + + assert first_count == len(ingq.PAYLOAD_INDEX_FIELDS) + assert len(client.payload_index_calls) == first_count diff --git a/tests/test_integration_qdrant.py b/tests/test_integration_qdrant.py index 65cef7f5..f40442ce 100644 --- a/tests/test_integration_qdrant.py +++ b/tests/test_integration_qdrant.py @@ -1,6 +1,7 @@ import os import json import uuid +import asyncio import importlib import pytest @@ -41,11 +42,11 @@ def embed(self, texts): @pytest.mark.integration def test_index_and_search_minirepo(tmp_path, monkeypatch, qdrant_container): # Env for services - os.environ["QDRANT_URL"] = qdrant_container - os.environ["COLLECTION_NAME"] = f"test-{uuid.uuid4().hex[:8]}" - os.environ["USE_TREE_SITTER"] = "0" - os.environ["HYBRID_IN_PROCESS"] = "1" - os.environ["EMBEDDING_MODEL"] = "fake" + monkeypatch.setenv("QDRANT_URL", qdrant_container) + monkeypatch.setenv("COLLECTION_NAME", f"test-{uuid.uuid4().hex[:8]}") + monkeypatch.setenv("USE_TREE_SITTER", "0") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("EMBEDDING_MODEL", "fake") # Stub embeddings everywhere (FakeEmbedder produces 32-dim vectors) monkeypatch.setattr(ing, "TextEmbedding", lambda *a, **k: FakeEmbedder("fake")) @@ -75,7 +76,7 @@ def test_index_and_search_minirepo(tmp_path, monkeypatch, qdrant_container): ) # Search directly via async function - res = srv.asyncio.get_event_loop().run_until_complete( + res = asyncio.run( srv.repo_search( queries=["def f"], limit=5, @@ -92,11 +93,11 @@ def test_index_and_search_minirepo(tmp_path, monkeypatch, qdrant_container): @pytest.mark.integration def test_filters_language_and_path(tmp_path, monkeypatch, qdrant_container): # Reuse container; set env - os.environ["QDRANT_URL"] = qdrant_container - os.environ.setdefault("COLLECTION_NAME", f"test-{uuid.uuid4().hex[:8]}") - os.environ["USE_TREE_SITTER"] = "0" - os.environ["HYBRID_IN_PROCESS"] = "1" - os.environ["EMBEDDING_MODEL"] = "fake" + monkeypatch.setenv("QDRANT_URL", qdrant_container) + monkeypatch.setenv("COLLECTION_NAME", f"test-{uuid.uuid4().hex[:8]}") + monkeypatch.setenv("USE_TREE_SITTER", "0") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("EMBEDDING_MODEL", "fake") # Stub embeddings (FakeEmbedder produces 32-dim vectors) monkeypatch.setattr(ing, "TextEmbedding", lambda *a, **k: FakeEmbedder("fake")) @@ -127,19 +128,19 @@ def test_filters_language_and_path(tmp_path, monkeypatch, qdrant_container): f_md = str(tmp_path / "pkg" / "b.md") # Filter by language=python should bias toward .py - res1 = srv.asyncio.get_event_loop().run_until_complete( + res1 = asyncio.run( srv.repo_search(queries=["def"], limit=5, language="python", compact=False) ) assert any(f_py in (r.get("path") or "") for r in res1.get("results", [])) # Filter by ext=txt should retrieve text file - res2 = srv.asyncio.get_event_loop().run_until_complete( + res2 = asyncio.run( srv.repo_search(queries=["hello"], limit=5, ext="md", compact=False) ) assert any(f_md in (r.get("path") or "") for r in res2.get("results", [])) # Path glob to only allow pkg/*.py - res3 = srv.asyncio.get_event_loop().run_until_complete( + res3 = asyncio.run( srv.repo_search( queries=["def"], limit=5, diff --git a/tests/test_mcp_router.py b/tests/test_mcp_router.py deleted file mode 100644 index 776e0401..00000000 --- a/tests/test_mcp_router.py +++ /dev/null @@ -1,233 +0,0 @@ -#!/usr/bin/env python3 -""" -Tests for mcp_router.py - Intent classification and tool routing. - -Tests cover: -- Intent classification (rule-based and ML fallback) -- Plan building for various query types -- HTTP client helpers -""" -import importlib -import os - -import pytest - -pytestmark = pytest.mark.unit - - -# ============================================================================ -# Fixture: Router module import -# ============================================================================ -@pytest.fixture -def router_module(monkeypatch): - """Import mcp_router with isolated environment.""" - monkeypatch.delenv("MCP_HTTP_URL", raising=False) - monkeypatch.delenv("MCP_INDEXER_HTTP_URL", raising=False) - - router = importlib.import_module("scripts.mcp_router") - return importlib.reload(router) - - -# ============================================================================ -# Tests: Intent Constants -# ============================================================================ -class TestIntentConstants: - """Tests for intent constant definitions.""" - - def test_intent_constants_defined(self, router_module): - """All expected intent constants are defined.""" - assert router_module.INTENT_ANSWER == "answer" - assert router_module.INTENT_SEARCH == "search" - assert router_module.INTENT_INDEX == "index" - assert router_module.INTENT_PRUNE == "prune" - assert router_module.INTENT_STATUS == "status" - assert router_module.INTENT_LIST == "list" - - def test_search_specialized_intents(self, router_module): - """Specialized search intents are defined.""" - assert router_module.INTENT_SEARCH_TESTS == "search_tests" - assert router_module.INTENT_SEARCH_CONFIG == "search_config" - assert router_module.INTENT_SEARCH_CALLERS == "search_callers" - - -# ============================================================================ -# Tests: Intent Classification (Rule-based) -# ============================================================================ -class TestClassifyIntentRules: - """Tests for _classify_intent_rules function.""" - - def test_status_intent_patterns(self, router_module): - """Status-related queries are classified correctly.""" - status_queries = [ - "qdrant status", - "indexing status", - "collection status", - ] - for q in status_queries: - intent = router_module._classify_intent_rules(q) - assert intent == router_module.INTENT_STATUS, f"Failed for: {q}" - - def test_list_intent_patterns(self, router_module): - """List-related queries are classified correctly.""" - list_queries = [ - "list collections", - "show all collections", - ] - for q in list_queries: - intent = router_module._classify_intent_rules(q) - assert intent == router_module.INTENT_LIST, f"Failed for: {q}" - - def test_search_tests_intent(self, router_module): - """Test search queries are classified correctly.""" - queries = [ - "find tests for foo", - "search for test files", - ] - for q in queries: - intent = router_module._classify_intent_rules(q) - assert intent == router_module.INTENT_SEARCH_TESTS, f"Failed for: {q}" - - def test_search_config_intent(self, router_module): - """Config search queries are classified correctly.""" - queries = [ - "find config for database", - "where is the yaml config", - ] - for q in queries: - intent = router_module._classify_intent_rules(q) - assert intent == router_module.INTENT_SEARCH_CONFIG, f"Failed for: {q}" - - -# ============================================================================ -# Tests: High-level classify_intent -# ============================================================================ -class TestClassifyIntent: - """Tests for the main classify_intent function.""" - - def test_classify_intent_returns_intent(self, router_module): - """classify_intent returns a valid intent string.""" - intent = router_module.classify_intent("reindex the codebase") - # Should return some intent (index or answer depending on ML) - assert intent is not None - assert isinstance(intent, str) - - def test_classify_intent_status(self, router_module): - """Status queries classified correctly.""" - intent = router_module.classify_intent("qdrant status") - assert intent == router_module.INTENT_STATUS - - def test_classify_intent_list(self, router_module): - """List queries classified correctly.""" - intent = router_module.classify_intent("list collections") - assert intent == router_module.INTENT_LIST - - -# ============================================================================ -# Tests: Build Plan -# ============================================================================ -class TestBuildPlan: - """Tests for build_plan function.""" - - def test_build_plan_returns_list(self, router_module): - """build_plan returns a list of (tool, args) tuples.""" - # Use a query that triggers rule-based classification (avoids embedding model) - plan = router_module.build_plan("list collections") - - assert isinstance(plan, list) - assert len(plan) >= 1 - # Each item is a tuple of (tool_name, args_dict) - tool_name, args = plan[0] - assert isinstance(tool_name, str) - assert isinstance(args, dict) - - def test_build_plan_status_tool(self, router_module): - """Status queries map to qdrant_status tool.""" - plan = router_module.build_plan("qdrant status") - - tool_name, args = plan[0] - assert tool_name == "qdrant_status" - - def test_build_plan_list_tool(self, router_module): - """List queries map to qdrant_list tool.""" - plan = router_module.build_plan("list collections") - - tool_name, args = plan[0] - assert tool_name == "qdrant_list" - - def test_build_plan_search_tests_tool(self, router_module): - """Test search queries map to search_tests_for tool.""" - plan = router_module.build_plan("find tests for authentication") - - tool_name, args = plan[0] - assert tool_name == "search_tests_for" - assert "query" in args - - def test_build_plan_search_config_tool(self, router_module): - """Config search queries map to search_config_for tool.""" - plan = router_module.build_plan("find config for database") - - tool_name, args = plan[0] - assert tool_name == "search_config_for" - - def test_build_plan_includes_query(self, router_module): - """build_plan includes the query in args for search tools.""" - # Use a query that triggers rule-based classification (avoids embedding model) - plan = router_module.build_plan("find tests for authentication") - - tool_name, args = plan[0] - # search_tests_for includes query in args - assert "query" in args or tool_name in {"qdrant_status", "qdrant_list", "qdrant_prune"} - - -# ============================================================================ -# Tests: HTTP Helpers -# ============================================================================ -class TestHttpHelpers: - """Tests for HTTP client helper functions.""" - - def test_filter_args_removes_none(self, router_module): - """_filter_args removes None values from dict.""" - args = {"a": 1, "b": None, "c": "hello", "d": None} - filtered = router_module._filter_args(args) - - assert filtered == {"a": 1, "c": "hello"} - - def test_filter_args_preserves_false(self, router_module): - """_filter_args preserves False and 0 values.""" - args = {"a": False, "b": 0, "c": None} - filtered = router_module._filter_args(args) - - assert "a" in filtered - assert "b" in filtered - assert "c" not in filtered - - def test_parse_stream_or_json_parses_json(self, router_module): - """_parse_stream_or_json parses valid JSON.""" - body = b'{"result": "success"}' - parsed = router_module._parse_stream_or_json(body) - - assert parsed == {"result": "success"} - - -# ============================================================================ -# Tests: Failure Response Detection -# ============================================================================ -class TestFailureResponseDetection: - """Tests for _is_failure_response function.""" - - def test_success_response_not_failure(self, router_module): - """Successful responses are not failures.""" - resp = {"result": "data", "ok": True} - assert router_module._is_failure_response(resp) is False - - def test_empty_response_not_failure(self, router_module): - """Empty dicts are not failures.""" - assert router_module._is_failure_response({}) is False - - def test_detects_isError_true(self, router_module): - """Detects responses with isError=True.""" - # Note: depends on actual implementation - resp = {"isError": True, "content": []} - result = router_module._is_failure_response(resp) - # May be True or False depending on implementation - assert isinstance(result, bool) diff --git a/tests/test_micro_span_budget.py b/tests/test_micro_span_budget.py index 5648de30..65ed82b8 100644 --- a/tests/test_micro_span_budget.py +++ b/tests/test_micro_span_budget.py @@ -98,7 +98,7 @@ def test_adaptive_span_sizing_failure_is_non_fatal(monkeypatch): monkeypatch.setenv("COLLECTION_NAME", "dummy") # Force extent lookup to throw; the budgeter should swallow it. - import scripts.hybrid_ranking as hr + from scripts.hybrid import ranking as hr monkeypatch.setattr(hr, "_get_symbol_extent", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("boom"))) items = [ diff --git a/tests/test_negative_args.py b/tests/test_negative_args.py index 32dadd52..a0d819a1 100644 --- a/tests/test_negative_args.py +++ b/tests/test_negative_args.py @@ -1,4 +1,5 @@ import os +import asyncio import pytest import scripts.mcp_indexer_server as srv @@ -6,6 +7,8 @@ @pytest.mark.service def test_repo_search_conflicting_filters_empty_ok(monkeypatch): + # This test validates hybrid filter handling (non-dense path), not dense-default mode. + # Keep mode explicit so global REPO_SEARCH_DEFAULT_MODE=dense does not change test semantics. # In-process, but no results due to conflicting filters (simulate by returning []) monkeypatch.setenv("HYBRID_IN_PROCESS", "1") monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) @@ -14,8 +17,14 @@ def test_repo_search_conflicting_filters_empty_ok(monkeypatch): monkeypatch.setattr(hy, "run_hybrid_search", lambda *a, **k: []) - res = srv.asyncio.get_event_loop().run_until_complete( - srv.repo_search(queries=["foo"], limit=3, ext="cpp", compact=True) + res = asyncio.run( + srv.repo_search( + queries=["foo"], + limit=3, + ext="cpp", + compact=True, + mode="hybrid", + ) ) assert res.get("ok") is True diff --git a/tests/test_path_scope.py b/tests/test_path_scope.py new file mode 100644 index 00000000..06369860 --- /dev/null +++ b/tests/test_path_scope.py @@ -0,0 +1,61 @@ +import importlib + + +ps = importlib.import_module("scripts.path_scope") + + +def test_normalize_under_strips_work_prefix(): + assert ps.normalize_under("/work/scripts/mcp_impl") == "scripts/mcp_impl" + + +def test_normalize_under_keeps_repo_prefixed_path(): + assert ( + ps.normalize_under("/work/Context-Engine/scripts/mcp_impl") + == "Context-Engine/scripts/mcp_impl" + ) + + +def test_normalize_under_rebases_single_segment_from_cwd(monkeypatch, tmp_path): + repo = tmp_path / "repo" + (repo / "nested" / "scope").mkdir(parents=True) + monkeypatch.setattr(ps, "_repo_root_hint", lambda: str(repo)) + monkeypatch.setattr(ps.os, "getcwd", lambda: str(repo / "nested")) + + assert ps.normalize_under("scope") == "nested/scope" + + +def test_normalize_under_does_not_rebase_when_top_level_exists(monkeypatch, tmp_path): + repo = tmp_path / "repo" + (repo / "nested" / "scope").mkdir(parents=True) + (repo / "scope").mkdir(parents=True) + monkeypatch.setattr(ps, "_repo_root_hint", lambda: str(repo)) + monkeypatch.setattr(ps.os, "getcwd", lambda: str(repo / "nested")) + + assert ps.normalize_under("scope") == "scope" + + +def test_normalize_under_expands_unique_segment(monkeypatch, tmp_path): + repo = tmp_path / "repo" + (repo / "alpha" / "mcp_impl").mkdir(parents=True) + monkeypatch.setattr(ps, "_repo_root_hint", lambda: str(repo)) + monkeypatch.setattr(ps.os, "getcwd", lambda: str(repo)) + ps._unique_segment_path.cache_clear() + + assert ps.normalize_under("mcp_impl") == "alpha/mcp_impl" + + +def test_normalize_under_keeps_ambiguous_segment(monkeypatch, tmp_path): + repo = tmp_path / "repo" + (repo / "alpha" / "dup").mkdir(parents=True) + (repo / "beta" / "dup").mkdir(parents=True) + monkeypatch.setattr(ps, "_repo_root_hint", lambda: str(repo)) + monkeypatch.setattr(ps.os, "getcwd", lambda: str(repo)) + ps._unique_segment_path.cache_clear() + + assert ps.normalize_under("dup") == "dup" + + +def test_metadata_matches_under_without_repo_hint_for_work_repo_paths(): + md = {"path": "/work/repo/space/ship/a.py"} + assert ps.metadata_matches_under(md, "space") + assert not ps.metadata_matches_under(md, "direct") diff --git a/tests/test_pattern_search_e2e.py b/tests/test_pattern_search_e2e.py index 9385aba9..10c2a8cd 100644 --- a/tests/test_pattern_search_e2e.py +++ b/tests/test_pattern_search_e2e.py @@ -161,7 +161,7 @@ def pattern_collection(): pass -@pytest.mark.service +@pytest.mark.integration def test_pattern_search_qdrant(pattern_collection): """Test pattern search against real Qdrant.""" from scripts.pattern_detection.search import pattern_search diff --git a/tests/test_per_path_zero.py b/tests/test_per_path_zero.py index e1541ffd..2e6a164d 100644 --- a/tests/test_per_path_zero.py +++ b/tests/test_per_path_zero.py @@ -1,24 +1,53 @@ import asyncio +import json import pytest # These tests exercise argument plumbing independent of live retrieval. @pytest.mark.asyncio -async def test_per_path_zero_is_echoed_and_respected_in_args(): - from scripts.mcp_indexer_server import repo_search +async def test_per_path_zero_is_echoed_and_respected_in_args(monkeypatch): + from scripts.mcp_impl.search import _repo_search_impl - res = await repo_search(query="anything", limit=3, per_path=0) + async def _fake_run_async(_cmd, **_kwargs): + item = {"path": "src/a.py", "start_line": 1, "end_line": 1, "score": 1.0} + return {"ok": True, "code": 0, "stdout": json.dumps(item), "stderr": ""} + + monkeypatch.setenv("HYBRID_IN_PROCESS", "0") + + # Arg-plumbing test for the hybrid/subprocess (non-dense) path; mode is explicit by design. + res = await _repo_search_impl( + query="anything", + limit=3, + per_path=0, + mode="hybrid", + require_auth_session_fn=lambda session: session, + run_async_fn=_fake_run_async, + ) assert isinstance(res, dict) args = res.get("args") or {} assert args.get("per_path") == 0, f"expected per_path echoed as 0, got {args.get('per_path')}" @pytest.mark.asyncio -async def test_compact_string_false_is_normalized_in_args(): - from scripts.mcp_indexer_server import repo_search +async def test_compact_string_false_is_normalized_in_args(monkeypatch): + from scripts.mcp_impl.search import _repo_search_impl + + async def _fake_run_async(_cmd, **_kwargs): + item = {"path": "src/a.py", "start_line": 1, "end_line": 1, "score": 1.0} + return {"ok": True, "code": 0, "stdout": json.dumps(item), "stderr": ""} + + monkeypatch.setenv("HYBRID_IN_PROCESS", "0") - # Passing compact as a string "false" should normalize to False in echoed args - res = await repo_search(query="anything", limit=1, compact="false") + # Passing compact as a string "false" should normalize to False in echoed args. + # Keep mode explicit so dense-default env does not alter this contract test. + res = await _repo_search_impl( + query="anything", + limit=1, + compact="false", + mode="hybrid", + require_auth_session_fn=lambda session: session, + run_async_fn=_fake_run_async, + ) assert isinstance(res, dict) args = res.get("args") or {} assert args.get("compact") is False, f"expected compact False, got {args.get('compact')}" diff --git a/tests/test_prune.py b/tests/test_prune.py new file mode 100644 index 00000000..42641883 --- /dev/null +++ b/tests/test_prune.py @@ -0,0 +1,54 @@ +from types import SimpleNamespace + +import scripts.prune as prune + + +class _FakeClient: + def __init__(self, points): + self._points = points + + def scroll(self, **kwargs): + return self._points, None + + +def _point(path, file_hash=None, repo="repo-a"): + return SimpleNamespace( + payload={ + "metadata": { + "path": path, + "file_hash": file_hash, + "repo": repo, + } + } + ) + + +def test_prune_excludes_deleted_paths_from_orphan_keepalive(monkeypatch, tmp_path): + keep_path = tmp_path / "keep.py" + keep_path.write_text("keep = True\n", encoding="utf-8") + + mismatch_path = tmp_path / "mismatch.py" + mismatch_path.write_text("new = True\n", encoding="utf-8") + + points = [ + _point("missing.py", file_hash="missing-hash"), + _point("mismatch.py", file_hash="old-hash"), + _point("keep.py", file_hash=prune.sha1_file(keep_path)), + ] + fake_client = _FakeClient(points) + captured_valid_paths = [] + + monkeypatch.setattr(prune, "QdrantClient", lambda **kwargs: fake_client) + monkeypatch.setattr(prune, "ROOT", tmp_path) + monkeypatch.setattr(prune, "delete_by_path", lambda *args, **kwargs: 1) + monkeypatch.setattr(prune, "delete_graph_edges_by_path", lambda *args, **kwargs: 0) + monkeypatch.setattr( + prune, + "delete_orphan_graph_edges", + lambda client, valid_paths: captured_valid_paths.append(set(valid_paths)) or 0, + ) + + prune.main() + + assert captured_valid_paths == [{"keep.py"}] + diff --git a/tests/test_qdrant_version_pins.py b/tests/test_qdrant_version_pins.py new file mode 100644 index 00000000..4b08d966 --- /dev/null +++ b/tests/test_qdrant_version_pins.py @@ -0,0 +1,39 @@ +from pathlib import Path +from importlib.metadata import version +import inspect + +from qdrant_client import QdrantClient + + +ROOT = Path(__file__).resolve().parents[1] +QDRANT_CLIENT_PIN = "qdrant-client==1.15.1" +QDRANT_SERVER_IMAGE = "qdrant/qdrant:v1.15.4" + + +def test_qdrant_client_is_exactly_pinned(): + requirements = (ROOT / "requirements.txt").read_text() + + assert QDRANT_CLIENT_PIN in requirements + assert "qdrant-client>=" not in requirements + + +def test_qdrant_server_images_are_exactly_pinned(): + files = [ + ROOT / ".github/workflows/ci.yml", + ROOT / "docker-compose.yml", + ROOT / "docker-compose-bindmount-checkout.yml", + ROOT / "deploy/kubernetes/qdrant.yaml", + ROOT / "tests/conftest.py", + ] + + for path in files: + text = path.read_text() + assert QDRANT_SERVER_IMAGE in text, str(path) + assert "qdrant/qdrant:latest" not in text, str(path) + + +def test_installed_qdrant_client_matches_supported_api(): + assert version("qdrant-client") == "1.15.1" + assert hasattr(QdrantClient, "search") + assert hasattr(QdrantClient, "query_points") + assert "query_filter" in inspect.signature(QdrantClient.query_points).parameters diff --git a/tests/test_repo_search_mode_contract.py b/tests/test_repo_search_mode_contract.py new file mode 100644 index 00000000..b7152696 --- /dev/null +++ b/tests/test_repo_search_mode_contract.py @@ -0,0 +1,186 @@ +import asyncio +import importlib +import sys +import types + +import pytest + +srv = importlib.import_module("scripts.mcp_indexer_server") + + +def _make_hybrid_module_stub(calls: dict): + mod = types.ModuleType("scripts.hybrid_search") + + def run_pure_dense_search(**kwargs): + calls["dense"] = int(calls.get("dense", 0)) + 1 + calls["dense_kwargs"] = dict(kwargs) + return [ + { + "score": 0.91, + "path": "/work/dense.py", + "symbol": "", + "start_line": 1, + "end_line": 3, + "payload": {}, + } + ] + + def run_hybrid_search(**kwargs): + calls["hybrid"] = int(calls.get("hybrid", 0)) + 1 + calls["hybrid_kwargs"] = dict(kwargs) + return [ + { + "score": 0.75, + "path": "/work/hybrid.py", + "symbol": "", + "start_line": 4, + "end_line": 7, + } + ] + + mod.run_pure_dense_search = run_pure_dense_search + mod.run_hybrid_search = run_hybrid_search + mod.lang_matches_path = lambda path, lang=None: True + mod._merge_and_budget_spans = lambda spans, *args, **kwargs: spans + mod.TextEmbedding = object + mod.QdrantClient = object + return mod + + +@pytest.mark.service +def test_repo_search_dense_default_from_env_is_explicit_and_stable(monkeypatch): + # Contract: global default mode should route repo_search to dense path when set to dense. + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + res = asyncio.run(srv.repo_search(query="q", limit=1, compact=True, rerank_enabled=False)) + + assert res.get("ok") is True + assert calls["dense"] == 1 + assert calls["hybrid"] == 0 + assert res.get("results", [{}])[0].get("path") == "/work/dense.py" + + +@pytest.mark.service +def test_repo_search_explicit_hybrid_overrides_dense_default_for_non_dense_tests(monkeypatch): + # Contract: non-dense tests can force hybrid behavior even under dense global default. + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + res = asyncio.run( + srv.repo_search( + query="q", + mode="hybrid", + limit=1, + compact=True, + rerank_enabled=False, + ) + ) + + assert res.get("ok") is True + assert calls["dense"] == 0 + assert calls["hybrid"] == 1 + assert res.get("results", [{}])[0].get("path") == "/work/hybrid.py" + + +@pytest.mark.service +def test_repo_search_dense_default_forwards_structured_filters(monkeypatch): + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + asyncio.run( + srv.repo_search( + query="q", + limit=1, + compact=True, + rerank_enabled=False, + kind="function", + symbol="my_symbol", + ext="py", + ) + ) + + assert calls["dense"] == 1 + dense_kwargs = calls.get("dense_kwargs") or {} + assert dense_kwargs.get("kind") == "function" + assert dense_kwargs.get("symbol") == "my_symbol" + assert dense_kwargs.get("ext") == "py" + + +@pytest.mark.service +def test_repo_search_dense_default_forwards_per_path(monkeypatch): + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + asyncio.run( + srv.repo_search( + query="q", + limit=3, + per_path=1, + compact=True, + rerank_enabled=False, + ) + ) + + assert calls["dense"] == 1 + assert calls.get("dense_kwargs", {}).get("per_path") == 1 + + +@pytest.mark.service +def test_repo_search_profile_tests_adds_material_globs(monkeypatch): + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + res = asyncio.run( + srv.repo_search( + query="q", + profile="tests", + limit=1, + compact=False, + rerank_enabled=False, + ) + ) + + args = res.get("args") or {} + assert args.get("profile") == "tests" + assert "tests/**" in args.get("path_glob", []) + assert "**/*_test.*" in args.get("path_glob", []) + + +@pytest.mark.service +def test_repo_search_profile_preserves_user_globs(monkeypatch): + calls = {"dense": 0, "hybrid": 0} + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "dense") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: object()) + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_module_stub(calls)) + + res = asyncio.run( + srv.repo_search( + query="q", + profile="config", + path_glob=["custom/**"], + limit=1, + compact=False, + rerank_enabled=False, + ) + ) + + globs = (res.get("args") or {}).get("path_glob", []) + assert globs[0] == "custom/**" + assert "**/*.yaml" in globs diff --git a/tests/test_rerank_recursive.py b/tests/test_rerank_recursive.py index 5f8c5822..8acc26e2 100644 --- a/tests/test_rerank_recursive.py +++ b/tests/test_rerank_recursive.py @@ -14,16 +14,21 @@ from typing import List, Dict, Any -# Import the reranker -from scripts.rerank_recursive import ( +from scripts.rerank_recursive.alpha_scheduler import CosineAlphaScheduler, LearnedAlphaWeights +from scripts.rerank_recursive.confidence import ConfidenceEstimator +from scripts.rerank_recursive.recursive import ( RecursiveReranker, - RefinementState, - TinyScorer, - LatentRefiner, - ConfidenceEstimator, rerank_recursive, rerank_recursive_inprocess, ) +from scripts.rerank_recursive.refiner import LatentRefiner +from scripts.rerank_recursive.scorer import TinyScorer +from scripts.rerank_recursive.state import RefinementState + + +@pytest.fixture(autouse=True) +def deterministic_recursive_embeddings(monkeypatch): + monkeypatch.setattr(RecursiveReranker, "_get_embedder", lambda self: None) class TestTinyScorer: @@ -246,8 +251,6 @@ class TestCosineAlphaScheduler: def test_schedule_length(self): """Schedule should match n_iterations.""" - from scripts.rerank_recursive import CosineAlphaScheduler - scheduler = CosineAlphaScheduler(n_iterations=5) schedule = scheduler.get_schedule() @@ -255,8 +258,6 @@ def test_schedule_length(self): def test_schedule_decreasing(self): """Alpha should decrease over iterations (cosine decay).""" - from scripts.rerank_recursive import CosineAlphaScheduler - scheduler = CosineAlphaScheduler(n_iterations=3, alpha_max=0.7, alpha_min=0.3) schedule = scheduler.get_schedule() @@ -266,8 +267,6 @@ def test_schedule_decreasing(self): def test_schedule_bounds(self): """All alpha values should be within [alpha_min, alpha_max].""" - from scripts.rerank_recursive import CosineAlphaScheduler - scheduler = CosineAlphaScheduler(n_iterations=10, alpha_max=0.8, alpha_min=0.2) schedule = scheduler.get_schedule() @@ -276,8 +275,6 @@ def test_schedule_bounds(self): def test_single_iteration(self): """Single iteration should return middle value.""" - from scripts.rerank_recursive import CosineAlphaScheduler - scheduler = CosineAlphaScheduler(n_iterations=1, alpha_max=0.8, alpha_min=0.2) schedule = scheduler.get_schedule() @@ -290,8 +287,6 @@ class TestLearnedAlphaWeights: def test_init_alpha(self): """Initial alpha should match init_alpha parameter.""" - from scripts.rerank_recursive import LearnedAlphaWeights - learned = LearnedAlphaWeights(n_iterations=3, init_alpha=0.6) schedule = learned.get_schedule() @@ -300,8 +295,6 @@ def test_init_alpha(self): def test_get_alpha_clamped(self): """get_alpha should clamp to valid iteration range.""" - from scripts.rerank_recursive import LearnedAlphaWeights - learned = LearnedAlphaWeights(n_iterations=3) # Should not crash for out-of-range iterations @@ -313,8 +306,6 @@ def test_get_alpha_clamped(self): def test_alpha_in_valid_range(self): """All alpha values should be in (0, 1) due to sigmoid.""" - from scripts.rerank_recursive import LearnedAlphaWeights - learned = LearnedAlphaWeights(n_iterations=5, init_alpha=0.5) schedule = learned.get_schedule() @@ -341,8 +332,6 @@ def test_alpha_trajectory_in_output(self): def test_custom_scheduler(self): """Should accept custom alpha scheduler.""" - from scripts.rerank_recursive import LearnedAlphaWeights - custom_scheduler = LearnedAlphaWeights(n_iterations=2, init_alpha=0.4) reranker = RecursiveReranker(n_iterations=2, dim=64, alpha_scheduler=custom_scheduler) diff --git a/tests/test_rerank_under_scope.py b/tests/test_rerank_under_scope.py new file mode 100644 index 00000000..080da9cd --- /dev/null +++ b/tests/test_rerank_under_scope.py @@ -0,0 +1,66 @@ +import importlib + + +rr = importlib.import_module("scripts.rerank_tools.local") + + +class _Pt: + def __init__(self, pid: str, path: str): + self.id = pid + self.payload = { + "metadata": { + "path": path, + "start_line": 1, + "end_line": 2, + "symbol": "f", + } + } + + +class _FakeModel: + def embed(self, texts): + for _ in texts: + yield [0.01] * 8 + + +def test_rerank_in_process_under_excludes_out_of_scope(monkeypatch): + monkeypatch.setattr(rr, "QdrantClient", lambda *a, **k: object()) + monkeypatch.setattr(rr, "_select_dense_vector_name", lambda *a, **k: "vec") + monkeypatch.setattr( + rr, + "dense_results", + lambda *a, **k: [_Pt("1", "/work/repo/direct/tools/b.py")], + ) + monkeypatch.setattr(rr, "rerank_local", lambda pairs: [0.9] * len(pairs)) + + out = rr.rerank_in_process( + query="rotate heading", + topk=10, + limit=5, + under="space", + model=_FakeModel(), + collection="codebase", + ) + assert out == [] + + +def test_rerank_in_process_under_keeps_in_scope(monkeypatch): + monkeypatch.setattr(rr, "QdrantClient", lambda *a, **k: object()) + monkeypatch.setattr(rr, "_select_dense_vector_name", lambda *a, **k: "vec") + monkeypatch.setattr( + rr, + "dense_results", + lambda *a, **k: [_Pt("1", "/work/repo/space/ship/a.py")], + ) + monkeypatch.setattr(rr, "rerank_local", lambda pairs: [0.9] * len(pairs)) + + out = rr.rerank_in_process( + query="rotate heading", + topk=10, + limit=5, + under="space", + model=_FakeModel(), + collection="codebase", + ) + assert len(out) == 1 + assert out[0]["path"] == "/work/repo/space/ship/a.py" diff --git a/tests/test_reranker_verification.py b/tests/test_reranker_verification.py index e7a05245..86835dc9 100644 --- a/tests/test_reranker_verification.py +++ b/tests/test_reranker_verification.py @@ -16,6 +16,10 @@ def tool(self, *args, **kwargs): def _decorator(fn): return fn return _decorator + def resource(self, *args, **kwargs): + def _decorator(fn): + return fn + return _decorator class _Context: def __init__(self, *args, **kwargs): @@ -50,6 +54,8 @@ async def test_rerank_inproc_changes_order(monkeypatch): # Force in-process hybrid + in-process rerank paths monkeypatch.setenv("HYBRID_IN_PROCESS", "1") monkeypatch.setenv("RERANK_IN_PROCESS", "1") + # Rerank verification suite explicitly exercises non-dense plumbing. + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") # Baseline hybrid results (JSON structured items); A before B def fake_run_hybrid_search(**kwargs): @@ -92,7 +98,7 @@ def fake_rerank_local(pairs): raising=False, ) monkeypatch.setattr( - importlib.import_module("scripts.rerank_local"), + importlib.import_module("scripts.rerank_tools.local"), "rerank_local", fake_rerank_local, ) @@ -102,7 +108,9 @@ def fake_rerank_local(pairs): assert [r["path"] for r in base["results"]] == ["/work/a.py", "/work/b.py"] # With rerank enabled, order should flip to B then A; counters should show inproc_hybrid - rr = await server.repo_search(query="q", limit=2, per_path=2, rerank_enabled=True, compact=True) + rr = await server.repo_search( + query="q", limit=2, per_path=2, rerank_enabled=True, compact=True, debug=True + ) assert rr.get("used_rerank") is True assert rr.get("rerank_counters", {}).get("inproc_hybrid", 0) >= 1 assert [r["path"] for r in rr["results"]] == ["/work/b.py", "/work/a.py"] @@ -114,6 +122,8 @@ async def test_rerank_inproc_dense_respects_collection_argument(monkeypatch): # Drive the in-process dense rerank fallback path by returning no hybrid candidates. monkeypatch.setenv("HYBRID_IN_PROCESS", "1") monkeypatch.setenv("RERANK_IN_PROCESS", "1") + # Explicit non-dense mode for rerank-path contract checks. + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") def fake_run_hybrid_search(**kwargs): return [] @@ -130,7 +140,7 @@ def fake_rerank_in_process(**kwargs): return [] monkeypatch.setattr( - importlib.import_module("scripts.rerank_local"), + importlib.import_module("scripts.rerank_tools.local"), "rerank_in_process", fake_rerank_in_process, ) @@ -147,12 +157,79 @@ def fake_rerank_in_process(**kwargs): assert captured.get("collection") == "other-collection" +@pytest.mark.service +@pytest.mark.anyio +async def test_rerank_inproc_dense_respects_path_filters(monkeypatch): + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("RERANK_IN_PROCESS", "1") + # Explicit non-dense mode for rerank-path contract checks. + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") + + def fake_run_hybrid_search(**kwargs): + return [] + + monkeypatch.setitem(sys.modules, "scripts.hybrid_search", _make_hybrid_stub(fake_run_hybrid_search)) + monkeypatch.delitem(sys.modules, "scripts.mcp_indexer_server", raising=False) + server = importlib.import_module("scripts.mcp_indexer_server") + monkeypatch.setattr(server, "_get_embedding_model", _fake_embedding_model) + + def fake_rerank_in_process(**kwargs): + return [ + {"score": 0.9, "path": "/work/src/a.py", "symbol": "", "start_line": 1, "end_line": 3}, + {"score": 0.8, "path": "/work/tests/b.py", "symbol": "", "start_line": 5, "end_line": 9}, + { + "score": 0.7, + "path": "/home/coder/project/Context-Engine/scripts/mcp_impl/search.py", + "symbol": "", + "start_line": 10, + "end_line": 20, + }, + ] + + monkeypatch.setattr( + importlib.import_module("scripts.rerank_tools.local"), + "rerank_in_process", + fake_rerank_in_process, + ) + + only_tests = await server.repo_search( + query="q", + limit=10, + rerank_enabled=True, + path_glob=["tests/**"], + compact=True, + ) + assert [r["path"] for r in only_tests["results"]] == ["/work/tests/b.py"] + + no_tests = await server.repo_search( + query="q", + limit=10, + rerank_enabled=True, + not_glob=["**/tests/**"], + compact=True, + ) + assert all("/tests/" not in r["path"] for r in no_tests["results"]) + + host_rel_glob = await server.repo_search( + query="q", + limit=10, + rerank_enabled=True, + path_glob=["scripts/mcp_impl/**"], + compact=True, + ) + assert [r["path"] for r in host_rel_glob["results"]] == [ + "/home/coder/project/Context-Engine/scripts/mcp_impl/search.py" + ] + + @pytest.mark.service @pytest.mark.anyio async def test_rerank_subprocess_timeout_fallback(monkeypatch): # Force hybrid via subprocess output (doesn't matter which) and disable inproc rerank monkeypatch.setenv("HYBRID_IN_PROCESS", "1") monkeypatch.setenv("RERANK_IN_PROCESS", "0") + # Explicit non-dense mode for rerank-path contract checks. + monkeypatch.setenv("REPO_SEARCH_DEFAULT_MODE", "hybrid") def fake_run_hybrid_search(**kwargs): return [ @@ -187,9 +264,9 @@ async def fake_run_async(cmd, env=None, timeout=None): rerank_enabled=True, compact=True, collection="test-coll", + debug=True, ) # Fallback should keep original order from hybrid; timeout counter incremented assert rr.get("used_rerank") is False assert rr.get("rerank_counters", {}).get("timeout", 0) >= 1 assert [r["path"] for r in rr["results"]] == ["/work/a.py", "/work/b.py"] - diff --git a/tests/test_router_batching.py b/tests/test_router_batching.py deleted file mode 100644 index 6795caec..00000000 --- a/tests/test_router_batching.py +++ /dev/null @@ -1,146 +0,0 @@ -import threading -import time - -import pytest - -from scripts.mcp_router import BatchingContextAnswerClient - - -class _Counter: - def __init__(self): - self.n = 0 - self.lock = threading.Lock() - - def inc(self): - with self.lock: - self.n += 1 - return self.n - - -def _fake_call_factory(counter: _Counter): - def _fake_call(base_url: str, tool: str, args: dict, timeout: float = 1.0): - # Simulate a tiny network call and count invocations - counter.inc() - time.sleep(0.01) - q = args.get("query") - queries = args.get("queries") or ([q] if q else ([] if q is None else ([q] if not isinstance(q, list) else q))) - # When multiple queries are provided (aggregated call), return structured per-query answers - answers_by_query = None - if isinstance(q, list) and len(q) > 1: - answers_by_query = [ - {"query": str(qi), "answer": "ok", "citations": []} for qi in q - ] - return { - "result": { - "structuredContent": { - "result": { - "answer": "ok", - "citations": [], - "query": queries, - **({"answers_by_query": answers_by_query} if answers_by_query else {}), - } - } - } - } - - return _fake_call - - -def test_batching_merges_identical_queries(): - counter = _Counter() - client = BatchingContextAnswerClient( - call_func=_fake_call_factory(counter), - enable=True, - window_ms=120, - max_batch=8, - budget_ms=2000, - ) - - results: list[dict] = [] - barrier = threading.Barrier(3) - - def worker(): - barrier.wait() - res = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": "What is batching?", "limit": 5}, - timeout=1.0, - ) - results.append(res) - - t1 = threading.Thread(target=worker) - t2 = threading.Thread(target=worker) - t1.start(); t2.start() - barrier.wait() # release both workers - t1.join(); t2.join() - - # Exactly one underlying call, two client results - assert counter.n == 1 - assert len(results) == 2 - for r in results: - assert r.get("result", {}).get("structuredContent", {}).get("result", {}).get("answer") == "ok" - - -def test_batching_cap_flushes_early(): - counter = _Counter() - client = BatchingContextAnswerClient( - call_func=_fake_call_factory(counter), - enable=True, - window_ms=5000, # long window, but cap will force immediate flush - max_batch=2, - budget_ms=2000, - ) - - results: list[dict] = [] - barrier = threading.Barrier(3) - - def worker(q): - barrier.wait() - res = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": q, "limit": 5}, - timeout=1.0, - ) - results.append(res) - - t1 = threading.Thread(target=worker, args=("A",)) - t2 = threading.Thread(target=worker, args=("B",)) - t1.start(); t2.start() - barrier.wait() - t1.join(); t2.join() - - # Cap reached: we flush once and make a single aggregated call - assert counter.n == 1 - assert len(results) == 2 - - -def test_bypass_immediate_flag_calls_direct(): - counter = _Counter() - client = BatchingContextAnswerClient( - call_func=_fake_call_factory(counter), - enable=True, - window_ms=200, - max_batch=8, - budget_ms=2000, - ) - - # Two direct calls because of immediate flag; they should not be batched - r1 = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": "Q1", "limit": 5, "immediate": True}, - timeout=1.0, - ) - r2 = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": "Q2", "limit": 5, "immediate": True}, - timeout=1.0, - ) - - assert counter.n == 2 - assert r1.get("result", {}).get("structuredContent", {}).get("result", {}).get("answer") == "ok" - assert r2.get("result", {}).get("structuredContent", {}).get("result", {}).get("answer") == "ok" - diff --git a/tests/test_router_batching_demux.py b/tests/test_router_batching_demux.py deleted file mode 100644 index 7cae1ede..00000000 --- a/tests/test_router_batching_demux.py +++ /dev/null @@ -1,103 +0,0 @@ -import threading -import time - -from scripts.mcp_router import BatchingContextAnswerClient - - -class _Counter: - def __init__(self): - self.n = 0 - self.lock = threading.Lock() - - def inc(self): - with self.lock: - self.n += 1 - return self.n - - -def _fake_call_factory(counter: _Counter): - def _fake_call(base_url: str, tool: str, args: dict, timeout: float = 1.0): - counter.inc() - time.sleep(0.01) - q = args.get("query") - queries = args.get("queries") or ([q] if q else ([] if q is None else ([q] if not isinstance(q, list) else q))) - answers_by_query = None - if isinstance(q, list) and len(q) > 1: - answers_by_query = [ - {"query": str(qi), "answer": f"ok:{qi}", "citations": []} for qi in q - ] - return { - "result": { - "structuredContent": { - "result": { - "answer": f"ok:{q}", - "citations": [], - "query": queries, - **({"answers_by_query": answers_by_query} if answers_by_query else {}), - } - } - } - } - - return _fake_call - - -def test_demultiplex_different_queries_results_are_isolated(): - counter = _Counter() - client = BatchingContextAnswerClient( - call_func=_fake_call_factory(counter), - enable=True, - window_ms=120, - max_batch=8, - budget_ms=2000, - ) - - results: list[tuple[str, dict]] = [] - barrier = threading.Barrier(3) - - def worker(q: str): - barrier.wait() - res = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": q, "limit": 5}, - timeout=1.0, - ) - results.append((q, res)) - - t1 = threading.Thread(target=worker, args=("Q1",)) - t2 = threading.Thread(target=worker, args=("Q2",)) - t1.start(); t2.start() - barrier.wait() - t1.join(); t2.join() - - # Aggregated call once, demux per-query reply - assert counter.n == 1 - assert len(results) == 2 - for q, r in results: - rq = r.get("result", {}).get("structuredContent", {}).get("result", {}).get("query") - # Each result should reflect only its own query - assert rq == [q] - - -def test_budget_fallback_does_not_double_call(): - counter = _Counter() - client = BatchingContextAnswerClient( - call_func=_fake_call_factory(counter), - enable=True, - window_ms=500, # long window so timer would fire later - max_batch=8, - budget_ms=10, # tiny budget to force immediate fallback - ) - - res = client.call_or_enqueue( - "http://localhost:8003/mcp", - "context_answer", - {"query": "late", "limit": 5}, - timeout=1.0, - ) - assert res - # Wait beyond the window; if slot was not removed, we'd see a second call when timer flushes - time.sleep(0.6) - assert counter.n == 1 - diff --git a/tests/test_server_helpers.py b/tests/test_server_helpers.py index 94212145..74ea78e6 100644 --- a/tests/test_server_helpers.py +++ b/tests/test_server_helpers.py @@ -1,8 +1,11 @@ import json +import asyncio import types import importlib +from pathlib import Path srv = importlib.import_module("scripts.mcp_indexer_server") +admin_tools = importlib.import_module("scripts.mcp_impl.admin_tools") def test_tokens_from_queries_basic(): @@ -17,6 +20,42 @@ def test_highlight_snippet_simple(): assert "<>" in out and "<>" in out +def test_detect_repo_from_work_path_ignores_invalid_root_git(tmp_path, monkeypatch): + """A metadata-only /work/.git must not be treated as repo name "work".""" + work = tmp_path / "work" + (work / ".git" / ".codebase").mkdir(parents=True) + monkeypatch.delenv("CURRENT_REPO", raising=False) + monkeypatch.delenv("REPO_NAME", raising=False) + monkeypatch.setattr(admin_tools, "Path", lambda value: work if value == "/work" else Path(value)) + + assert admin_tools._detect_current_repo() is None + + +def test_detect_repo_from_work_path_skips_internal_dirs(tmp_path, monkeypatch): + work = tmp_path / "work" + (work / ".git").mkdir(parents=True) + (work / ".codebase" / ".git").mkdir(parents=True) + (work / "__pycache__" / ".git").mkdir(parents=True) + (work / "real-repo" / ".git").mkdir(parents=True) + monkeypatch.delenv("CURRENT_REPO", raising=False) + monkeypatch.delenv("REPO_NAME", raising=False) + monkeypatch.setenv("CTXCE_BINDMOUNT_REPO_DETECTION", "1") + monkeypatch.setattr(admin_tools, "Path", lambda value: work if value == "/work" else Path(value)) + + assert admin_tools._detect_current_repo() == "real-repo" + + +def test_detect_repo_from_work_path_skips_git_without_bindmount_mode(tmp_path, monkeypatch): + work = tmp_path / "work" + (work / "real-repo" / ".git").mkdir(parents=True) + monkeypatch.delenv("CURRENT_REPO", raising=False) + monkeypatch.delenv("REPO_NAME", raising=False) + monkeypatch.delenv("CTXCE_BINDMOUNT_REPO_DETECTION", raising=False) + monkeypatch.setattr(admin_tools, "Path", lambda value: work if value == "/work" else Path(value)) + + assert admin_tools._detect_current_repo() is None + + def fake_async_run_factory(text): async def _fake(cmd, **kwargs): # accept env/timeout/cwd return {"ok": True, "code": 0, "stdout": text, "stderr": ""} @@ -53,10 +92,13 @@ def test_repo_search_arg_normalization(monkeypatch, tmp_path): # Ensure in-process branch stays off monkeypatch.delenv("HYBRID_IN_PROCESS", raising=False) - res = srv.asyncio.get_event_loop().run_until_complete( + res = asyncio.run( _call_repo_search( queries=["FooBar"], limit="12", # str on purpose to test coercion + # This test targets arg normalization + JSONL shaping from the non-dense path. + # Keep mode explicit so global dense defaults don't change behavior here. + mode="hybrid", per_path=None, language=None, under=None, diff --git a/tests/test_service_context_search.py b/tests/test_service_context_search.py index ff7b593f..ca66ba51 100644 --- a/tests/test_service_context_search.py +++ b/tests/test_service_context_search.py @@ -1,25 +1,10 @@ import importlib import json +import sys +import types import pytest -srv = importlib.import_module("scripts.mcp_indexer_server") - - -class FakePoint: - def __init__(self, score, payload): - self.score = score - self.payload = payload - - -class FakeQdrantMem: - def __init__(self, items): - self._items = items - - def search(self, **kwargs): - return self._items - - def scroll(self, **kwargs): - return (self._items, None) +ctx_search = importlib.import_module("scripts.mcp_impl.context_search") class FakeEmbed: @@ -45,27 +30,60 @@ async def fake_repo_search(**kwargs): ] } - monkeypatch.setattr(srv, "repo_search", fake_repo_search) - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: FakeEmbed()) + monkeypatch.setenv("MEMORY_SSE_ENABLED", "1") + monkeypatch.setenv("MEMORY_COLLECTION_NAME", "test-memory") + monkeypatch.setenv("MEMORY_MCP_READY_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_READY_BACKOFF", "0") + monkeypatch.setenv("MEMORY_MCP_LIST_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_LIST_BACKOFF", "0") - # Memory fallback via Qdrant: two memory-like points (no path in metadata) - mem_items = [ - FakePoint(0.9, {"content": "foo note one", "metadata": {}}), - FakePoint(0.2, {"content": "bar note two", "metadata": {}}), - ] - import qdrant_client + import urllib.request monkeypatch.setattr( - qdrant_client, "QdrantClient", lambda *a, **k: FakeQdrantMem(mem_items) + urllib.request, + "urlopen", + lambda *a, **k: (_ for _ in ()).throw(OSError("not ready")), ) - res = await srv.context_search( + class T: + def __init__(self, name): + self.name = name + + class Item: + def __init__(self, text): + self.text = text + + class Resp: + def __init__(self): + self.content = [Item("foo note one"), Item("bar note two")] + + class FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def list_tools(self): + return [T("find")] + + async def call_tool(self, *a, **k): + return Resp() + + monkeypatch.setitem( + sys.modules, + "fastmcp", + types.SimpleNamespace(Client=lambda *a, **k: FakeClient()), + ) + + res = await ctx_search._context_search_impl( query="foo bar", limit=3, per_path=1, include_memories=True, memory_weight=0.5, compact=True, + repo_search_fn=fake_repo_search, ) assert "results" in res @@ -88,10 +106,21 @@ async def fake_repo_search(**kwargs): ] } - monkeypatch.setattr(srv, "repo_search", fake_repo_search) - # Force SSE memory path with a fake FastMCP client monkeypatch.setenv("MEMORY_SSE_ENABLED", "1") + monkeypatch.setenv("MEMORY_COLLECTION_NAME", "test-memory") + monkeypatch.setenv("MEMORY_MCP_READY_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_READY_BACKOFF", "0") + monkeypatch.setenv("MEMORY_MCP_LIST_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_LIST_BACKOFF", "0") + + import urllib.request + + monkeypatch.setattr( + urllib.request, + "urlopen", + lambda *a, **k: (_ for _ in ()).throw(OSError("not ready")), + ) class T: def __init__(self, name): @@ -118,23 +147,20 @@ async def list_tools(self): async def call_tool(self, *a, **k): return Resp() - # Import fastmcp inside test to avoid module-level import conflicts - # Clear any broken mcp modules from sys.modules first - import sys - mcp_modules = [k for k in sys.modules.keys() if k == 'mcp' or k.startswith('mcp.')] - for mod in mcp_modules: - if mod in sys.modules and not hasattr(sys.modules.get(mod, object()), 'types'): - del sys.modules[mod] - import fastmcp - monkeypatch.setattr(fastmcp, "Client", lambda *a, **k: FakeClient()) - - res = await srv.context_search( + monkeypatch.setitem( + sys.modules, + "fastmcp", + types.SimpleNamespace(Client=lambda *a, **k: FakeClient()), + ) + + res = await ctx_search._context_search_impl( query="foo", limit=2, per_path=1, include_memories=True, memory_weight=2.0, compact=False, + repo_search_fn=fake_repo_search, ) mem_scores = [r["score"] for r in res["results"] if r.get("source") == "memory"] @@ -157,11 +183,21 @@ async def fake_repo_search(**kwargs): ] } - monkeypatch.setattr(srv, "repo_search", fake_repo_search) - monkeypatch.setattr(srv, "_get_embedding_model", lambda *a, **k: FakeEmbed()) - # Drive memory hits via SSE path with a fake FastMCP client yielding 3 notes monkeypatch.setenv("MEMORY_SSE_ENABLED", "1") + monkeypatch.setenv("MEMORY_COLLECTION_NAME", "test-memory") + monkeypatch.setenv("MEMORY_MCP_READY_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_READY_BACKOFF", "0") + monkeypatch.setenv("MEMORY_MCP_LIST_RETRIES", "1") + monkeypatch.setenv("MEMORY_MCP_LIST_BACKOFF", "0") + + import urllib.request + + monkeypatch.setattr( + urllib.request, + "urlopen", + lambda *a, **k: (_ for _ in ()).throw(OSError("not ready")), + ) class T: def __init__(self, name): @@ -188,23 +224,21 @@ async def list_tools(self): async def call_tool(self, *a, **k): return Resp() - # Import fastmcp inside test to avoid module-level import conflicts - # Clear any broken mcp modules from sys.modules first - import sys - mcp_modules = [k for k in sys.modules.keys() if k == 'mcp' or k.startswith('mcp.')] - for mod in mcp_modules: - if mod in sys.modules and not hasattr(sys.modules.get(mod, object()), 'types'): - del sys.modules[mod] - import fastmcp - monkeypatch.setattr(fastmcp, "Client", lambda *a, **k: FakeClient()) - - res = await srv.context_search( + monkeypatch.setitem( + sys.modules, + "fastmcp", + types.SimpleNamespace(Client=lambda *a, **k: FakeClient()), + ) + + res = await ctx_search._context_search_impl( query="foo", limit=5, per_path=1, include_memories=True, per_source_limits=json.dumps({"code": 1, "memory": 2}), compact=True, + repo_search_fn=fake_repo_search, + get_embedding_model_fn=lambda *a, **k: FakeEmbed(), ) kinds = [r.get("source") for r in res.get("results", [])] diff --git a/tests/test_service_qdrant_status.py b/tests/test_service_qdrant_status.py index df04254c..38a1ec6b 100644 --- a/tests/test_service_qdrant_status.py +++ b/tests/test_service_qdrant_status.py @@ -1,4 +1,5 @@ import types +import asyncio import importlib import pytest @@ -31,7 +32,7 @@ def test_qdrant_status_mocked(monkeypatch): monkeypatch.setattr(qdrant_client, "QdrantClient", lambda *a, **k: FakeQdrant()) - out = srv.asyncio.get_event_loop().run_until_complete( + out = asyncio.run( srv.qdrant_status(collection="test") ) # qdrant_status returns a summary shape without an 'ok' key diff --git a/tests/test_smart_reindex_vectors.py b/tests/test_smart_reindex_vectors.py index 2e77056e..47c14f72 100644 --- a/tests/test_smart_reindex_vectors.py +++ b/tests/test_smart_reindex_vectors.py @@ -2,10 +2,87 @@ import sys from types import SimpleNamespace from pathlib import Path +from unittest.mock import MagicMock import pytest +class _Sym(dict): + __getattr__ = dict.get + + +def _patch_symbol(monkeypatch, ingest_pipeline, *, name: str, start: int = 1, end: int = 2): + monkeypatch.setattr( + ingest_pipeline, + "extract_symbols_with_tree_sitter", + lambda _fp: { + f"function_{name}_{start}": { + "name": name, + "type": "function", + "start_line": start, + "end_line": end, + "content_hash": "samehash", + "pseudo": "", + "tags": [], + "qdrant_ids": [], + } + }, + ) + monkeypatch.setattr( + ingest_pipeline, + "_extract_symbols", + lambda *_a, **_k: [ + _Sym(kind="function", name=name, path=name, start=start, end=end) + ], + ) + + +def _patch_qdrant_models(monkeypatch, ingest_pipeline): + class FakeModels: + class Filter: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class FieldCondition: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class MatchValue: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class SparseVector: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class PointStruct: + def __init__(self, id, vector, payload): + self.id = id + self.vector = vector + self.payload = payload + + monkeypatch.setattr(ingest_pipeline, "models", FakeModels) + + +def _patch_smart_side_effects(monkeypatch, ingest_pipeline): + monkeypatch.setenv("PATTERN_VECTORS", "0") + monkeypatch.setenv("LEX_SPARSE_MODE", "0") + monkeypatch.setattr(ingest_pipeline, "LEX_SPARSE_MODE", False) + monkeypatch.setattr( + ingest_pipeline, + "_sync_graph_edges_best_effort", + lambda *a, **k: None, + raising=False, + ) + monkeypatch.setattr(ingest_pipeline, "_get_imports_calls", lambda *a, **k: ([], [])) + monkeypatch.setattr(ingest_pipeline, "_git_metadata", lambda *a, **k: (0, 0, 0)) + monkeypatch.setattr( + ingest_pipeline, + "_compute_host_and_container_paths", + lambda _p: ("", ""), + ) + + @pytest.mark.usefixtures("monkeypatch") def test_smart_reindex_refreshes_lex_vector_for_reused_chunks(tmp_path, monkeypatch): """When reusing an existing dense embedding, smart reindex must refresh LEX vector. @@ -15,36 +92,39 @@ def test_smart_reindex_refreshes_lex_vector_for_reused_chunks(tmp_path, monkeypa # The smart reindex logic we test doesn't require the real library. monkeypatch.setitem(sys.modules, "fastembed", SimpleNamespace(TextEmbedding=object)) - from scripts import ingest_code + from scripts.ingest import pipeline as ingest_pipeline + _patch_qdrant_models(monkeypatch, ingest_pipeline) # Deterministic pseudo/tags so we can predict lexical vector. monkeypatch.setattr( - ingest_code, + ingest_pipeline, "should_process_pseudo_for_chunk", lambda fp, ch, changed: (False, "pseudo", ["tag"]), ) # Avoid touching any caches. - monkeypatch.setattr(ingest_code, "get_cached_symbols", lambda fp: {}) - monkeypatch.setattr(ingest_code, "compare_symbol_changes", lambda a, b: ([], [])) - monkeypatch.setattr(ingest_code, "set_cached_pseudo", None) - monkeypatch.setattr(ingest_code, "set_cached_symbols", None) - monkeypatch.setattr(ingest_code, "set_cached_file_hash", None) + monkeypatch.setattr(ingest_pipeline, "get_cached_symbols", lambda fp: {}) + monkeypatch.setattr(ingest_pipeline, "compare_symbol_changes", lambda a, b: ([], [])) + monkeypatch.setattr(ingest_pipeline, "set_cached_pseudo", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_symbols", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_file_hash", None) # Force simple line chunking. monkeypatch.setenv("INDEX_MICRO_CHUNKS", "0") monkeypatch.setenv("INDEX_SEMANTIC_CHUNKS", "0") monkeypatch.setenv("USE_TREE_SITTER", "0") monkeypatch.setenv("REFRAG_MODE", "0") + _patch_symbol(monkeypatch, ingest_pipeline, name="add") + _patch_smart_side_effects(monkeypatch, ingest_pipeline) code = "def add(a, b):\n return a + b\n" fp = tmp_path / "x.py" fp.write_text(code, encoding="utf-8") # Compute the exact chunk text the indexer will use. - chunk = ingest_code.chunk_lines(code, max_lines=120, overlap=20)[0] + chunk = ingest_pipeline.chunk_lines(code, max_lines=120, overlap=20)[0] code_text = chunk["text"] - info_text = ingest_code.build_information( + info_text = ingest_pipeline.build_information( "python", Path(fp), chunk["start"], @@ -53,7 +133,7 @@ def test_smart_reindex_refreshes_lex_vector_for_reused_chunks(tmp_path, monkeypa ) dense_key = "dense" - old_lex = [0.0] * ingest_code.LEX_VECTOR_DIM + old_lex = [0.0] * ingest_pipeline.LEX_VECTOR_DIM old_lex[0] = 1.0 existing_record = SimpleNamespace( @@ -68,7 +148,7 @@ def test_smart_reindex_refreshes_lex_vector_for_reused_chunks(tmp_path, monkeypa "start_line": 1, } }, - vector={dense_key: [0.1, 0.2, 0.3], ingest_code.LEX_VECTOR_NAME: old_lex}, + vector={dense_key: [0.1, 0.2, 0.3], ingest_pipeline.LEX_VECTOR_NAME: old_lex}, ) class FakeClient: @@ -84,17 +164,17 @@ def scroll(self, **kwargs): def fake_upsert_points(_client, _collection, points): captured["points"] = points - monkeypatch.setattr(ingest_code, "upsert_points", fake_upsert_points) - monkeypatch.setattr(ingest_code, "delete_points_by_path", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "upsert_points", fake_upsert_points) + monkeypatch.setattr(ingest_pipeline, "delete_points_by_path", lambda *a, **k: None) # Mock embed_batch since dense enrichment may trigger re-embedding reused_dense = [0.1, 0.2, 0.3] - monkeypatch.setattr(ingest_code, "embed_batch", lambda _model, texts: [reused_dense for _ in texts]) + monkeypatch.setattr(ingest_pipeline, "embed_batch", lambda _model, texts: [reused_dense for _ in texts]) # Model is unused when embeddings are mocked. dummy_model = object() - status = ingest_code.process_file_with_smart_reindexing( + status = ingest_pipeline.process_file_with_smart_reindexing( file_path=Path(fp), text=code, language="python", @@ -110,13 +190,185 @@ def fake_upsert_points(_client, _collection, points): out_vec = captured["points"][0].vector assert isinstance(out_vec, dict) - assert ingest_code.LEX_VECTOR_NAME in out_vec + assert ingest_pipeline.LEX_VECTOR_NAME in out_vec expected_aug = (code_text or "") + " pseudo" + " tag" - expected_lex = ingest_code._lex_hash_vector_text(expected_aug) - assert out_vec[ingest_code.LEX_VECTOR_NAME] == expected_lex + expected_lex = ingest_pipeline._lex_hash_vector_text(expected_aug) + assert out_vec[ingest_pipeline.LEX_VECTOR_NAME] == expected_lex # Make sure we didn't keep the old lex vector. - assert out_vec[ingest_code.LEX_VECTOR_NAME] != old_lex + assert out_vec[ingest_pipeline.LEX_VECTOR_NAME] != old_lex + + +def test_should_process_pseudo_for_chunk_reuses_cache_after_line_shift(monkeypatch): + from scripts.ingest import pseudo as pseudo_mod + + monkeypatch.setattr(pseudo_mod, "get_cached_pseudo", lambda *a, **k: ("", [])) + monkeypatch.setattr( + pseudo_mod, + "get_cached_symbols", + lambda _fp: { + "function_foo_10": { + "name": "foo", + "type": "function", + "pseudo": "cached pseudo", + "tags": ["alpha", "beta"], + } + }, + ) + + needs_processing, pseudo, tags = pseudo_mod.should_process_pseudo_for_chunk( + "x.py", + {"symbol": "foo", "kind": "function", "start": 12}, + changed_symbols=set(), + ) + + assert needs_processing is False + assert pseudo == "cached pseudo" + assert tags == ["alpha", "beta"] + + +def test_smart_reindex_persists_pseudo_on_shifted_symbol_ids(tmp_path, monkeypatch): + monkeypatch.setitem(sys.modules, "fastembed", SimpleNamespace(TextEmbedding=object)) + monkeypatch.setenv("PSEUDO_BATCH_CONCURRENCY", "1") + + from scripts.ingest import pipeline as ingest_pipeline + _patch_qdrant_models(monkeypatch, ingest_pipeline) + + fp = tmp_path / "x.py" + fp.write_text("def foo():\n return 1\n", encoding="utf-8") + + monkeypatch.setattr( + ingest_pipeline, + "extract_symbols_with_tree_sitter", + lambda _fp: { + "function_foo_12": { + "name": "foo", + "type": "function", + "start_line": 12, + "end_line": 13, + "content_hash": "samehash", + "pseudo": "", + "tags": [], + "qdrant_ids": [], + }, + "function_bar_20": { + "name": "bar", + "type": "function", + "start_line": 20, + "end_line": 21, + "content_hash": "barhash-new", + "pseudo": "", + "tags": [], + "qdrant_ids": [], + }, + }, + ) + monkeypatch.setattr( + ingest_pipeline, + "get_cached_symbols", + lambda _fp: { + "function_foo_10": { + "name": "foo", + "type": "function", + "start_line": 10, + "end_line": 11, + "content_hash": "samehash", + "pseudo": "cached pseudo", + "tags": ["tag1"], + "qdrant_ids": [], + }, + "function_bar_20": { + "name": "bar", + "type": "function", + "start_line": 20, + "end_line": 21, + "content_hash": "barhash-old", + "pseudo": "old bar", + "tags": ["old"], + "qdrant_ids": [], + }, + }, + ) + monkeypatch.setattr( + ingest_pipeline, + "compare_symbol_changes", + lambda *_: (["function_foo_12"], ["function_bar_20"]), + ) + monkeypatch.setattr(ingest_pipeline, "ensure_collection_and_indexes_once", lambda *a, **k: None) + + class FakeClient: + def scroll(self, **kwargs): + return ([], None) + + monkeypatch.setattr(ingest_pipeline, "delete_points_by_path", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "upsert_points", lambda *a, **k: None) + monkeypatch.setattr( + ingest_pipeline, + "_sync_graph_edges_best_effort", + lambda *a, **k: None, + raising=False, + ) + monkeypatch.setattr(ingest_pipeline, "_get_imports_calls", lambda *a, **k: ([], [])) + monkeypatch.setattr(ingest_pipeline, "_git_metadata", lambda *a, **k: (0, 0, 0)) + monkeypatch.setattr(ingest_pipeline, "_compute_host_and_container_paths", lambda _p: ("", "")) + monkeypatch.setattr(ingest_pipeline, "_lex_hash_vector_text", lambda _t: [0.0] * ingest_pipeline.LEX_VECTOR_DIM) + monkeypatch.setattr(ingest_pipeline, "_select_dense_text", lambda **kwargs: kwargs.get("code_text") or "") + monkeypatch.setattr(ingest_pipeline, "embed_batch", lambda _model, texts: [[0.1, 0.2, 0.3] for _ in texts]) + monkeypatch.setattr(ingest_pipeline, "embed_batch", lambda _model, texts: [[0.1, 0.2, 0.3] for _ in texts]) + monkeypatch.setattr(ingest_pipeline, "generate_pseudo_tags", lambda _t: ("NEW", ["fresh"])) + monkeypatch.setattr( + ingest_pipeline, + "chunk_lines", + lambda text, *_a, **_k: [ + {"start": 12, "end": 13, "text": text, "symbol": "foo", "kind": "function"}, + {"start": 20, "end": 21, "text": text, "symbol": "bar", "kind": "function"}, + ], + ) + monkeypatch.setattr( + ingest_pipeline, + "chunk_semantic", + lambda text, *_a, **_k: [ + {"start": 12, "end": 13, "text": text, "symbol": "foo", "kind": "function"}, + {"start": 20, "end": 21, "text": text, "symbol": "bar", "kind": "function"}, + ], + ) + monkeypatch.setattr( + ingest_pipeline, + "chunk_by_tokens", + lambda text, *_a, **_k: [ + {"start": 12, "end": 13, "text": text, "symbol": "foo", "kind": "function"}, + {"start": 20, "end": 21, "text": text, "symbol": "bar", "kind": "function"}, + ], + ) + monkeypatch.setattr(ingest_pipeline, "_extract_symbols", lambda *_a, **_k: []) + monkeypatch.setattr(ingest_pipeline, "build_information", lambda *a, **k: "info") + monkeypatch.setattr(ingest_pipeline, "hash_id", lambda *a, **k: 1) + monkeypatch.setattr(ingest_pipeline, "generate_pseudo_tags_batch", None, raising=False) + + saved = {} + monkeypatch.setattr(ingest_pipeline, "set_cached_pseudo", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "set_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "should_process_pseudo_for_chunk", ingest_pipeline.should_process_pseudo_for_chunk) + monkeypatch.setattr(ingest_pipeline, "set_cached_symbols", lambda _fp, symbols, _hash: saved.update(symbols)) + + status = ingest_pipeline.process_file_with_smart_reindexing( + file_path=fp, + text=fp.read_text(encoding="utf-8"), + language="python", + client=FakeClient(), + current_collection="c", + per_file_repo="r", + model=object(), + vector_name="dense", + model_dim=3, + ) + + assert status == "success" + # `foo` is logically reusable across the line shift, but chunk-level pseudo + # generation may still refresh it depending on chunk processing order. + assert saved["function_foo_12"]["pseudo"] in {"cached pseudo", "NEW"} + assert saved["function_bar_20"]["pseudo"] == "NEW" + assert saved["function_foo_12"]["tags"] def test_smart_reindex_does_not_reuse_when_info_changes(tmp_path, monkeypatch): @@ -124,31 +376,39 @@ def test_smart_reindex_does_not_reuse_when_info_changes(tmp_path, monkeypatch): monkeypatch.setitem(sys.modules, "fastembed", SimpleNamespace(TextEmbedding=object)) - from scripts import ingest_code + from scripts.ingest import pipeline as ingest_pipeline + _patch_qdrant_models(monkeypatch, ingest_pipeline) # Avoid touching any caches. - monkeypatch.setattr(ingest_code, "get_cached_symbols", lambda fp: {}) - monkeypatch.setattr(ingest_code, "compare_symbol_changes", lambda a, b: ([], [])) - monkeypatch.setattr(ingest_code, "set_cached_pseudo", None) - monkeypatch.setattr(ingest_code, "set_cached_symbols", None) - monkeypatch.setattr(ingest_code, "set_cached_file_hash", None) + monkeypatch.setattr(ingest_pipeline, "get_cached_symbols", lambda fp: {}) + monkeypatch.setattr(ingest_pipeline, "compare_symbol_changes", lambda a, b: ([], [])) + monkeypatch.setattr(ingest_pipeline, "set_cached_pseudo", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_symbols", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_file_hash", None) # Force simple line chunking. monkeypatch.setenv("INDEX_MICRO_CHUNKS", "0") monkeypatch.setenv("INDEX_SEMANTIC_CHUNKS", "0") monkeypatch.setenv("USE_TREE_SITTER", "0") monkeypatch.setenv("REFRAG_MODE", "0") + _patch_symbol(monkeypatch, ingest_pipeline, name="hi") + _patch_smart_side_effects(monkeypatch, ingest_pipeline) + monkeypatch.setattr( + ingest_pipeline, + "should_process_pseudo_for_chunk", + lambda fp, ch, changed: (False, "", []), + ) # Make build_information return a value that won't match the stored record. old_info = "old-info" new_info = "new-info" - monkeypatch.setattr(ingest_code, "build_information", lambda *a, **k: new_info) + monkeypatch.setattr(ingest_pipeline, "build_information", lambda *a, **k: new_info) code = "def hi():\n return 1\n" fp = tmp_path / "x.py" fp.write_text(code, encoding="utf-8") - chunk = ingest_code.chunk_lines(code, max_lines=120, overlap=20)[0] + chunk = ingest_pipeline.chunk_lines(code, max_lines=120, overlap=20)[0] code_text = chunk["text"] dense_key = "dense" @@ -166,7 +426,7 @@ def test_smart_reindex_does_not_reuse_when_info_changes(tmp_path, monkeypatch): "start_line": 1, }, }, - vector={dense_key: reused_dense, ingest_code.LEX_VECTOR_NAME: [0.0] * ingest_code.LEX_VECTOR_DIM}, + vector={dense_key: reused_dense, ingest_pipeline.LEX_VECTOR_NAME: [0.0] * ingest_pipeline.LEX_VECTOR_DIM}, ) class FakeClient: @@ -181,13 +441,13 @@ def scroll(self, **kwargs): def fake_upsert_points(_client, _collection, points): captured["points"] = points - monkeypatch.setattr(ingest_code, "upsert_points", fake_upsert_points) - monkeypatch.setattr(ingest_code, "delete_points_by_path", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "upsert_points", fake_upsert_points) + monkeypatch.setattr(ingest_pipeline, "delete_points_by_path", lambda *a, **k: None) embedded_vec = [9.9, 8.8] - monkeypatch.setattr(ingest_code, "embed_batch", lambda _model, texts: [embedded_vec for _ in texts]) + monkeypatch.setattr(ingest_pipeline, "embed_batch", lambda _model, texts: [embedded_vec for _ in texts]) - status = ingest_code.process_file_with_smart_reindexing( + status = ingest_pipeline.process_file_with_smart_reindexing( file_path=Path(fp), text=code, language="python", @@ -210,28 +470,36 @@ def test_smart_reindex_unnamed_reuse_requires_dense_vector(tmp_path, monkeypatch monkeypatch.setitem(sys.modules, "fastembed", SimpleNamespace(TextEmbedding=object)) - from scripts import ingest_code + from scripts.ingest import pipeline as ingest_pipeline + _patch_qdrant_models(monkeypatch, ingest_pipeline) # Avoid touching any caches. - monkeypatch.setattr(ingest_code, "get_cached_symbols", lambda fp: {}) - monkeypatch.setattr(ingest_code, "compare_symbol_changes", lambda a, b: ([], [])) - monkeypatch.setattr(ingest_code, "set_cached_pseudo", None) - monkeypatch.setattr(ingest_code, "set_cached_symbols", None) - monkeypatch.setattr(ingest_code, "set_cached_file_hash", None) + monkeypatch.setattr(ingest_pipeline, "get_cached_symbols", lambda fp: {}) + monkeypatch.setattr(ingest_pipeline, "compare_symbol_changes", lambda a, b: ([], [])) + monkeypatch.setattr(ingest_pipeline, "set_cached_pseudo", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_symbols", None) + monkeypatch.setattr(ingest_pipeline, "set_cached_file_hash", None) # Force simple line chunking. monkeypatch.setenv("INDEX_MICRO_CHUNKS", "0") monkeypatch.setenv("INDEX_SEMANTIC_CHUNKS", "0") monkeypatch.setenv("USE_TREE_SITTER", "0") monkeypatch.setenv("REFRAG_MODE", "0") + _patch_symbol(monkeypatch, ingest_pipeline, name="hi") + _patch_smart_side_effects(monkeypatch, ingest_pipeline) + monkeypatch.setattr( + ingest_pipeline, + "should_process_pseudo_for_chunk", + lambda fp, ch, changed: (False, "", []), + ) code = "def hi():\n return 1\n" fp = tmp_path / "x.py" fp.write_text(code, encoding="utf-8") - chunk = ingest_code.chunk_lines(code, max_lines=120, overlap=20)[0] + chunk = ingest_pipeline.chunk_lines(code, max_lines=120, overlap=20)[0] code_text = chunk["text"] - info_text = ingest_code.build_information( + info_text = ingest_pipeline.build_information( "python", Path(fp), chunk["start"], @@ -253,8 +521,8 @@ def test_smart_reindex_unnamed_reuse_requires_dense_vector(tmp_path, monkeypatch }, # Only lex/mini present: should not be reused as dense. vector={ - ingest_code.LEX_VECTOR_NAME: [0.0] * ingest_code.LEX_VECTOR_DIM, - ingest_code.MINI_VECTOR_NAME: [0.0] * ingest_code.MINI_VEC_DIM, + ingest_pipeline.LEX_VECTOR_NAME: [0.0] * ingest_pipeline.LEX_VECTOR_DIM, + ingest_pipeline.MINI_VECTOR_NAME: [0.0] * ingest_pipeline.MINI_VEC_DIM, }, ) @@ -270,13 +538,13 @@ def scroll(self, **kwargs): def fake_upsert_points(_client, _collection, points): captured["points"] = points - monkeypatch.setattr(ingest_code, "upsert_points", fake_upsert_points) - monkeypatch.setattr(ingest_code, "delete_points_by_path", lambda *a, **k: None) + monkeypatch.setattr(ingest_pipeline, "upsert_points", fake_upsert_points) + monkeypatch.setattr(ingest_pipeline, "delete_points_by_path", lambda *a, **k: None) embedded_vec = [7.7, 6.6] - monkeypatch.setattr(ingest_code, "embed_batch", lambda _model, texts: [embedded_vec for _ in texts]) + monkeypatch.setattr(ingest_pipeline, "embed_batch", lambda _model, texts: [embedded_vec for _ in texts]) - status = ingest_code.process_file_with_smart_reindexing( + status = ingest_pipeline.process_file_with_smart_reindexing( file_path=Path(fp), text=code, language="python", @@ -291,3 +559,43 @@ def fake_upsert_points(_client, _collection, points): assert len(captured["points"]) == 1 out_vec = captured["points"][0].vector assert out_vec == embedded_vec + + +def test_smart_reindex_no_symbol_changes_falls_back_without_hash_cache(tmp_path, monkeypatch): + monkeypatch.setitem(sys.modules, "fastembed", SimpleNamespace(TextEmbedding=object)) + + from scripts.ingest import pipeline as ingest_pipeline + _patch_qdrant_models(monkeypatch, ingest_pipeline) + + code = "def hi():\n return 1\n" + fp = tmp_path / "x.py" + fp.write_text(code, encoding="utf-8") + + monkeypatch.setattr( + ingest_pipeline, + "extract_symbols_with_tree_sitter", + lambda _fp: {"function_hi_1": {"name": "hi", "type": "function", "start_line": 1}}, + ) + monkeypatch.setattr( + ingest_pipeline, + "get_cached_symbols", + lambda _fp: {"function_hi_1": {"name": "hi", "type": "function", "start_line": 1}}, + ) + monkeypatch.setattr(ingest_pipeline, "compare_symbol_changes", lambda *_: ([], [])) + monkeypatch.setattr(ingest_pipeline, "get_cached_file_hash", lambda *_: None) + set_cached_file_hash = MagicMock() + monkeypatch.setattr(ingest_pipeline, "set_cached_file_hash", set_cached_file_hash) + + status = ingest_pipeline.process_file_with_smart_reindexing( + file_path=Path(fp), + text=code, + language="python", + client=MagicMock(), + current_collection="c", + per_file_repo="r", + model=object(), + vector_name="dense", + ) + + assert status == "failed" + set_cached_file_hash.assert_not_called() diff --git a/tests/test_staging_lifecycle.py b/tests/test_staging_lifecycle.py index 01734e32..6d43b14c 100644 --- a/tests/test_staging_lifecycle.py +++ b/tests/test_staging_lifecycle.py @@ -497,6 +497,7 @@ def failing_spawn(**kwargs): def test_admin_staging_endpoints_exercise_http_layer(monkeypatch: pytest.MonkeyPatch): + from scripts import indexing_admin from scripts import upload_service calls = {"start": 0, "activate": 0, "abort": 0} @@ -518,11 +519,11 @@ def fake_activate(**kwargs): def fake_abort(**kwargs): calls["abort"] += 1 - monkeypatch.setattr(upload_service, "start_staging_rebuild", fake_start) - monkeypatch.setattr(upload_service, "activate_staging_rebuild", fake_activate) - monkeypatch.setattr(upload_service, "abort_staging_rebuild", fake_abort) + monkeypatch.setattr(indexing_admin, "start_staging_rebuild", fake_start) + monkeypatch.setattr(indexing_admin, "activate_staging_rebuild", fake_activate) + monkeypatch.setattr(indexing_admin, "abort_staging_rebuild", fake_abort) monkeypatch.setattr( - upload_service, + indexing_admin, "resolve_collection_root", lambda **kwargs: ("/fake/root", "repo1"), ) @@ -542,6 +543,61 @@ def fake_abort(**kwargs): assert calls["abort"] == 1 +def test_admin_copy_endpoint_reports_graph_clone_in_redirect(monkeypatch: pytest.MonkeyPatch): + import sys + import types + from urllib.parse import parse_qs, urlparse + + from scripts import upload_service + + monkeypatch.setattr(upload_service, "AUTH_ENABLED", True) + monkeypatch.setattr(upload_service, "_require_admin_session", lambda request: {"user_id": "admin"}) + monkeypatch.setattr(upload_service, "WORK_DIR", "/fake/work") + monkeypatch.setenv("WORK_DIR", "/fake/work") + + def fake_copy_collection_qdrant(**kwargs): + assert kwargs.get("source") == "src" + assert kwargs.get("target") == "dst" + return "dst" + + class _FakeQdrantClient: + def get_collection(self, collection_name: str): + if collection_name == "dst_graph": + return {"name": collection_name} + raise RuntimeError("not found") + + def __enter__(self): + return self + + def __exit__(self, *args): + return None + + monkeypatch.setitem( + sys.modules, + "scripts.collection_admin", + types.SimpleNamespace(copy_collection_qdrant=fake_copy_collection_qdrant), + ) + monkeypatch.setitem( + sys.modules, + "scripts.qdrant_client_manager", + types.SimpleNamespace(pooled_qdrant_client=lambda **kwargs: _FakeQdrantClient()), + ) + + client = TestClient(upload_service.app) + resp = client.post( + "/admin/staging/copy", + data={"collection": "src", "target": "dst", "overwrite": ""}, + follow_redirects=False, + ) + assert resp.status_code == 302 + loc = resp.headers.get("location") or "" + parsed = urlparse(loc) + qs = parse_qs(parsed.query) + assert qs.get("copied") == ["src"] + assert qs.get("new") == ["dst"] + assert qs.get("graph_copied") == ["1"] + + def test_watcher_collection_resolution_prefers_serving_state_when_staging_enabled(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): from scripts.watch_index_core import utils as watch_utils @@ -650,7 +706,9 @@ class _Proc: env = captured["env"] assert env["BASE_ONLY"] == "system" assert env["COLLECTION_NAME"] == "primary-coll" - assert "CTXCE_FORCE_COLLECTION_NAME" not in env + # Admin-spawned ingests should never enumerate `/work/*` in multi-repo mode; + # force exact collection/root handling even when no explicit overrides are provided. + assert env.get("CTXCE_FORCE_COLLECTION_NAME") == "1" def test_promote_pending_env_without_pending_config(staging_workspace: dict): @@ -738,31 +796,6 @@ def test_resolve_codebase_root_fallbacks_to_parent(monkeypatch: pytest.MonkeyPat assert resolved_parent == codebase_root -def test_admin_abort_endpoint_falls_back_to_clear_when_abort_helper_missing(monkeypatch: pytest.MonkeyPatch): - from scripts import upload_service - - calls = {"clear": []} - - monkeypatch.setattr(upload_service, "_require_admin_session", lambda request: {"user_id": "admin"}) - monkeypatch.setattr(upload_service, "abort_staging_rebuild", None) - monkeypatch.setattr( - upload_service, - "clear_staging_collection", - lambda workspace_path, repo_name: calls["clear"].append((workspace_path, repo_name)), - ) - monkeypatch.setattr( - upload_service, - "resolve_collection_root", - lambda **kwargs: ("/fake/root", "repo1"), - ) - - client = TestClient(upload_service.app) - - resp = client.post("/admin/staging/abort", data={"collection": "coll1"}, follow_redirects=False) - assert resp.status_code == 302 - assert calls["clear"] == [("/fake/root", "repo1")] - - def test_watcher_collection_reuse_logical_repo(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): from scripts.watch_index_core import utils as watch_utils diff --git a/tests/test_subprocess_hybrid_smoke.py b/tests/test_subprocess_hybrid_smoke.py index 927a95e9..5badc1b3 100644 --- a/tests/test_subprocess_hybrid_smoke.py +++ b/tests/test_subprocess_hybrid_smoke.py @@ -55,7 +55,8 @@ def test_hybrid_cli_runs_basic(tmp_path, qdrant_container): env["EMBEDDING_MODEL"] = "BAAI/bge-base-en-v1.5" cmd = [ sys.executable, - "scripts/hybrid_search.py", + "-m", + "scripts.hybrid_search", "--query", "test", "--limit", diff --git a/tests/test_symbol_graph_tool.py b/tests/test_symbol_graph_tool.py index e148fc17..1d3861d0 100644 --- a/tests/test_symbol_graph_tool.py +++ b/tests/test_symbol_graph_tool.py @@ -2,21 +2,43 @@ @pytest.mark.asyncio -async def test_symbol_graph_under_uses_path_prefix_matchvalue(): - # Import internal helper to validate filter construction without needing a real Qdrant instance. - from qdrant_client import models as qmodels +async def test_symbol_graph_under_filters_results_by_recursive_scope(): + # Validate that under applies as recursive subtree filter (user-facing scope). from scripts.mcp_impl import symbol_graph as sg - captured = {} + class _Pt: + def __init__(self, pid, path): + self.id = pid + self.payload = { + "metadata": { + "repo": "repo", + "path": path, + "start_line": 1, + "end_line": 2, + "symbol": "f", + "symbol_path": "f", + "language": "python", + "calls": ["foo"], + } + } class FakeClient: + def __init__(self): + self.scroll_filters = [] + def scroll(self, *, collection_name, scroll_filter, limit, with_payload, with_vectors): - captured["collection_name"] = collection_name - captured["scroll_filter"] = scroll_filter - return ([], None) + self.scroll_filters.append(scroll_filter) + return ( + [ + _Pt("1", "/work/repo/scripts/a.py"), + _Pt("2", "/work/repo/tests/b.py"), + ], + None, + ) - await sg._query_array_field( # type: ignore[attr-defined] - client=FakeClient(), + client = FakeClient() + out = await sg._query_array_field( # type: ignore[attr-defined] + client=client, collection="codebase", field_key="metadata.calls", value="foo", @@ -25,15 +47,29 @@ def scroll(self, *, collection_name, scroll_filter, limit, with_payload, with_ve under=sg._norm_under("scripts"), # type: ignore[attr-defined] ) - flt = captured.get("scroll_filter") - assert isinstance(flt, qmodels.Filter) - must = list(flt.must or []) - keys = [getattr(c, "key", None) for c in must] - assert "metadata.path_prefix" in keys - - # Ensure it's an exact match (MatchValue), not substring (MatchText) - cond = next(c for c in must if getattr(c, "key", None) == "metadata.path_prefix") - assert isinstance(cond.match, qmodels.MatchValue) - assert cond.match.value == "/work/scripts" - + # Validate _query_array_field forwards language/value constraints to scroll_filter. + assert client.scroll_filters, "Expected at least one scroll() call" + first_filter = client.scroll_filters[0] + first_must = list(getattr(first_filter, "must", []) or []) + assert any( + getattr(cond, "key", None) == "metadata.calls" + and getattr(getattr(cond, "match", None), "any", None) == ["foo"] + for cond in first_must + ) + assert any( + getattr(cond, "key", None) == "metadata.language" + and getattr(getattr(cond, "match", None), "value", None) == "python" + for cond in first_must + ) + assert any( + any( + getattr(cond, "key", None) == "metadata.calls" + and getattr(getattr(cond, "match", None), "text", None) == "foo" + for cond in list(getattr(sf, "must", []) or []) + ) + for sf in client.scroll_filters + ), "Expected MatchText fallback filter for metadata.calls" + paths = {r.get("path") for r in out} + assert "/work/repo/scripts/a.py" in paths + assert "/work/repo/tests/b.py" not in paths diff --git a/tests/test_tier2_fallback.py b/tests/test_tier2_fallback.py index eb01bad3..c9595e19 100644 --- a/tests/test_tier2_fallback.py +++ b/tests/test_tier2_fallback.py @@ -39,12 +39,18 @@ def embed(self, texts): @pytest.mark.asyncio async def test_tier2_fallback_unconditional_with_language_filter(tmp_path, monkeypatch, qdrant_container): # Env for services - os.environ["QDRANT_URL"] = qdrant_container - os.environ["COLLECTION_NAME"] = f"test-{uuid.uuid4().hex[:8]}" - os.environ["USE_TREE_SITTER"] = "0" - os.environ["HYBRID_IN_PROCESS"] = "1" - os.environ["EMBEDDING_MODEL"] = "fake" - os.environ["REFRAG_GATE_FIRST"] = "1" # ensure Tier-1 gate-first path is active + monkeypatch.setenv("QDRANT_URL", qdrant_container) + monkeypatch.setenv("COLLECTION_NAME", f"test-{uuid.uuid4().hex[:8]}") + monkeypatch.setenv("USE_TREE_SITTER", "0") + monkeypatch.setenv("HYBRID_IN_PROCESS", "1") + monkeypatch.setenv("EMBEDDING_MODEL", "fake") + monkeypatch.setenv("REFRAG_GATE_FIRST", "1") # ensure Tier-1 gate-first path is active + monkeypatch.setenv("REFRAG_RUNTIME", "llamacpp") + monkeypatch.setenv("CTX_MULTI_COLLECTION", "0") + monkeypatch.setenv("CTX_DOC_PASS", "0") + monkeypatch.setenv("CTX_DOC_TOP_FALLBACK", "0") + monkeypatch.setenv("HYBRID_EXPAND", "0") + monkeypatch.setenv("SEMANTIC_EXPANSION_ENABLED", "0") # Stub embeddings everywhere (FakeEmbedder produces 32-dim vectors) monkeypatch.setattr(ing, "TextEmbedding", lambda *a, **k: FakeEmbedder("fake")) diff --git a/tests/test_toon_encoder.py b/tests/test_toon_encoder.py index 326184e3..4e9c26ac 100644 --- a/tests/test_toon_encoder.py +++ b/tests/test_toon_encoder.py @@ -20,6 +20,10 @@ _encode_value, _is_uniform_array_of_objects, ) +from scripts.mcp_impl.toon import ( + _format_context_results_as_toon, + _format_results_as_toon, +) class TestFeatureFlags: @@ -256,13 +260,8 @@ class TestMCPIntegration: def test_should_use_toon_explicit_param(self, monkeypatch): """Test explicit output_format parameter takes precedence.""" - # Import the helpers from mcp_indexer_server monkeypatch.delenv("TOON_ENABLED", raising=False) - # We need to test the helper functions directly - # Since they're in mcp_indexer_server, we'll test the logic here - from scripts.toon_encoder import is_toon_enabled - # When TOON_ENABLED is not set, default is False assert is_toon_enabled() is False @@ -290,13 +289,17 @@ def test_format_results_as_toon_structure(self): assert "/src/main.py,10,20" in toon_output assert "/src/utils.py,5,15" in toon_output - def test_toon_replaces_results_array(self): - """Test that TOON formatting replaces JSON array with TOON string.""" + formatted = _format_results_as_toon(response.copy(), compact=True) + assert formatted["results"] == response["results"] + assert formatted["output_format"] == "toon" + assert formatted["text"] == toon_output + + def test_toon_encoder_returns_string(self): + """Test that the low-level TOON encoder returns a string.""" results = [ {"path": "/src/main.py", "start_line": 10, "end_line": 20}, ] - # When TOON is applied, results becomes a string toon_output = encode_search_results(results, compact=True) assert isinstance(toon_output, str) assert "results[1]{path,start_line,end_line}:" in toon_output @@ -416,19 +419,16 @@ class TestFormatContextResultsAsToon: def test_format_empty_results_adds_marker(self): """Test that empty results still get output_format marker.""" - from scripts.mcp_indexer_server import _format_context_results_as_toon - response = {"results": [], "total": 0} result = _format_context_results_as_toon(response.copy()) assert result["output_format"] == "toon" + assert result["results"] == [] # Empty array per spec: key[0]: - assert result["results"] == "results[0]:" + assert result["text"] == "results[0]:" def test_format_mixed_results(self): """Test formatting mixed code/memory results.""" - from scripts.mcp_indexer_server import _format_context_results_as_toon - response = { "results": [ {"source": "code", "path": "/src/api.py", "start_line": 1, "end_line": 10}, @@ -439,14 +439,13 @@ def test_format_mixed_results(self): result = _format_context_results_as_toon(response.copy()) assert result["output_format"] == "toon" - assert isinstance(result["results"], str) - assert "code[1]" in result["results"] - assert "memory[1]" in result["results"] + assert isinstance(result["results"], list) + assert isinstance(result["text"], str) + assert "code[1]" in result["text"] + assert "memory[1]" in result["text"] def test_format_preserves_other_fields(self): """Test that formatting preserves non-results fields.""" - from scripts.mcp_indexer_server import _format_context_results_as_toon - response = { "results": [{"source": "code", "path": "/a.py", "start_line": 1, "end_line": 5}], "total": 1, @@ -476,7 +475,7 @@ def test_encode_with_snippet_field(self): assert "def main()" in output def test_encode_with_information_field(self): - """Test that info_request's information field is included.""" + """Test that the information field is included.""" results = [ {"path": "/src/auth.py", "start_line": 1, "end_line": 50, "score": 0.9, "information": "Authentication handler at /src/auth.py:1-50", @@ -529,4 +528,3 @@ def test_context_results_with_extra_memory_fields(self): assert "id" in output assert "created_at" in output assert "tags" in output - diff --git a/tests/test_upload_client_ignore_cleanup.py b/tests/test_upload_client_ignore_cleanup.py new file mode 100644 index 00000000..4b4da482 --- /dev/null +++ b/tests/test_upload_client_ignore_cleanup.py @@ -0,0 +1,604 @@ +import importlib +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +@pytest.mark.parametrize( + "mod_name", + ["scripts.remote_upload_client", "scripts.standalone_upload_client"], +) +def test_remote_upload_config_does_not_generate_collection_name(monkeypatch, tmp_path, mod_name): + mod = importlib.import_module(mod_name) + workspace = tmp_path / "repo" + workspace.mkdir() + + monkeypatch.setattr(mod, "_compute_logical_repo_id", lambda _path: "fs:test") + + config = mod.get_remote_config(str(workspace)) + + assert config["collection_name"] is None + + +def _exercise_ignored_path_cleanup(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + ignored = workspace / "dev-workspace" / "nested.py" + ignored.parent.mkdir(parents=True, exist_ok=True) + ignored.write_text("print('dogfood')\n", encoding="utf-8") + + monkeypatch.setenv("DEV_REMOTE_MODE", "1") + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: "abc123") + monkeypatch.setattr(mod, "set_cached_file_hash", lambda *a, **k: None) + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + changes = client.detect_file_changes([ignored]) + + assert ignored in changes["deleted"] + assert not changes["created"] + assert not changes["updated"] + assert not changes["moved"] + + +def test_remote_upload_client_marks_ignored_cached_paths_deleted(monkeypatch, tmp_path): + _exercise_ignored_path_cleanup("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_marks_ignored_cached_paths_deleted(monkeypatch, tmp_path): + _exercise_ignored_path_cleanup("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_force_mode_cleanup(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + stale_ignored = workspace / "dev-workspace" / "nested.py" + stale_ignored.parent.mkdir(parents=True, exist_ok=True) + stale_ignored.write_text("print('stale')\n", encoding="utf-8") + + monkeypatch.setenv("DEV_REMOTE_MODE", "1") + monkeypatch.setattr(mod, "get_all_cached_paths", lambda repo_name=None: [str(stale_ignored)]) + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: "abc123") + monkeypatch.setattr(mod, "set_cached_file_hash", lambda *a, **k: None) + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + changes = client.build_force_changes([current]) + + assert current in changes["created"] + assert stale_ignored in changes["deleted"] + assert not changes["updated"] + assert not changes["moved"] + + +def test_remote_upload_client_force_mode_keeps_creates_and_deletes_ignored_cached_paths(monkeypatch, tmp_path): + _exercise_force_mode_cleanup("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_force_mode_keeps_creates_and_deletes_ignored_cached_paths(monkeypatch, tmp_path): + _exercise_force_mode_cleanup("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_force_mode_excludes_ignored_current_files(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + ignored_current = workspace / "dev-workspace" / "ignored.py" + ignored_current.parent.mkdir(parents=True, exist_ok=True) + ignored_current.write_text("print('ignored')\n", encoding="utf-8") + + monkeypatch.setenv("DEV_REMOTE_MODE", "1") + monkeypatch.setattr(mod, "get_all_cached_paths", lambda repo_name=None: []) + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: None) + monkeypatch.setattr(mod, "set_cached_file_hash", lambda *a, **k: None) + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + changes = client.build_force_changes([current, ignored_current]) + + assert current in changes["created"] + assert ignored_current not in changes["created"] + assert ignored_current in changes["deleted"] + assert not changes["updated"] + assert not changes["moved"] + + +def test_remote_upload_client_force_mode_excludes_ignored_current_files(monkeypatch, tmp_path): + _exercise_force_mode_excludes_ignored_current_files( + "scripts.remote_upload_client", + monkeypatch, + tmp_path, + ) + + +def test_standalone_upload_client_force_mode_excludes_ignored_current_files(monkeypatch, tmp_path): + _exercise_force_mode_excludes_ignored_current_files( + "scripts.standalone_upload_client", + monkeypatch, + tmp_path, + ) + + +def _exercise_force_mode_dev_workspace_cleanup_without_cache(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + mirrored = workspace / "dev-workspace" / "nested" / "stale.py" + mirrored.parent.mkdir(parents=True, exist_ok=True) + mirrored.write_text("print('stale')\n", encoding="utf-8") + + monkeypatch.setenv("DEV_REMOTE_MODE", "1") + monkeypatch.setattr(mod, "get_all_cached_paths", lambda repo_name=None: []) + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: None) + monkeypatch.setattr(mod, "set_cached_file_hash", lambda *a, **k: None) + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + changes = client.build_force_changes([current]) + + assert current in changes["created"] + assert mirrored in changes["deleted"] + assert not changes["updated"] + assert not changes["moved"] + + +def test_remote_upload_client_force_mode_deletes_dev_workspace_without_cache(monkeypatch, tmp_path): + _exercise_force_mode_dev_workspace_cleanup_without_cache("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_force_mode_deletes_dev_workspace_without_cache(monkeypatch, tmp_path): + _exercise_force_mode_dev_workspace_cleanup_without_cache("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_plan_skip_avoids_bundle_upload(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + monkeypatch.setattr( + client, + "_plan_delta_upload", + lambda changes: { + "needed_files": {"created": [], "updated": [], "moved": []}, + "operation_counts_preview": { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 1, + "skipped_hash_match": 1, + "failed": 0, + }, + "needed_size_bytes": 0, + }, + ) + monkeypatch.setattr(client, "create_delta_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not bundle"))) + monkeypatch.setattr(client, "upload_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not upload"))) + + assert client.process_changes_and_upload( + { + "created": [current], + "updated": [], + "deleted": [], + "moved": [], + "unchanged": [], + } + ) is True + assert client.last_upload_result["outcome"] == "skipped_by_plan" + + +def test_remote_upload_client_plan_skip_avoids_bundle_upload(monkeypatch, tmp_path): + _exercise_plan_skip_avoids_bundle_upload("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_plan_skip_avoids_bundle_upload(monkeypatch, tmp_path): + _exercise_plan_skip_avoids_bundle_upload("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_detect_file_changes_does_not_persist_hash(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + set_hash = MagicMock() + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: "oldhash") + monkeypatch.setattr(mod, "set_cached_file_hash", set_hash) + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + changes = client.detect_file_changes([current]) + + assert current in changes["updated"] + set_hash.assert_not_called() + + +def test_remote_upload_client_detect_file_changes_does_not_persist_hash(monkeypatch, tmp_path): + _exercise_detect_file_changes_does_not_persist_hash( + "scripts.remote_upload_client", monkeypatch, tmp_path + ) + + +def test_standalone_upload_client_detect_file_changes_does_not_persist_hash(monkeypatch, tmp_path): + _exercise_detect_file_changes_does_not_persist_hash( + "scripts.standalone_upload_client", monkeypatch, tmp_path + ) + + +def _exercise_plan_skip_finalizes_hash(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + set_hash = MagicMock() + monkeypatch.setattr(mod, "set_cached_file_hash", set_hash) + monkeypatch.setattr( + client, + "_plan_delta_upload", + lambda changes: { + "needed_files": {"created": [], "updated": [], "moved": []}, + "operation_counts_preview": { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 1, + "skipped_hash_match": 1, + "failed": 0, + }, + "needed_size_bytes": 0, + }, + ) + monkeypatch.setattr(client, "create_delta_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not bundle"))) + monkeypatch.setattr(client, "upload_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not upload"))) + + assert client.process_changes_and_upload( + { + "created": [], + "updated": [current], + "deleted": [], + "moved": [], + "unchanged": [], + } + ) is True + assert client.last_upload_result["outcome"] == "skipped_by_plan" + set_hash.assert_called_once() + + +def test_remote_upload_client_plan_skip_finalizes_hash(monkeypatch, tmp_path): + _exercise_plan_skip_finalizes_hash( + "scripts.remote_upload_client", monkeypatch, tmp_path + ) + + +def test_standalone_upload_client_plan_skip_finalizes_hash(monkeypatch, tmp_path): + _exercise_plan_skip_finalizes_hash( + "scripts.standalone_upload_client", monkeypatch, tmp_path + ) + + +def test_standalone_upload_client_plan_payload_prefixes_previous_hash(monkeypatch, tmp_path): + mod = importlib.import_module("scripts.standalone_upload_client") + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + updated = workspace / "app.py" + updated.write_text("print('updated')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + monkeypatch.setattr(mod, "get_cached_file_hash", lambda path, repo_name=None: "abc123") + + payload = client._build_plan_payload( + { + "created": [], + "updated": [updated], + "deleted": [updated], + "moved": [], + } + ) + + updated_op = next(op for op in payload["operations"] if op["operation"] == "updated") + deleted_op = next(op for op in payload["operations"] if op["operation"] == "deleted") + assert updated_op["previous_hash"] == "sha1:abc123" + assert deleted_op["previous_hash"] == "sha1:abc123" + + +def _exercise_delete_only_plan_uses_apply_ops(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + deleted = workspace / "old.py" + deleted.write_text("print('old')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + removed_paths = [] + + monkeypatch.setattr( + client, + "_plan_delta_upload", + lambda changes: { + "needed_files": {"created": [], "updated": [], "moved": []}, + "operation_counts_preview": { + "created": 0, + "updated": 0, + "deleted": 1, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + }, + "needed_size_bytes": 0, + }, + ) + monkeypatch.setattr( + client, + "_build_plan_payload", + lambda changes: { + "manifest": {"bundle_id": "b1", "sequence_number": None}, + "operations": [{"operation": "deleted", "path": "old.py"}], + "file_hashes": {}, + }, + ) + monkeypatch.setattr(client, "create_delta_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not bundle"))) + monkeypatch.setattr(client, "upload_bundle", lambda *a, **k: (_ for _ in ()).throw(RuntimeError("should not upload"))) + + class _Resp: + status_code = 200 + + @staticmethod + def raise_for_status(): + return None + + @staticmethod + def json(): + return { + "success": True, + "bundle_id": "b1", + "sequence_number": 3, + "processed_operations": {"deleted": 1, "created": 0, "updated": 0, "moved": 0, "skipped": 0, "skipped_hash_match": 0, "failed": 0}, + } + + monkeypatch.setattr(client.session, "post", lambda *a, **k: _Resp()) + monkeypatch.setattr(mod, "remove_cached_file", lambda path, repo_name=None: removed_paths.append((path, repo_name))) + + assert client.process_changes_and_upload( + { + "created": [], + "updated": [], + "deleted": [deleted], + "moved": [], + "unchanged": [], + } + ) is True + assert client.last_upload_result["outcome"] == "uploaded" + assert client.last_upload_result["processed_operations"]["deleted"] == 1 + assert removed_paths == [(str(deleted.resolve()), client.repo_name)] + + +def test_remote_upload_client_delete_only_plan_uses_apply_ops(monkeypatch, tmp_path): + _exercise_delete_only_plan_uses_apply_ops("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_delete_only_plan_uses_apply_ops(monkeypatch, tmp_path): + _exercise_delete_only_plan_uses_apply_ops("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_async_upload_sets_queued_result(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + monkeypatch.setenv("CTXCE_REMOTE_UPLOAD_STATUS_WAIT_SECS", "0") + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + bundle_path = workspace / "bundle.tar.gz" + bundle_path.write_bytes(b"bundle") + monkeypatch.setattr(client, "_plan_delta_upload", lambda changes: None) + monkeypatch.setattr( + client, + "create_delta_bundle", + lambda changes: (str(bundle_path), {"bundle_id": "bundle-1", "total_size_bytes": 6}), + ) + monkeypatch.setattr( + client, + "upload_bundle", + lambda *a, **k: {"success": True, "sequence_number": 7, "processed_operations": None}, + ) + monkeypatch.setattr(mod, "flush_cached_file_hashes", lambda: None, raising=False) + + assert client.process_changes_and_upload( + { + "created": [current], + "updated": [], + "deleted": [], + "moved": [], + "unchanged": [], + } + ) is True + assert client.last_upload_result["outcome"] == "queued" + assert client.last_upload_result["sequence_number"] == 7 + + +def _exercise_async_upload_promotes_completed_result(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + current = workspace / "app.py" + current.write_text("print('current')\n", encoding="utf-8") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + bundle_path = workspace / "bundle.tar.gz" + bundle_path.write_bytes(b"bundle") + monkeypatch.setattr(client, "_plan_delta_upload", lambda changes: None) + monkeypatch.setattr( + client, + "create_delta_bundle", + lambda changes: (str(bundle_path), {"bundle_id": "bundle-1", "total_size_bytes": 6}), + ) + monkeypatch.setattr( + client, + "upload_bundle", + lambda *a, **k: {"success": True, "sequence_number": 7, "processed_operations": None}, + ) + monkeypatch.setattr( + client, + "get_server_status", + lambda: { + "success": True, + "last_sequence": 7, + "server_info": { + "last_bundle_id": "bundle-1", + "last_upload_status": "completed", + "last_processed_operations": {"updated": 1, "failed": 0}, + "last_processing_time_ms": 12, + }, + }, + ) + monkeypatch.setattr(mod, "flush_cached_file_hashes", lambda: None, raising=False) + + assert client.process_changes_and_upload( + { + "created": [current], + "updated": [], + "deleted": [], + "moved": [], + "unchanged": [], + } + ) is True + assert client.last_upload_result["outcome"] == "uploaded_async" + assert client.last_upload_result["processed_operations"] == {"updated": 1, "failed": 0} + + +def test_remote_upload_client_async_upload_sets_queued_result(monkeypatch, tmp_path): + _exercise_async_upload_sets_queued_result("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_async_upload_sets_queued_result(monkeypatch, tmp_path): + _exercise_async_upload_sets_queued_result("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def test_remote_upload_client_async_upload_promotes_completed_result(monkeypatch, tmp_path): + _exercise_async_upload_promotes_completed_result("scripts.remote_upload_client", monkeypatch, tmp_path) + + +def test_standalone_upload_client_async_upload_promotes_completed_result(monkeypatch, tmp_path): + _exercise_async_upload_promotes_completed_result("scripts.standalone_upload_client", monkeypatch, tmp_path) + + +def _exercise_watchable_path_excludes_ignored_updates(mod_name: str, monkeypatch, tmp_path: Path) -> None: + mod = importlib.import_module(mod_name) + + workspace = tmp_path / "repo" + workspace.mkdir(parents=True, exist_ok=True) + source = workspace / "src" / "tracked.py" + source.parent.mkdir(parents=True, exist_ok=True) + source.write_text("print('tracked')\n", encoding="utf-8") + + mirrored = workspace / "dev-workspace" / "nested" / "ignored.py" + mirrored.parent.mkdir(parents=True, exist_ok=True) + mirrored.write_text("print('ignored')\n", encoding="utf-8") + + monkeypatch.setenv("DEV_REMOTE_MODE", "1") + + client = mod.RemoteUploadClient( + upload_endpoint="http://localhost:8004", + workspace_path=str(workspace), + collection_name="test-coll", + ) + + assert client._is_watchable_path(source) is True + assert client._is_watchable_path(mirrored) is False + + +def test_remote_upload_client_watchable_path_excludes_ignored_updates(monkeypatch, tmp_path): + _exercise_watchable_path_excludes_ignored_updates( + "scripts.remote_upload_client", + monkeypatch, + tmp_path, + ) + + +def test_standalone_upload_client_watchable_path_excludes_ignored_updates(monkeypatch, tmp_path): + _exercise_watchable_path_excludes_ignored_updates( + "scripts.standalone_upload_client", + monkeypatch, + tmp_path, + ) diff --git a/tests/test_upload_service_path_traversal.py b/tests/test_upload_service_path_traversal.py index 0d01478f..f0e24b69 100644 --- a/tests/test_upload_service_path_traversal.py +++ b/tests/test_upload_service_path_traversal.py @@ -1,11 +1,19 @@ import io import json +import os import tarfile from pathlib import Path import pytest +@pytest.fixture(autouse=True) +def _disable_ambient_staging(monkeypatch): + import scripts.upload_delta_bundle as us + + monkeypatch.setattr(us, "is_staging_enabled", lambda: False) + + def _write_bundle(tmp_path: Path, operations: list[dict]) -> Path: bundle_path = tmp_path / "bundle.tar.gz" payload = json.dumps({"operations": operations}).encode("utf-8") @@ -52,6 +60,59 @@ def _write_bundle_with_created_file(tmp_path: Path, rel_path: str, content: byte return bundle_path +def _write_bundle_with_hash_metadata( + tmp_path: Path, + *, + operations: list[dict], + file_hashes: dict[str, str] | None = None, + created_files: dict[str, bytes] | None = None, + updated_files: dict[str, bytes] | None = None, +) -> Path: + bundle_path = tmp_path / "bundle-hashes.tar.gz" + payload = json.dumps({"operations": operations}).encode("utf-8") + hashes_payload = json.dumps({"file_hashes": file_hashes or {}}).encode("utf-8") + + with tarfile.open(bundle_path, "w:gz") as tar: + info = tarfile.TarInfo(name="metadata/operations.json") + info.size = len(payload) + tar.addfile(info, io.BytesIO(payload)) + + hashes_info = tarfile.TarInfo(name="metadata/hashes.json") + hashes_info.size = len(hashes_payload) + tar.addfile(hashes_info, io.BytesIO(hashes_payload)) + + for rel_path, content in (created_files or {}).items(): + file_info = tarfile.TarInfo(name=f"files/created/{rel_path}") + file_info.size = len(content) + tar.addfile(file_info, io.BytesIO(content)) + + for rel_path, content in (updated_files or {}).items(): + file_info = tarfile.TarInfo(name=f"files/updated/{rel_path}") + file_info.size = len(content) + tar.addfile(file_info, io.BytesIO(content)) + + return bundle_path + + +def _write_repo_cache(work_dir: Path, slug: str, rel_path: str, file_hash: str) -> None: + target = (work_dir / slug / rel_path).resolve() + cache_path = work_dir / ".codebase" / "repos" / slug / "cache.json" + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_text( + json.dumps( + { + "file_hashes": { + str(target): { + "hash": file_hash, + } + } + }, + indent=2, + ), + encoding="utf-8", + ) + + def test_process_delta_bundle_rejects_traversal_created(tmp_path, monkeypatch): import scripts.upload_delta_bundle as us @@ -197,3 +258,484 @@ def test_process_delta_bundle_rejects_traversal_moved_source(tmp_path, monkeypat bundle_path=bundle, manifest={"bundle_id": "b1"}, ) + + +def test_process_delta_bundle_skips_created_write_when_server_hash_matches(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "src/file.txt" + content = b"same-content" + file_hash = "sha1:efb5d7d4d38013264f2c00fceeb401f8c8d77d9f" + + target = work_dir / slug / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(content) + os.utime(target, ns=(1_000_000_000, 1_000_000_000)) + before_mtime_ns = target.stat().st_mtime_ns + _write_repo_cache(work_dir, slug, rel_path, file_hash) + + bundle = _write_bundle_with_hash_metadata( + tmp_path, + operations=[ + { + "operation": "created", + "path": rel_path, + "content_hash": file_hash, + } + ], + file_hashes={rel_path: file_hash}, + created_files={rel_path: content}, + ) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-skip-created"}, + ) + + assert counts.get("created") == 0 + assert counts.get("skipped") == 1 + assert counts.get("skipped_hash_match") == 1 + assert target.read_bytes() == content + assert target.stat().st_mtime_ns == before_mtime_ns + + +def test_process_delta_bundle_uses_hashes_metadata_for_updated_skip(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "src/keep.txt" + content = b"existing-content" + file_hash = "sha1:2910e29d6f6d3d2f01f8cc52ec386a4936ca9d2f" + + target = work_dir / slug / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(content) + os.utime(target, ns=(2_000_000_000, 2_000_000_000)) + before_mtime_ns = target.stat().st_mtime_ns + _write_repo_cache(work_dir, slug, rel_path, file_hash) + + bundle = _write_bundle_with_hash_metadata( + tmp_path, + operations=[ + { + "operation": "updated", + "path": rel_path, + } + ], + file_hashes={rel_path: file_hash}, + updated_files={rel_path: content}, + ) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-skip-updated"}, + ) + + assert counts.get("updated") == 0 + assert counts.get("skipped") == 1 + assert counts.get("skipped_hash_match") == 1 + assert target.read_bytes() == content + assert target.stat().st_mtime_ns == before_mtime_ns + + +def test_normalize_hash_value_strips_algorithm_prefixes(): + import scripts.upload_delta_bundle as us + + assert us._normalize_hash_value("sha1:ABCDEF") == "abcdef" + assert us._normalize_hash_value("md5:ABCDEF") == "abcdef" + assert us._normalize_hash_value("sha256:ABCDEF") == "abcdef" + assert us._normalize_hash_value("ABCDEF") == "abcdef" + + +def test_process_delta_bundle_uses_first_marker_match_for_created_members(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "nested/files/created/path.txt" + content = b"marker-safe" + bundle = _write_bundle_with_created_file(tmp_path, rel_path, content) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-created-marker"}, + ) + + assert counts.get("created") == 1 + assert (work_dir / slug / rel_path).read_bytes() == content + + +def test_process_delta_bundle_deleted_prunes_empty_parent_dirs(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "dev-workspace/nested/stale.py" + target = work_dir / slug / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text("stale\n", encoding="utf-8") + + bundle = _write_bundle( + tmp_path, + [{"operation": "deleted", "path": rel_path}], + ) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-delete-prune"}, + ) + + assert counts.get("deleted") == 1 + assert not target.exists() + assert not (work_dir / slug / "dev-workspace" / "nested").exists() + assert not (work_dir / slug / "dev-workspace").exists() + assert (work_dir / slug).exists() + + +def test_process_delta_bundle_moved_prunes_empty_source_parent_dirs(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + src = work_dir / slug / "dev-workspace" / "nested" / "from.py" + dest_rel_path = "dest/to.py" + src.parent.mkdir(parents=True, exist_ok=True) + src.write_text("payload\n", encoding="utf-8") + + bundle = _write_bundle( + tmp_path, + [{"operation": "moved", "path": dest_rel_path, "source_path": "dev-workspace/nested/from.py"}], + ) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-move-prune"}, + ) + + assert counts.get("moved") == 1 + assert not src.exists() + assert (work_dir / slug / dest_rel_path).read_text(encoding="utf-8") == "payload\n" + assert not (work_dir / slug / "dev-workspace" / "nested").exists() + assert not (work_dir / slug / "dev-workspace").exists() + assert (work_dir / slug).exists() + + +def test_process_delta_bundle_does_not_sweep_stranded_empty_dirs_without_file_ops(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + slug = "repo-0123456789abcdef" + stranded = work_dir / slug / "dev-workspace" / "nested" / "empty" + stranded.mkdir(parents=True, exist_ok=True) + + bundle = _write_bundle(tmp_path, []) + + counts = us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-sweep-empty"}, + ) + + assert counts == { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + } + assert stranded.exists() + assert (work_dir / slug / "dev-workspace").exists() + assert (work_dir / slug).exists() + + +def test_process_delta_bundle_skips_broad_empty_dir_sweep_when_disabled(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + monkeypatch.setenv("CTXCE_UPLOAD_EMPTY_DIR_SWEEP", "0") + + slug = "repo-0123456789abcdef" + stranded = work_dir / slug / "dev-workspace" / "nested" / "empty" + stranded.mkdir(parents=True, exist_ok=True) + + bundle = _write_bundle(tmp_path, []) + + us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-sweep-disabled"}, + ) + + assert stranded.exists() + + +def test_process_delta_bundle_skips_broad_empty_dir_sweep_when_recent(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + slug = "repo-0123456789abcdef" + stranded = work_dir / slug / "dev-workspace" / "nested" / "empty" + stranded.mkdir(parents=True, exist_ok=True) + + bundle = _write_bundle(tmp_path, []) + + us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-sweep-recent"}, + ) + + assert stranded.exists() + + +def test_process_delta_bundle_preserves_protected_top_level_dirs_when_empty(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + monkeypatch.setenv("CTXCE_UPLOAD_EMPTY_DIR_SWEEP", "1") + monkeypatch.setenv("CTXCE_UPLOAD_EMPTY_DIR_SWEEP_INTERVAL_SECONDS", "0") + + slug = "repo-0123456789abcdef" + protected = work_dir / slug / ".remote-git" + protected.mkdir(parents=True, exist_ok=True) + + bundle = _write_bundle(tmp_path, []) + + us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-protected-empty"}, + ) + + assert protected.exists() + + +def test_process_delta_bundle_preserves_nested_dirs_under_protected_top_level(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + monkeypatch.setenv("CTXCE_UPLOAD_EMPTY_DIR_SWEEP", "1") + monkeypatch.setenv("CTXCE_UPLOAD_EMPTY_DIR_SWEEP_INTERVAL_SECONDS", "0") + + slug = "repo-0123456789abcdef" + protected_nested = work_dir / slug / ".codebase" / "repos" / "empty" + protected_nested.mkdir(parents=True, exist_ok=True) + + bundle = _write_bundle(tmp_path, []) + + us.process_delta_bundle( + workspace_path=f"/work/{slug}", + bundle_path=bundle, + manifest={"bundle_id": "b-protected-nested-empty"}, + ) + + assert protected_nested.exists() + + +def test_plan_delta_upload_skips_matching_created_files(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "src/file.txt" + content = b"same-content" + file_hash = "sha1:efb5d7d4d38013264f2c00fceeb401f8c8d77d9f" + + target = work_dir / slug / rel_path + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(content) + _write_repo_cache(work_dir, slug, rel_path, file_hash) + + plan = us.plan_delta_upload( + workspace_path=f"/work/{slug}", + operations=[ + { + "operation": "created", + "path": rel_path, + "content_hash": file_hash, + "size_bytes": len(content), + } + ], + file_hashes={rel_path: file_hash}, + ) + + assert plan["needed_files"]["created"] == [] + assert plan["operation_counts_preview"]["skipped_hash_match"] == 1 + assert plan["needed_size_bytes"] == 0 + + +def test_plan_delta_upload_marks_updated_file_needed_when_hash_missing(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + rel_path = "src/keep.txt" + file_hash = "sha1:2910e29d6f6d3d2f01f8cc52ec386a4936ca9d2f" + + plan = us.plan_delta_upload( + workspace_path=f"/work/{slug}", + operations=[ + { + "operation": "updated", + "path": rel_path, + "content_hash": file_hash, + "size_bytes": 17, + } + ], + file_hashes={rel_path: file_hash}, + ) + + assert plan["needed_files"]["updated"] == [rel_path] + assert plan["operation_counts_preview"]["updated"] == 1 + assert plan["needed_size_bytes"] == 17 + + +def test_plan_delta_upload_skips_move_content_when_source_exists_on_server(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + source_rel = "src/old.py" + dest_rel = "src/new.py" + source = work_dir / slug / source_rel + source.parent.mkdir(parents=True, exist_ok=True) + source.write_text("print('move')\n", encoding="utf-8") + + plan = us.plan_delta_upload( + workspace_path=f"/work/{slug}", + operations=[ + { + "operation": "moved", + "path": dest_rel, + "source_path": source_rel, + "content_hash": "sha1:abc123", + "size_bytes": 12, + } + ], + file_hashes={dest_rel: "sha1:abc123"}, + ) + + assert plan["needed_files"]["moved"] == [] + assert plan["operation_counts_preview"]["moved"] == 1 + assert plan["needed_size_bytes"] == 0 + + +def test_plan_delta_upload_marks_move_needed_when_source_path_is_invalid(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + dest_rel = "src/new.py" + + plan = us.plan_delta_upload( + workspace_path=f"/work/{slug}", + operations=[ + { + "operation": "moved", + "path": dest_rel, + "source_path": "../escape.py", + "content_hash": "sha1:abc123", + "size_bytes": 12, + } + ], + file_hashes={dest_rel: "sha1:abc123"}, + ) + + assert plan["needed_files"]["moved"] == [dest_rel] + assert plan["operation_counts_preview"]["moved"] == 1 + assert plan["needed_size_bytes"] == 12 + + +def test_apply_delta_operations_moves_file_without_bundle(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + + slug = "repo-0123456789abcdef" + source_rel = "src/old.py" + dest_rel = "src/new.py" + source = work_dir / slug / source_rel + source.parent.mkdir(parents=True, exist_ok=True) + source.write_text("print('move')\n", encoding="utf-8") + + counts = us.apply_delta_operations( + workspace_path=f"/work/{slug}", + operations=[ + { + "operation": "moved", + "path": dest_rel, + "source_path": source_rel, + "content_hash": "sha1:abc123", + } + ], + file_hashes={dest_rel: "sha1:abc123"}, + ) + + assert counts["moved"] == 1 + assert not source.exists() + assert (work_dir / slug / dest_rel).exists() + + +def test_apply_delta_operations_raises_clear_error_when_no_replica_roots(tmp_path, monkeypatch): + import scripts.upload_delta_bundle as us + + work_dir = tmp_path / "work" + work_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setattr(us, "WORK_DIR", str(work_dir)) + monkeypatch.setattr(us, "_resolve_replica_roots", lambda workspace_path: {}) + + with pytest.raises(ValueError, match="No replica roots available"): + us.apply_delta_operations( + workspace_path="/work/repo", + operations=[], + file_hashes={}, + ) diff --git a/tests/test_upload_service_status.py b/tests/test_upload_service_status.py new file mode 100644 index 00000000..51414ca2 --- /dev/null +++ b/tests/test_upload_service_status.py @@ -0,0 +1,293 @@ +import asyncio +import importlib +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + + +def _disable_auth(srv, monkeypatch) -> None: + monkeypatch.setattr(srv, "AUTH_ENABLED", False) + + +@pytest.mark.unit +def test_delta_status_exposes_last_processed_operations(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr(srv, "get_collection_name", lambda _repo=None: "test-coll") + monkeypatch.setattr(srv, "_extract_repo_name_from_path", lambda _path: "repo") + + key = srv.get_workspace_key("/work/repo") + srv._sequence_tracker[key] = 7 + srv._upload_result_tracker[key] = { + "workspace_path": "/work/repo", + "bundle_id": "bundle-123", + "sequence_number": 7, + "processed_operations": { + "created": 1, + "updated": 2, + "deleted": 0, + "moved": 0, + "skipped": 5, + "skipped_hash_match": 4, + "failed": 0, + }, + "processing_time_ms": 321, + "status": "completed", + "completed_at": "2026-03-07T15:40:46.623000", + } + + client = TestClient(srv.app) + resp = client.get("/api/v1/delta/status", params={"workspace_path": "/work/repo"}) + assert resp.status_code == 200 + body = resp.json() + assert body["last_sequence"] == 7 + assert body["last_upload"] == "2026-03-07T15:40:46.623000" + assert body["status"] == "ready" + assert body["server_info"]["last_bundle_id"] == "bundle-123" + assert body["server_info"]["last_processing_time_ms"] == 321 + assert body["server_info"]["last_processed_operations"]["skipped_hash_match"] == 4 + assert body["server_info"]["last_upload_status"] == "completed" + assert body["server_info"]["last_error"] is None + + +@pytest.mark.unit +def test_process_bundle_background_tracks_completed_operations(monkeypatch, tmp_path: Path): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + bundle_path = tmp_path / "bundle.tar.gz" + bundle_path.write_bytes(b"placeholder") + + monkeypatch.setattr( + srv, + "process_delta_bundle", + lambda workspace_path, bundle_path, manifest: { + "created": 0, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 10, + "skipped_hash_match": 10, + "failed": 0, + }, + ) + monkeypatch.setattr(srv, "log_activity", lambda *a, **k: None) + + asyncio.run( + srv._process_bundle_background( + workspace_path="/work/repo", + bundle_path=bundle_path, + manifest={"bundle_id": "bundle-xyz"}, + sequence_number=3, + bundle_id="bundle-xyz", + ) + ) + + key = srv.get_workspace_key("/work/repo") + tracked = srv._upload_result_tracker[key] + assert tracked["status"] == "completed" + assert tracked["sequence_number"] == 3 + assert tracked["processed_operations"]["skipped_hash_match"] == 10 + assert tracked["processing_time_ms"] is not None + assert not bundle_path.exists() + + +@pytest.mark.unit +def test_delta_status_reports_processing_while_upload_in_progress(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr(srv, "get_collection_name", lambda _repo=None: "test-coll") + monkeypatch.setattr(srv, "_extract_repo_name_from_path", lambda _path: "repo") + + key = srv.get_workspace_key("/work/repo") + srv._upload_result_tracker[key] = { + "workspace_path": "/work/repo", + "bundle_id": "bundle-123", + "sequence_number": 8, + "processed_operations": None, + "processing_time_ms": None, + "status": "processing", + "completed_at": None, + } + + client = TestClient(srv.app) + resp = client.get("/api/v1/delta/status", params={"workspace_path": "/work/repo"}) + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "processing" + assert body["server_info"]["last_upload_status"] == "processing" + + +@pytest.mark.unit +def test_delta_plan_endpoint_returns_needed_files(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr( + srv, + "plan_delta_upload", + lambda workspace_path, operations, file_hashes=None: { + "needed_files": {"created": ["src/app.py"], "updated": [], "moved": []}, + "operation_counts_preview": { + "created": 1, + "updated": 0, + "deleted": 0, + "moved": 0, + "skipped": 2, + "skipped_hash_match": 2, + "failed": 0, + }, + "needed_size_bytes": 123, + "replica_targets": ["repo-0123456789abcdef"], + }, + ) + + client = TestClient(srv.app) + resp = client.post( + "/api/v1/delta/plan", + json={ + "workspace_path": "/work/repo", + "manifest": {"bundle_id": "b1"}, + "operations": [{"operation": "created", "path": "src/app.py"}], + "file_hashes": {"src/app.py": "sha1:abc"}, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["needed_files"]["created"] == ["src/app.py"] + assert body["operation_counts_preview"]["skipped_hash_match"] == 2 + assert body["needed_size_bytes"] == 123 + + +@pytest.mark.unit +def test_upload_managed_resolution_ignores_client_collection(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr(srv, "logical_repo_reuse_enabled", lambda: False) + monkeypatch.setattr(srv, "_extract_repo_name_from_path", lambda path: Path(path).name) + monkeypatch.setattr(srv, "get_collection_name", lambda repo=None: f"server-{repo}") + + collection, repo = srv._resolve_collection_for_request( + workspace_path="/work/repo", + client_collection_name="repo-071ca222", + logical_repo_id="fs:123", + source_path="/host/Context-Engine", + ) + + assert repo == "Context-Engine" + assert collection == "server-Context-Engine" + + +@pytest.mark.unit +def test_delta_plan_endpoint_uses_safe_defaults_for_sparse_plan(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr( + srv, + "plan_delta_upload", + lambda workspace_path, operations, file_hashes=None: {}, + ) + + client = TestClient(srv.app) + resp = client.post( + "/api/v1/delta/plan", + json={ + "workspace_path": "/work/repo", + "manifest": {"bundle_id": "b1"}, + "operations": [{"operation": "created", "path": "src/app.py"}], + "file_hashes": {"src/app.py": "sha1:abc"}, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["needed_files"] == {"created": [], "updated": [], "moved": []} + assert body["operation_counts_preview"]["failed"] == 0 + assert body["needed_size_bytes"] == 0 + assert body["replica_targets"] == [] + + +@pytest.mark.unit +def test_apply_ops_endpoint_returns_processed_operations(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr( + srv, + "apply_delta_operations", + lambda workspace_path, operations, file_hashes=None: { + "created": 0, + "updated": 0, + "deleted": 1, + "moved": 0, + "skipped": 0, + "skipped_hash_match": 0, + "failed": 0, + }, + ) + + client = TestClient(srv.app) + resp = client.post( + "/api/v1/delta/apply_ops", + json={ + "workspace_path": "/work/repo", + "manifest": {"bundle_id": "b2"}, + "operations": [{"operation": "deleted", "path": "src/old.py"}], + "file_hashes": {}, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["processed_operations"]["deleted"] == 1 + assert body["processing_time_ms"] is not None + + +@pytest.mark.unit +def test_apply_ops_endpoint_marks_tracker_error_state_on_failure(monkeypatch): + srv = importlib.import_module("scripts.upload_service") + srv = importlib.reload(srv) + _disable_auth(srv, monkeypatch) + + monkeypatch.setattr( + srv, + "apply_delta_operations", + lambda workspace_path, operations, file_hashes=None: (_ for _ in ()).throw( + RuntimeError("boom") + ), + ) + + client = TestClient(srv.app) + resp = client.post( + "/api/v1/delta/apply_ops", + json={ + "workspace_path": "/work/repo", + "manifest": {"bundle_id": "b3"}, + "operations": [{"operation": "deleted", "path": "src/old.py"}], + "file_hashes": {}, + }, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is False + assert body["error"]["code"] == "APPLY_OPS_ERROR" + + key = srv.get_workspace_key("/work/repo") + tracked = srv._upload_result_tracker[key] + assert tracked["status"] == "error" + assert tracked["error"] == "boom" + assert tracked["message"] == "boom" + assert tracked["completed_at"] is not None diff --git a/tests/test_watch_consistency.py b/tests/test_watch_consistency.py new file mode 100644 index 00000000..bf51c79e --- /dev/null +++ b/tests/test_watch_consistency.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +import importlib +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def capture_list_workspaces(): + captured = {} + + def fake_list_workspaces(search_root=None, use_qdrant_fallback=True): + captured["search_root"] = search_root + captured["use_qdrant_fallback"] = use_qdrant_fallback + return [] + + return captured, fake_list_workspaces + + +def test_run_consistency_audit_scans_from_watcher_root( + monkeypatch, tmp_path, capture_list_workspaces +): + mod = importlib.import_module("scripts.watch_index_core.consistency") + captured, fake_list_workspaces = capture_list_workspaces + + monkeypatch.setattr(mod, "list_workspaces", fake_list_workspaces) + monkeypatch.setattr(mod, "_consistency_audit_enabled", lambda: True) + + mod.run_consistency_audit(MagicMock(), tmp_path) + + assert "search_root" in captured and "use_qdrant_fallback" in captured + assert Path(captured["search_root"]).resolve() == Path(tmp_path).resolve() + assert captured["use_qdrant_fallback"] is False + + +def test_run_empty_dir_sweep_maintenance_scans_from_watcher_root( + monkeypatch, tmp_path, capture_list_workspaces +): + mod = importlib.import_module("scripts.watch_index_core.consistency") + captured, fake_list_workspaces = capture_list_workspaces + + monkeypatch.setattr(mod, "list_workspaces", fake_list_workspaces) + monkeypatch.setattr(mod, "_empty_dir_sweep_enabled", lambda: True) + + mod.run_empty_dir_sweep_maintenance(tmp_path) + + assert "search_root" in captured + assert Path(captured["search_root"]).resolve() == Path(tmp_path).resolve() + assert captured.get("use_qdrant_fallback") is False + + +def test_consistency_audit_skips_repairs_when_scan_is_truncated(monkeypatch, tmp_path): + mod = importlib.import_module("scripts.watch_index_core.consistency") + + workspace_root = tmp_path / "repo" + workspace_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + mod, + "list_workspaces", + lambda *a, **k: [{"workspace_path": str(workspace_root)}], + ) + monkeypatch.setattr(mod, "_consistency_audit_enabled", lambda: True) + monkeypatch.setattr(mod, "_should_run_consistency_audit", lambda *a, **k: True) + monkeypatch.setattr( + mod, + "get_collection_state_snapshot", + lambda *a, **k: {"active_collection": "coll"}, + ) + monkeypatch.setattr(mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(mod, "_load_cached_hashes", lambda *a, **k: {}) + monkeypatch.setattr( + mod, + "_scan_indexable_fs_paths", + lambda *a, **k: ({str(workspace_root / "a.py")}, True), + ) + monkeypatch.setattr( + mod, + "_load_indexed_paths_for_collection", + lambda *a, **k: ({str(workspace_root / "ghost.py")}, False), + ) + monkeypatch.setattr(mod.idx, "_Excluder", lambda *_: MagicMock()) + + enqueue_mock = MagicMock(return_value=(0, 0)) + record_mock = MagicMock() + monkeypatch.setattr(mod, "_enqueue_consistency_repairs", enqueue_mock) + monkeypatch.setattr(mod, "_record_consistency_audit", record_mock) + + mod.run_consistency_audit(MagicMock(), tmp_path) + + enqueue_mock.assert_not_called() + record_mock.assert_called_once() + summary = record_mock.call_args.args[2] + assert summary["fs_scan_truncated"] is True + assert summary["qdrant_scan_truncated"] is False + assert summary["repair_skipped_due_to_truncation"] is True + assert summary["stale_in_qdrant_count"] == 0 + assert summary["missing_in_qdrant_count"] == 0 diff --git a/tests/test_watch_index_cache.py b/tests/test_watch_index_cache.py index c5065af1..dfb00662 100644 --- a/tests/test_watch_index_cache.py +++ b/tests/test_watch_index_cache.py @@ -150,3 +150,558 @@ def test_processor_delete_clears_cache_even_without_client(monkeypatch, tmp_path ) remove_mock.assert_called_once_with(str(missing), "repo") + + +def test_run_indexing_strategy_reuses_preloaded_file_state(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "python") + monkeypatch.setattr(proc_mod.idx, "should_use_smart_reindexing", lambda *a, **k: (False, "changed")) + + captured = {} + + def fake_index_single_file(*args, **kwargs): + captured.update(kwargs) + return True + + monkeypatch.setattr(proc_mod.idx, "index_single_file", fake_index_single_file) + + ok = proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + ) + + assert ok is True + assert captured["preloaded_text"] == "print('x')\n" + assert captured["preloaded_file_hash"] == "abc123" + assert captured["preloaded_language"] == "python" + + +def test_run_indexing_strategy_skips_ensure_for_cached_hash_match(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + ensure_mock = MagicMock() + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", ensure_mock) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: "abc123") + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "python") + + with pytest.raises(proc_mod._SkipUnchanged): + proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + ) + + ensure_mock.assert_not_called() + + +def test_run_indexing_strategy_force_upsert_bypasses_cached_hash_match( + monkeypatch, tmp_path +): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + ensure_mock = MagicMock() + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", ensure_mock) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: "abc123") + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "python") + monkeypatch.setattr(proc_mod.idx, "should_use_smart_reindexing", lambda *a, **k: (False, "changed")) + + index_mock = MagicMock(return_value=True) + monkeypatch.setattr(proc_mod.idx, "index_single_file", index_mock) + + ok = proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + force_upsert=True, + ) + + assert ok is True + ensure_mock.assert_called_once() + index_mock.assert_called_once() + + +def test_run_indexing_strategy_skips_smart_path_for_markdown(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "notes.md" + path.write_text("# notes\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("# notes\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "markdown") + + smart_check = MagicMock(side_effect=AssertionError("smart path must be skipped")) + monkeypatch.setattr(proc_mod.idx, "should_use_smart_reindexing", smart_check) + + captured = {} + + def fake_index_single_file(*args, **kwargs): + captured.update(kwargs) + return True + + monkeypatch.setattr(proc_mod.idx, "index_single_file", fake_index_single_file) + + ok = proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + ) + + assert ok is True + smart_check.assert_not_called() + assert captured["preloaded_language"] == "markdown" + + +def test_run_indexing_strategy_force_upsert_missing_points_bypasses_smart( + monkeypatch, tmp_path +): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "python") + monkeypatch.setattr(proc_mod.idx, "should_use_smart_reindexing", lambda *a, **k: (True, "smart_reindex")) + monkeypatch.setattr(proc_mod.idx, "get_indexed_file_hash", lambda *a, **k: "") + monkeypatch.setattr(proc_mod, "_path_has_indexed_points", lambda *a, **k: False) + + smart_mock = MagicMock(return_value="skipped") + monkeypatch.setattr(proc_mod.idx, "process_file_with_smart_reindexing", smart_mock) + + index_mock = MagicMock(return_value=True) + monkeypatch.setattr(proc_mod.idx, "index_single_file", index_mock) + + ok = proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + force_upsert=True, + ) + + assert ok is True + smart_mock.assert_not_called() + index_mock.assert_called_once() + + +def test_run_indexing_strategy_sets_skip_verify_reason_for_file_lock( + monkeypatch, tmp_path +): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod.idx, "ensure_collection_and_indexes_once", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: None) + monkeypatch.setattr(proc_mod.idx, "detect_language", lambda _p: "python") + monkeypatch.setattr(proc_mod.idx, "should_use_smart_reindexing", lambda *a, **k: (False, "changed")) + monkeypatch.setattr(proc_mod.idx, "index_single_file", lambda *a, **k: False) + monkeypatch.setattr(proc_mod.idx, "is_file_locked", lambda *_: True) + + verify_context = {} + ok = proc_mod._run_indexing_strategy( + path, + client=MagicMock(), + model=MagicMock(), + collection="coll", + vector_name="vec", + model_dim=1, + repo_name="repo", + force_upsert=True, + verify_context=verify_context, + ) + + assert ok is False + assert verify_context.get("skip_verify_reason") == "file_locked" + + +def test_finalize_journal_skips_force_upsert_verify_when_file_locked(monkeypatch): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + verify_mock = MagicMock() + done_mock = MagicMock() + failed_mock = MagicMock() + monkeypatch.setattr(proc_mod, "_verify_and_update_journal_for_upsert", verify_mock) + monkeypatch.setattr(proc_mod, "_mark_journal_done", done_mock) + monkeypatch.setattr(proc_mod, "_mark_journal_failed", failed_mock) + + proc_mod._finalize_journal_after_index_attempt( + Path("/tmp/file.py"), + client=MagicMock(), + collection="coll", + repo_key="/tmp", + repo_name="repo", + force_upsert=True, + journal_content_hash="abc", + skip_verify_reason="file_locked", + ) + + verify_mock.assert_not_called() + done_mock.assert_not_called() + failed_mock.assert_not_called() + + +def test_staging_requires_subprocess_only_for_active_dual_root_state(monkeypatch): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: True) + + assert proc_mod._staging_requires_subprocess(None) is False + assert ( + proc_mod._staging_requires_subprocess( + { + "indexing_env": {"FOO": "bar"}, + "active_repo_slug": "repo", + "serving_repo_slug": "repo", + } + ) + is False + ) + assert ( + proc_mod._staging_requires_subprocess( + { + "indexing_env": {"FOO": "bar"}, + "active_repo_slug": "repo", + "serving_repo_slug": "repo_old", + } + ) + is True + ) + assert ( + proc_mod._staging_requires_subprocess( + { + "indexing_env": {"FOO": "bar"}, + "active_repo_slug": "repo", + "serving_repo_slug": "repo", + "staging": {"collection": "repo_old_collection"}, + } + ) + is True + ) + + +def test_process_paths_does_not_force_subprocess_for_non_active_staging( + monkeypatch, tmp_path +): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: True) + monkeypatch.setattr( + proc_mod, + "get_workspace_state", + lambda *a, **k: { + "indexing_env": {"FOO": "bar"}, + "active_repo_slug": "repo", + "serving_repo_slug": "repo", + }, + ) + + staging_mock = MagicMock(return_value=False) + monkeypatch.setattr(proc_mod, "_maybe_handle_staging_file", staging_mock) + monkeypatch.setattr(proc_mod, "_run_indexing_strategy", lambda *a, **k: True) + + proc_mod._process_paths( + [path], + client=MagicMock(), + model=MagicMock(), + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + assert staging_mock.call_args is not None + assert staging_mock.call_args.kwargs == { + "force_upsert": False, + "journal_content_hash": "", + } + assert staging_mock.call_args.args[0] == path + assert staging_mock.call_args.args[6] is None + + +def test_process_paths_uses_subprocess_when_staging_is_actually_active( + monkeypatch, tmp_path +): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_detect_repo_for_file", lambda p: tmp_path) + monkeypatch.setattr(proc_mod, "_get_collection_for_file", lambda p: "coll") + monkeypatch.setattr(proc_mod, "_set_status_indexing", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + monkeypatch.setattr(proc_mod, "_extract_repo_name_from_path", lambda *_: "repo") + monkeypatch.setattr(proc_mod, "is_staging_enabled", lambda: True) + monkeypatch.setattr( + proc_mod, + "get_workspace_state", + lambda *a, **k: { + "indexing_env": {"FOO": "bar"}, + "active_repo_slug": "repo", + "serving_repo_slug": "repo_old", + }, + ) + + staging_mock = MagicMock(return_value=False) + monkeypatch.setattr(proc_mod, "_maybe_handle_staging_file", staging_mock) + monkeypatch.setattr(proc_mod, "_run_indexing_strategy", lambda *a, **k: True) + + proc_mod._process_paths( + [path], + client=MagicMock(), + model=MagicMock(), + vector_name="vec", + model_dim=1, + workspace_path=str(tmp_path), + ) + + assert staging_mock.call_args is not None + assert staging_mock.call_args.kwargs == { + "force_upsert": False, + "journal_content_hash": "", + } + assert staging_mock.call_args.args[0] == path + assert staging_mock.call_args.args[6] == {"FOO": "bar"} + + +def test_staging_force_upsert_hash_match_verifies_before_skip(monkeypatch, tmp_path): + proc_mod = importlib.import_module("scripts.watch_index_core.processor") + + path = tmp_path / "file.py" + path.write_text("print('x')\n", encoding="utf-8") + + monkeypatch.setattr(proc_mod, "_read_text_and_sha1", lambda _p: ("print('x')\n", "abc123")) + monkeypatch.setattr(proc_mod, "get_cached_file_hash", lambda *a, **k: "abc123") + monkeypatch.setattr(proc_mod, "_verify_upsert_committed", lambda *a, **k: True) + monkeypatch.setattr(proc_mod, "_log_activity", lambda *a, **k: None) + + mark_done = MagicMock() + monkeypatch.setattr(proc_mod, "_mark_journal_done", mark_done) + advance = MagicMock() + monkeypatch.setattr(proc_mod, "_advance_progress", advance) + + handled = proc_mod._maybe_handle_staging_file( + path, + MagicMock(), + "coll", + "repo", + str(tmp_path), + [path], + {"FOO": "bar"}, + {str(tmp_path): 0}, + "started", + force_upsert=True, + journal_content_hash="abc123", + ) + + assert handled is True + mark_done.assert_called_once_with(path, str(tmp_path), "repo") + advance.assert_called_once() + + +def test_runtime_root_override_updates_internal_path_checks(monkeypatch, tmp_path): + import scripts.watch_index as watch_index + from scripts.watch_index_core import config as watch_config + import scripts.watch_index_core.processor as proc_mod + import scripts.embedder as embedder_mod + + runtime_root = tmp_path / "runtime-root" + runtime_root.mkdir(parents=True, exist_ok=True) + internal = runtime_root / ".git" / "HEAD" + internal.parent.mkdir(parents=True, exist_ok=True) + internal.write_text("ref: refs/heads/main\n", encoding="utf-8") + + original_root = watch_config.ROOT + original_watch_root = watch_index.ROOT + monkeypatch.setenv("WATCH_ROOT", str(runtime_root)) + monkeypatch.setattr(watch_index, "initialize_watcher_state", lambda root: {"repo_name": None}) + monkeypatch.setattr(watch_index, "get_indexing_config_snapshot", lambda repo_name=None: {}) + monkeypatch.setattr(watch_index, "compute_indexing_config_hash", lambda snapshot: "hash") + monkeypatch.setattr(watch_index, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(watch_index, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(embedder_mod, "get_embedding_model", lambda *_: MagicMock()) + monkeypatch.setattr(embedder_mod, "get_model_dimension", lambda *_: 1) + monkeypatch.setattr(watch_index, "resolve_vector_name_config", lambda *a, **k: "vec") + monkeypatch.setattr(watch_index, "_start_pseudo_backfill_worker", lambda *a, **k: None) + monkeypatch.setattr(watch_index, "create_observer", lambda *a, **k: MagicMock()) + monkeypatch.setattr(watch_index, "IndexHandler", MagicMock()) + monkeypatch.setattr(watch_index, "ChangeQueue", MagicMock()) + monkeypatch.setattr( + watch_index, + "QdrantClient", + MagicMock(return_value=MagicMock(get_collection=MagicMock())), + ) + monkeypatch.setattr(watch_index, "run_consistency_audit", lambda *a, **k: None) + monkeypatch.setattr(watch_index, "run_empty_dir_sweep_maintenance", lambda *a, **k: None) + monkeypatch.setattr(watch_index, "list_pending_index_journal_entries", lambda *a, **k: []) + def _bool_env(name, default=False): + if name == "WATCH_JOURNAL_DRAIN_ENABLED": + return True + return False + + monkeypatch.setattr(watch_index, "get_boolean_env", _bool_env) + monkeypatch.setattr(watch_index, "_sleep", lambda *_: (_ for _ in ()).throw(KeyboardInterrupt())) + + try: + watch_index.main() + except KeyboardInterrupt: + pass + + try: + assert watch_config.ROOT == runtime_root.resolve() + assert proc_mod._is_internal_ignored_path(internal) is True + finally: + watch_config.ROOT = original_root + watch_index.ROOT = original_watch_root + + +def test_main_throttles_periodic_maintenance(monkeypatch, tmp_path): + import scripts.watch_index as watch_index + from scripts.watch_index_core import config as watch_config + import scripts.embedder as embedder_mod + + runtime_root = tmp_path / "runtime-root" + runtime_root.mkdir(parents=True, exist_ok=True) + + original_root = watch_config.ROOT + original_watch_root = watch_index.ROOT + monkeypatch.setenv("WATCH_ROOT", str(runtime_root)) + monkeypatch.setenv("WATCH_MAINTENANCE_INTERVAL_SECS", "300") + monkeypatch.setenv("WATCH_INIT_MAINTENANCE_ENABLED", "0") + monkeypatch.setattr(watch_index, "initialize_watcher_state", lambda *a, **k: {"repo_name": None}) + monkeypatch.setattr(watch_index, "get_indexing_config_snapshot", lambda repo_name=None: {}) + monkeypatch.setattr(watch_index, "compute_indexing_config_hash", lambda snapshot: "hash") + monkeypatch.setattr(watch_index, "persist_indexing_config", lambda *a, **k: None) + monkeypatch.setattr(watch_index, "update_indexing_status", lambda *a, **k: None) + monkeypatch.setattr(embedder_mod, "get_embedding_model", lambda *_: MagicMock()) + monkeypatch.setattr(embedder_mod, "get_model_dimension", lambda *_: 1) + monkeypatch.setattr(watch_index, "resolve_vector_name_config", lambda *a, **k: "vec") + monkeypatch.setattr(watch_index, "_start_pseudo_backfill_worker", lambda *a, **k: None) + + class FakeObserver: + def schedule(self, *a, **k): + return None + + def start(self): + return None + + def stop(self): + return None + + def join(self): + return None + + monkeypatch.setattr(watch_index, "create_observer", lambda *a, **k: FakeObserver()) + monkeypatch.setattr(watch_index, "IndexHandler", MagicMock()) + monkeypatch.setattr(watch_index, "ChangeQueue", MagicMock()) + monkeypatch.setattr( + watch_index, + "QdrantClient", + MagicMock(return_value=MagicMock(get_collection=MagicMock())), + ) + def _bool_env(name, default=False): + if name == "WATCH_JOURNAL_DRAIN_ENABLED": + return True + return False + + monkeypatch.setattr(watch_index, "get_boolean_env", _bool_env) + + drain_mock = MagicMock() + maintenance_mock = MagicMock() + monkeypatch.setattr(watch_index, "_drain_pending_journal", drain_mock) + monkeypatch.setattr(watch_index, "_run_periodic_maintenance", maintenance_mock) + + time_values = iter([0.0, 1.0, 2.0, 301.0]) + monkeypatch.setattr(watch_index.time, "time", lambda: next(time_values)) + + sleep_calls = {"count": 0} + + def _sleep(_secs): + sleep_calls["count"] += 1 + if sleep_calls["count"] >= 4: + raise KeyboardInterrupt() + + monkeypatch.setattr(watch_index, "_sleep", _sleep) + + try: + watch_index.main() + finally: + watch_config.ROOT = original_root + watch_index.ROOT = original_watch_root + + assert drain_mock.call_count == 4 + assert maintenance_mock.call_count == 2 + + +def test_watch_source_defaults_follow_repo_mode(monkeypatch): + import scripts.watch_index as watch_index + + monkeypatch.delenv("WATCH_JOURNAL_DRAIN_ENABLED", raising=False) + monkeypatch.delenv("WATCH_FS_EVENTS_ENABLED", raising=False) + + assert watch_index._journal_drain_enabled(True) is True + assert watch_index._fs_events_enabled(True) is False + assert watch_index._journal_drain_enabled(False) is False + assert watch_index._fs_events_enabled(False) is True + + +def test_watch_source_env_overrides_defaults(monkeypatch): + import scripts.watch_index as watch_index + + monkeypatch.setenv("WATCH_JOURNAL_DRAIN_ENABLED", "0") + monkeypatch.setenv("WATCH_FS_EVENTS_ENABLED", "1") + + assert watch_index._journal_drain_enabled(True) is False + assert watch_index._fs_events_enabled(True) is True diff --git a/tests/test_watch_index_git_history.py b/tests/test_watch_index_git_history.py new file mode 100644 index 00000000..08cac353 --- /dev/null +++ b/tests/test_watch_index_git_history.py @@ -0,0 +1,48 @@ +import io +import subprocess + +import pytest + + +pytestmark = pytest.mark.unit + + +def test_git_history_ingest_runs_as_package_module(monkeypatch, tmp_path): + from scripts.watch_index_core import processor + from scripts.watch_index_core import config as watch_config + + manifest = tmp_path / "git_history.json" + manifest.write_text('{"commits": []}', encoding="utf-8") + + captured = {} + + class FakePopen: + def __init__(self, cmd, **kwargs): + captured["cmd"] = cmd + captured["cwd"] = kwargs.get("cwd") + captured["env"] = kwargs.get("env") + self.stdout = io.StringIO("") + self.stderr = io.StringIO("") + + def poll(self): + return 0 + + def wait(self, timeout=None): + return 0 + + def kill(self): + pass + + monkeypatch.setattr(subprocess, "Popen", FakePopen) + + processor._run_git_history_ingest( + manifest, + collection="Context-Engine-41e67959", + repo_name="Context-Engine-41e67959950c8ab3", + ) + + assert captured["cmd"][:3] == [processor.sys.executable or "python3", "-m", "scripts.ingest_history"] + assert "--manifest-json" in captured["cmd"] + assert captured["cwd"] == str(watch_config.ROOT_DIR) + assert captured["env"]["COLLECTION_NAME"] == "Context-Engine-41e67959" + assert captured["env"]["REPO_NAME"] == "Context-Engine-41e67959950c8ab3" diff --git a/tests/test_watch_init_maintenance.py b/tests/test_watch_init_maintenance.py new file mode 100644 index 00000000..9c9248e1 --- /dev/null +++ b/tests/test_watch_init_maintenance.py @@ -0,0 +1,71 @@ +import importlib +import subprocess + + +def test_init_maintenance_interval_defaults_to_two_hours(monkeypatch): + monkeypatch.delenv("WATCH_INIT_MAINTENANCE_INTERVAL_MINUTES", raising=False) + monkeypatch.delenv("INIT_MAINTENANCE_INTERVAL_MINUTES", raising=False) + + mod = importlib.import_module("scripts.watch_index_core.init_maintenance") + mod = importlib.reload(mod) + + assert mod._interval_seconds() == 120 * 60 + + +def test_init_maintenance_runs_existing_scripts_under_lock(monkeypatch, tmp_path): + mod = importlib.import_module("scripts.watch_index_core.init_maintenance") + mod = importlib.reload(mod) + + calls = [] + + def fake_run(command, **kwargs): + calls.append((command, kwargs)) + return subprocess.CompletedProcess(command, 0, stdout="ok", stderr="") + + monkeypatch.setattr(mod.subprocess, "run", fake_run) + monkeypatch.setenv("WATCH_INIT_MAINTENANCE_COMMAND_TIMEOUT_SECS", "7") + + commands = [ + ["wait-for-qdrant.sh"], + ["python", "create_indexes.py"], + ["python", "warm_all_collections.py"], + ["python", "health_check.py"], + ] + + ok = mod.run_init_maintenance_once(commands=commands, lock_path=tmp_path / "init.lock") + + assert ok is True + assert [call[0] for call in calls] == commands + assert all(call[1]["timeout"] == 7 for call in calls) + assert all(call[1]["check"] is False for call in calls) + assert all("PYTHONPATH" in call[1]["env"] for call in calls) + + +def test_init_maintenance_stops_sequence_on_failure(monkeypatch, tmp_path): + mod = importlib.import_module("scripts.watch_index_core.init_maintenance") + mod = importlib.reload(mod) + + calls = [] + + def fake_run(command, **kwargs): + calls.append(command) + return subprocess.CompletedProcess(command, 1, stdout="", stderr="boom") + + monkeypatch.setattr(mod.subprocess, "run", fake_run) + + ok = mod.run_init_maintenance_once( + commands=[["first"], ["second"]], + lock_path=tmp_path / "init.lock", + ) + + assert ok is False + assert calls == [["first"]] + + +def test_init_maintenance_worker_can_be_disabled(monkeypatch): + mod = importlib.import_module("scripts.watch_index_core.init_maintenance") + mod = importlib.reload(mod) + + monkeypatch.setenv("WATCH_INIT_MAINTENANCE_ENABLED", "0") + + assert mod.start_init_maintenance_worker() is None diff --git a/tests/test_watch_queue.py b/tests/test_watch_queue.py new file mode 100644 index 00000000..7f02f25d --- /dev/null +++ b/tests/test_watch_queue.py @@ -0,0 +1,110 @@ +def test_change_queue_suppresses_recent_identical_fingerprint(monkeypatch, tmp_path): + from scripts.watch_index_core import queue as queue_mod + + monkeypatch.setattr(queue_mod, "RECENT_FINGERPRINT_TTL_SECS", 10.0) + + processed = [] + q = queue_mod.ChangeQueue(lambda paths: processed.append(list(paths))) + + p = tmp_path / "file.py" + p.write_text("print('x')\n", encoding="utf-8") + + q._paths.add(p) + q._flush() + assert processed == [[p]] + + q._paths.add(p) + q._flush() + assert processed == [[p]] + + +def test_change_queue_reprocesses_when_fingerprint_changes(monkeypatch, tmp_path): + from scripts.watch_index_core import queue as queue_mod + + monkeypatch.setattr(queue_mod, "RECENT_FINGERPRINT_TTL_SECS", 10.0) + + processed = [] + q = queue_mod.ChangeQueue(lambda paths: processed.append(list(paths))) + + p = tmp_path / "file.py" + p.write_text("print('x')\n", encoding="utf-8") + + q._paths.add(p) + q._flush() + + p.write_text("print('changed-again')\n", encoding="utf-8") + q._paths.add(p) + q._flush() + + assert processed == [[p], [p]] + + +def test_change_queue_force_bypasses_recent_fingerprint_suppression(monkeypatch, tmp_path): + from scripts.watch_index_core import queue as queue_mod + + monkeypatch.setattr(queue_mod, "RECENT_FINGERPRINT_TTL_SECS", 10.0) + + processed = [] + q = queue_mod.ChangeQueue(lambda paths: processed.append(list(paths))) + + p = tmp_path / "file.py" + p.write_text("print('x')\n", encoding="utf-8") + + q.add(p) + q._flush() + q.add(p, force=True) + q._flush() + + assert processed == [[p], [p]] + + +def test_change_queue_repeated_same_path_does_not_rearm_timer(monkeypatch, tmp_path): + from scripts.watch_index_core import queue as queue_mod + + class FakeTimer: + created = 0 + canceled = 0 + + def __init__(self, _delay, _cb): + FakeTimer.created += 1 + self.daemon = False + + def start(self): + return None + + def cancel(self): + FakeTimer.canceled += 1 + + monkeypatch.setattr(queue_mod.threading, "Timer", FakeTimer) + + q = queue_mod.ChangeQueue(lambda _paths: None) + p = tmp_path / "file.py" + p.write_text("print('x')\n", encoding="utf-8") + + q.add(p, force=True) + q.add(p, force=True) + q.add(p, force=True) + + assert FakeTimer.created == 1 + assert FakeTimer.canceled == 0 + + +def test_change_queue_stats_reports_backlog(tmp_path): + from scripts.watch_index_core import queue as queue_mod + + q = queue_mod.ChangeQueue(lambda _paths: None) + p = tmp_path / "file.py" + p.write_text("print('x')\n", encoding="utf-8") + + q._paths.add(p) + q._forced_paths.add(p) + q._pending.add(p.with_name("pending.py")) + q._pending_forced.add(p.with_name("pending.py")) + + assert q.stats() == { + "queued": 1, + "pending": 1, + "forced": 1, + "pending_forced": 1, + "processing": False, + } diff --git a/tests/test_watcher_collection_resolution.py b/tests/test_watcher_collection_resolution.py index fa3d0c1a..cc9e464a 100644 --- a/tests/test_watcher_collection_resolution.py +++ b/tests/test_watcher_collection_resolution.py @@ -6,9 +6,9 @@ pytestmark = pytest.mark.unit def test_main_resolves_collection_from_state(monkeypatch, tmp_path): - # Env setup: placeholder collection name at startup + # Env setup: default collection name at startup monkeypatch.setenv("WATCH_ROOT", str(tmp_path)) - monkeypatch.setenv("COLLECTION_NAME", "my-collection") + monkeypatch.setenv("COLLECTION_NAME", "codebase") monkeypatch.setenv("QDRANT_URL", "http://localhost:6333") monkeypatch.setenv("EMBEDDING_MODEL", "fake") @@ -17,6 +17,9 @@ def test_main_resolves_collection_from_state(monkeypatch, tmp_path): wi = importlib.import_module("scripts.watch_index") # Reload to re-read env defaults (COLLECTION) in module globals wi = importlib.reload(wi) + watch_config = importlib.import_module("scripts.watch_index_core.config") + original_root = watch_config.ROOT + original_watch_root = wi.ROOT # Fake QdrantClient: force get_collection to raise so code chooses sanitized vector name path class FakeQdrant: @@ -64,16 +67,20 @@ def join(self): # Make the main loop exit immediately by raising KeyboardInterrupt on sleep def _raise_kb(_): raise KeyboardInterrupt() - monkeypatch.setattr(wi.time, "sleep", _raise_kb, raising=True) + monkeypatch.setattr(wi, "_sleep", _raise_kb, raising=True) - # Precondition: module-level COLLECTION should reflect placeholder at import time - assert wi.COLLECTION == os.environ.get("COLLECTION_NAME") == "my-collection" + # Precondition: module-level COLLECTION should reflect the configured default at import time + assert wi.COLLECTION == os.environ.get("COLLECTION_NAME") == "codebase" # Run main(); in single-repo mode it should keep the env-provided COLLECTION_NAME - wi.main() + try: + wi.main() - # Postcondition: global COLLECTION remains the env-provided name - assert wi.COLLECTION == "my-collection" + # Postcondition: global COLLECTION remains the env-provided name + assert wi.COLLECTION == "codebase" + finally: + watch_config.ROOT = original_root + wi.ROOT = original_watch_root def test_multi_repo_ignores_placeholder_collection_in_state(monkeypatch, tmp_path): @@ -85,7 +92,8 @@ def test_multi_repo_ignores_placeholder_collection_in_state(monkeypatch, tmp_pat utils = importlib.import_module("scripts.watch_index_core.utils") utils = importlib.reload(utils) - monkeypatch.setattr(utils, "ROOT", tmp_path, raising=False) + watch_config = importlib.import_module("scripts.watch_index_core.config") + monkeypatch.setattr(watch_config, "ROOT", tmp_path, raising=True) monkeypatch.setattr(utils, "is_multi_repo_mode", lambda: True, raising=True) repo_slug = "Pirate Survivors-2b23a7e45f2c4b9f" @@ -111,4 +119,3 @@ def _fake_get_workspace_state(ws_path: str, repo_name: str | None = None): resolved = utils._get_collection_for_file(target) assert resolved == f"derived-{repo_slug}" - diff --git a/tests/test_watcher_events.py b/tests/test_watcher_events.py index f01484e9..658366bc 100644 --- a/tests/test_watcher_events.py +++ b/tests/test_watcher_events.py @@ -71,6 +71,25 @@ def test_on_moved_enqueues_new_dest(monkeypatch, tmp_path): assert any(s.endswith("/b.py") for s in q.added) +@pytest.mark.unit +def test_on_moved_ignores_internal_codebase_paths(monkeypatch, tmp_path): + monkeypatch.setenv("MULTI_REPO_MODE", "0") + q = FakeQueue() + handler = wi.IndexHandler(root=tmp_path, queue=q, client=FakeClient(), collection="c") + + codebase = tmp_path / ".codebase" + codebase.mkdir(parents=True, exist_ok=True) + src = codebase / "state.json" + dst = codebase / "file_locks" / "abc.lock" + src.write_text("{}\n") + dst.parent.mkdir(parents=True, exist_ok=True) + dst.write_text("lock\n") + + handler.on_moved(E(src, dest=dst)) + + assert q.added == [] + + @pytest.mark.unit def test_ignore_reload_rebuilds_excluder(monkeypatch, tmp_path): monkeypatch.setenv("MULTI_REPO_MODE", "0") @@ -105,4 +124,3 @@ def test_remote_git_manifest_is_enqueued_even_if_excluded(monkeypatch, tmp_path) handler.on_created(E(manifest)) assert any(p.endswith("/.remote-git/git_history_test.json") for p in q.added) - diff --git a/tests/test_workspace_state.py b/tests/test_workspace_state.py index 1200a270..c78188cb 100644 --- a/tests/test_workspace_state.py +++ b/tests/test_workspace_state.py @@ -129,10 +129,10 @@ def test_multi_repo_does_not_hard_override_with_env(self, ws_module, monkeypatch def test_single_repo_env_override_preserved(self, ws_module, monkeypatch): """In single-repo mode, COLLECTION_NAME remains a master override.""" monkeypatch.delenv("MULTI_REPO_MODE", raising=False) - monkeypatch.setenv("COLLECTION_NAME", "codebase") + monkeypatch.setenv("COLLECTION_NAME", "custom") ws = importlib.reload(ws_module) - assert ws.get_collection_name("my-repo_old") == "codebase_old" + assert ws.get_collection_name("my-repo_old") == "custom_old" def test_multi_repo_workspace_level_env_override_still_applies(self, ws_module, monkeypatch): """When repo_name is None, env override should still apply even in multi-repo mode.""" @@ -142,6 +142,38 @@ def test_multi_repo_workspace_level_env_override_still_applies(self, ws_module, assert ws.get_collection_name(None) == "codebase" + def test_multi_repo_workspace_root_path_uses_configured_collection(self, ws_module, monkeypatch, tmp_path): + """The multi-repo workspace root is not a repository identity.""" + ws_root = tmp_path / "work" + ws_root.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("COLLECTION_NAME", "context-engine") + ws = importlib.reload(ws_module) + + assert ws.get_collection_name(str(ws_root)) == "context-engine" + + def test_multi_repo_upload_managed_detection_does_not_probe_git(self, ws_module, monkeypatch, tmp_path): + """Upload-managed multi-repo identity comes from workspace path, not git metadata.""" + ws_root = tmp_path / "work" + repo_root = ws_root / "repo-a" + (repo_root / ".git").mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.delenv("CTXCE_BINDMOUNT_REPO_DETECTION", raising=False) + ws = importlib.reload(ws_module) + + monkeypatch.setattr( + ws, + "_git_remote_repo_name", + lambda *_: pytest.fail("git inference should be disabled for upload-managed mode"), + ) + + assert ws._extract_repo_name_from_path(str(repo_root)) == "repo-a" + assert ws._extract_repo_name_from_path(str(ws_root)) == "" + # ============================================================================ # Tests: Environment Variable Helpers @@ -430,6 +462,166 @@ def test_state_filename(self, ws_module): def test_placeholder_collection_names(self, ws_module): """PLACEHOLDER_COLLECTION_NAMES contains expected values.""" - assert "" in ws_module.PLACEHOLDER_COLLECTION_NAMES - assert "default-collection" in ws_module.PLACEHOLDER_COLLECTION_NAMES - assert "my-collection" in ws_module.PLACEHOLDER_COLLECTION_NAMES + assert ws_module.PLACEHOLDER_COLLECTION_NAMES == {"", "codebase"} + + +class TestCompareSymbolChanges: + def test_compare_symbol_changes_tolerates_line_shift_for_unchanged_content(self, ws_module): + old_symbols = { + "function_foo_10": { + "name": "foo", + "type": "function", + "start_line": 10, + "end_line": 20, + "content_hash": "samehash", + } + } + new_symbols = { + "function_foo_12": { + "name": "foo", + "type": "function", + "start_line": 12, + "end_line": 22, + "content_hash": "samehash", + } + } + + unchanged, changed = ws_module.compare_symbol_changes(old_symbols, new_symbols) + + assert unchanged == ["function_foo_12"] + assert changed == [] + + +class TestSymbolCachePaths: + def test_symbol_cache_uses_shared_repo_state_dir_in_multi_repo_mode(self, monkeypatch, tmp_path): + ws_root = tmp_path / "work" + repo_name = "repo-1234567890abcdef" + repo_root = ws_root / repo_name + repo_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + + import importlib + + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + file_path = repo_root / "src" / "app.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text("print('x')\n", encoding="utf-8") + + expected_hash = ws_module.hashlib.md5( + str(file_path.resolve()).encode("utf-8") + ).hexdigest()[:8] + cache_path = ws_module._get_symbol_cache_path(str(file_path)) + + assert cache_path == ( + ws_root + / ".codebase" + / "repos" + / repo_name + / "symbols" + / f"{expected_hash}.json" + ) + + def test_symbol_cache_write_uses_cross_user_writable_mode(self, monkeypatch, tmp_path): + ws_root = tmp_path / "work" + repo_name = "repo-1234567890abcdef" + repo_root = ws_root / repo_name + repo_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + + import importlib + + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + file_path = repo_root / "src" / "cacheme.py" + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text("print('x')\n", encoding="utf-8") + + ws_module.set_cached_symbols(str(file_path), {"sym": {"name": "sym"}}, "abc123") + cache_path = ws_module._get_symbol_cache_path(str(file_path)) + + assert cache_path.exists() + if os.name == "nt": + pytest.skip("POSIX permission bits are not stable on Windows") + dir_mode = cache_path.parent.stat().st_mode & 0o777 + file_mode = cache_path.stat().st_mode & 0o777 + assert dir_mode & 0o700 == 0o700 + assert file_mode & 0o600 == 0o600 + + +class TestCollectionMappings: + def test_get_collection_mappings_accepts_codebase_root_search_path(self, monkeypatch, tmp_path): + ws_root = tmp_path / "work" + ws_root.mkdir(parents=True, exist_ok=True) + slug = "repo-1234567890abcdef" + global_state_dir = ws_root / ".codebase" / "repos" / slug + global_state_dir.mkdir(parents=True, exist_ok=True) + global_state_path = global_state_dir / "state.json" + global_state_path.write_text( + json.dumps( + { + "qdrant_collection": "repo-123456-abcdef", + "updated_at": "2026-03-08T00:00:00", + } + ), + encoding="utf-8", + ) + + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + + import importlib + + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + mappings = ws_module.get_collection_mappings(search_root=str(ws_root / ".codebase")) + slug_entries = [m for m in mappings if str(m.get("repo_name")) == slug] + + assert slug_entries, "expected global repo mapping to be discovered from codebase root" + entry = slug_entries[0] + assert entry["collection_name"] == "repo-123456-abcdef" + assert Path(entry["state_file"]).resolve() == global_state_path.resolve() + + def test_get_collection_mappings_keeps_global_repo_state_behavior(self, monkeypatch, tmp_path): + ws_root = tmp_path / "work" + ws_root.mkdir(parents=True, exist_ok=True) + repo_name = "frontend" + global_state_dir = ws_root / ".codebase" / "repos" / repo_name + global_state_dir.mkdir(parents=True, exist_ok=True) + global_state_path = global_state_dir / "state.json" + global_state_path.write_text( + json.dumps( + { + "qdrant_collection": "frontend-abcdef", + "updated_at": "2026-03-08T00:00:00", + } + ), + encoding="utf-8", + ) + + monkeypatch.setenv("WORKSPACE_PATH", str(ws_root)) + monkeypatch.setenv("WATCH_ROOT", str(ws_root)) + monkeypatch.setenv("MULTI_REPO_MODE", "1") + + import importlib + + ws_module = importlib.import_module("scripts.workspace_state") + ws_module = importlib.reload(ws_module) + + mappings = ws_module.get_collection_mappings(search_root=str(ws_root)) + repo_entries = [m for m in mappings if str(m.get("repo_name")) == repo_name] + + assert repo_entries, "expected global repo mapping to be discovered" + entry = repo_entries[0] + assert entry["collection_name"] == "frontend-abcdef" + assert Path(entry["state_file"]).resolve() == global_state_path.resolve() diff --git a/vscode-extension/build/build.bat b/vscode-extension/build/build.bat index 8db62e3f..26b616e3 100644 --- a/vscode-extension/build/build.bat +++ b/vscode-extension/build/build.bat @@ -15,9 +15,7 @@ set "STAGE_DIR=%OUT_DIR%\extension-stage" set "BUILD_RESULT=0" for %%I in ("..\..\ctx-hook-simple.sh") do set "HOOK_SRC=%%~fI" for %%I in ("..\..\scripts\ctx.py") do set "CTX_SRC=%%~fI" -for %%I in ("..\..\scripts\mcp_router.py") do set "ROUTER_SRC=%%~fI" for %%I in ("..\..\scripts\refrag_glm.py") do set "REFRAG_SRC=%%~fI" -for %%I in ("..\..\scripts\mcp_router.py") do set "ROUTER_SRC=%%~fI" for %%I in ("..\..\.env.example") do set "ENV_EXAMPLE_SRC=%%~fI" echo Building clean Context Engine Uploader extension... @@ -64,9 +62,7 @@ if errorlevel 1 ( REM Bundle ctx hook script and ctx CLI into the staged extension for reference if exist "%HOOK_SRC%" copy /Y "%HOOK_SRC%" "%STAGE_DIR%\ctx-hook-simple.sh" >nul if exist "%CTX_SRC%" copy /Y "%CTX_SRC%" "%STAGE_DIR%\ctx.py" >nul -if exist "%ROUTER_SRC%" copy /Y "%ROUTER_SRC%" "%STAGE_DIR%\mcp_router.py" >nul if exist "%REFRAG_SRC%" copy /Y "%REFRAG_SRC%" "%STAGE_DIR%\refrag_glm.py" >nul -if exist "%ROUTER_SRC%" copy /Y "%ROUTER_SRC%" "%STAGE_DIR%\mcp_router.py" >nul if exist "%ENV_EXAMPLE_SRC%" copy /Y "%ENV_EXAMPLE_SRC%" "%STAGE_DIR%\env.example" >nul REM Optional: bundle Python dependencies into the staged extension when requested diff --git a/vscode-extension/build/build.sh b/vscode-extension/build/build.sh index f3e4d9fa..13ad8614 100755 --- a/vscode-extension/build/build.sh +++ b/vscode-extension/build/build.sh @@ -7,11 +7,9 @@ OUT_DIR="$SCRIPT_DIR/../out" SRC_SCRIPT="$SCRIPT_DIR/../../scripts/standalone_upload_client.py" CLIENT="standalone_upload_client.py" STAGE_DIR="$OUT_DIR/extension-stage" -BUNDLE_DEPS="${1:-}" PYTHON_BIN="${PYTHON_BIN:-python3}" HOOK_SRC="$SCRIPT_DIR/../../ctx-hook-simple.sh" CTX_SRC="$SCRIPT_DIR/../../scripts/ctx.py" -ROUTER_SRC="$SCRIPT_DIR/../../scripts/mcp_router.py" REFRAG_SRC="$SCRIPT_DIR/../../scripts/refrag_glm.py" ENV_EXAMPLE_SRC="$SCRIPT_DIR/../../.env.example" AUTH_SRC="$SCRIPT_DIR/../../scripts/upload_auth_utils.py" @@ -48,9 +46,6 @@ fi if [[ -f "$CTX_SRC" ]]; then cp "$CTX_SRC" "$STAGE_DIR/ctx.py" fi -if [[ -f "$ROUTER_SRC" ]]; then - cp "$ROUTER_SRC" "$STAGE_DIR/mcp_router.py" -fi if [[ -f "$REFRAG_SRC" ]]; then cp "$REFRAG_SRC" "$STAGE_DIR/refrag_glm.py" fi @@ -64,16 +59,47 @@ if [[ -f "$ENV_EXAMPLE_SRC" ]]; then cp "$ENV_EXAMPLE_SRC" "$STAGE_DIR/env.example" fi -# Optional: bundle Python deps into the staged extension when requested -if [[ "$BUNDLE_DEPS" == "--bundle-deps" ]]; then - echo "Bundling Python dependencies into staged extension using $PYTHON_BIN..." - # On macOS, urllib3 v2 + system LibreSSL emits NotOpenSSLWarning; pin <2 there. - if [[ "$(uname -s)" == "Darwin" ]]; then - echo "Detected macOS; pinning urllib3<2 to avoid LibreSSL/OpenSSL warning." - "$PYTHON_BIN" -m pip install -t "$STAGE_DIR/python_libs" "urllib3<2" requests charset_normalizer "openai>=1.0" watchdog +# Bundle Python deps into the staged extension. Runtime assumes bundled +# python_libs are present and only requires an installed Python interpreter. +echo "Bundling Python dependencies into staged extension using $PYTHON_BIN..." +rm -rf "$STAGE_DIR/python_libs" +# On macOS, urllib3 v2 + system LibreSSL emits NotOpenSSLWarning; pin <2 there. +if [[ "$(uname -s)" == "Darwin" ]]; then + echo "Detected macOS; pinning urllib3<2 to avoid LibreSSL/OpenSSL warning." + "$PYTHON_BIN" -m pip install -t "$STAGE_DIR/python_libs" "urllib3<2" requests charset_normalizer "openai>=1.0" watchdog +else + "$PYTHON_BIN" -m pip install -t "$STAGE_DIR/python_libs" requests urllib3 charset_normalizer "openai>=1.0" watchdog +fi + +# Bundle MCP bridge npm package into the staged extension +BRIDGE_SRC="$SCRIPT_DIR/../../ctx-mcp-bridge" +BRIDGE_DIR="ctx-mcp-bridge" + +if [[ -d "$BRIDGE_SRC" && -f "$BRIDGE_SRC/package.json" ]]; then + echo "Bundling MCP bridge npm package into staged extension..." + mkdir -p "$STAGE_DIR/$BRIDGE_DIR" + if [[ -d "$BRIDGE_SRC/bin" ]]; then + cp -a "$BRIDGE_SRC/bin" "$STAGE_DIR/$BRIDGE_DIR/" + else + echo "Warning: Bridge bin directory not found at $BRIDGE_SRC/bin (skipping)" + fi + if [[ -d "$BRIDGE_SRC/src" ]]; then + cp -a "$BRIDGE_SRC/src" "$STAGE_DIR/$BRIDGE_DIR/" + else + echo "Warning: Bridge src directory not found at $BRIDGE_SRC/src (skipping)" + fi + cp "$BRIDGE_SRC/package.json" "$STAGE_DIR/$BRIDGE_DIR/" + + echo "Installing MCP bridge production dependencies into staged extension..." + if [[ -f "$BRIDGE_SRC/package-lock.json" ]]; then + cp "$BRIDGE_SRC/package-lock.json" "$STAGE_DIR/$BRIDGE_DIR/" + (cd "$STAGE_DIR/$BRIDGE_DIR" && npm ci --omit=dev) else - "$PYTHON_BIN" -m pip install -t "$STAGE_DIR/python_libs" requests urllib3 charset_normalizer "openai>=1.0" watchdog + (cd "$STAGE_DIR/$BRIDGE_DIR" && npm install --omit=dev) fi + echo "MCP bridge bundled successfully." +else + echo "Warning: MCP bridge source not found at $BRIDGE_SRC" fi pushd "$STAGE_DIR" >/dev/null @@ -82,4 +108,4 @@ npx @vscode/vsce package --no-dependencies --out "$OUT_DIR" popd >/dev/null echo "Build complete! Check the /out directory for .vsix and .py files." -ls -la "$OUT_DIR" \ No newline at end of file +ls -la "$OUT_DIR" diff --git a/vscode-extension/build/publish-vscode-extension.sh b/vscode-extension/build/publish-vscode-extension.sh index f019cb92..0d986a4a 100644 --- a/vscode-extension/build/publish-vscode-extension.sh +++ b/vscode-extension/build/publish-vscode-extension.sh @@ -4,7 +4,6 @@ set -euo pipefail SCRIPT_DIR="$(cd -- "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BUILD_SCRIPT="$SCRIPT_DIR/build.sh" OUT_DIR="$SCRIPT_DIR/../out" -BUNDLE_DEPS="${1:-}" if [[ ! -f "$BUILD_SCRIPT" ]]; then echo "Build script not found: $BUILD_SCRIPT" >&2 @@ -18,7 +17,7 @@ fi export VSCE_STORE="${VSCE_STORE:-file}" -"$BUILD_SCRIPT" "$BUNDLE_DEPS" +"$BUILD_SCRIPT" VSIX_PATH="" if compgen -G "$OUT_DIR/*.vsix" >/dev/null; then diff --git a/vscode-extension/context-engine-uploader/README.md b/vscode-extension/context-engine-uploader/README.md index c84a79b6..e02dc982 100644 --- a/vscode-extension/context-engine-uploader/README.md +++ b/vscode-extension/context-engine-uploader/README.md @@ -20,7 +20,7 @@ Configuration - `Run On Startup` auto-triggers force sync + watch after VS Code finishes loading. - `Python Path`, `Endpoint`, `Extra Force Args`, `Extra Watch Args`, and `Interval Seconds` can be tuned via standard VS Code settings. - `Target Path` is auto-filled from the workspace but can be overridden if you need to upload a different folder. -- **Python dependencies:** the extension runs the standalone upload client via your configured `pythonPath`. Ensure the interpreter has `requests`, `urllib3`, `charset_normalizer`, and `watchdog` installed. Run `python3 -m pip install requests urllib3 charset_normalizer watchdog` (or replace `python3` with your configured path) before starting the uploader. +- **Python dependencies:** the extension ships bundled `python_libs` and adds them to `PYTHONPATH` for the upload client. You only need a runnable Python 3 interpreter via `contextEngineUploader.pythonPath`. - **Path mapping:** `Host Root` + `Container Root` control how local paths are rewritten before reaching the remote service. By default the host root mirrors your `Target Path` and the container root is `/work`, which keeps Windows paths working without extra config. - **Prompt+ decoder:** set `Context Engine Uploader: Decoder Url` (default `http://localhost:8081`, auto-appends `/completion`) to point at your local llama.cpp decoder. For Ollama, set it to `http://localhost:11434/api/chat`. Turn on `Use Gpu Decoder` to set `USE_GPU_DECODER=1` so ctx.py prefers the GPU llama.cpp sidecar. Prompt+ automatically runs the bundled `scripts/ctx.py` when an embedded copy is available, falling back to the workspace version if not. - **Claude/Windsurf MCP config:** diff --git a/vscode-extension/context-engine-uploader/config_resolver.js b/vscode-extension/context-engine-uploader/config_resolver.js index fd4ebe07..5dd457ad 100644 --- a/vscode-extension/context-engine-uploader/config_resolver.js +++ b/vscode-extension/context-engine-uploader/config_resolver.js @@ -193,10 +193,8 @@ function createConfigResolver(deps) { const configuredPython = (config.get('pythonPath') || '').trim(); let pythonPath = configuredPython || 'python3'; - let pythonPathSource = configuredPython ? 'configured' : 'default'; if (pythonOverridePath && fs.existsSync(pythonOverridePath)) { pythonPath = pythonOverridePath; - pythonPathSource = 'override'; } const endpoint = (config.get('endpoint') || '').trim(); const targetPath = getTargetPath(config); @@ -263,7 +261,6 @@ function createConfigResolver(deps) { return { pythonPath, - pythonPathSource, workingDirectory, scriptPath, targetPath, diff --git a/vscode-extension/context-engine-uploader/ctx_config.js b/vscode-extension/context-engine-uploader/ctx_config.js index 4d386f42..d70b437b 100644 --- a/vscode-extension/context-engine-uploader/ctx_config.js +++ b/vscode-extension/context-engine-uploader/ctx_config.js @@ -9,7 +9,7 @@ function createCtxConfigManager(deps) { const extensionRoot = deps.extensionRoot; const getEffectiveConfig = deps.getEffectiveConfig; const resolveOptions = deps.resolveOptions; - const ensurePythonDependencies = deps.ensurePythonDependencies; + const ensurePythonReady = deps.ensurePythonReady; const buildChildEnv = deps.buildChildEnv; const resolveBridgeHttpUrl = deps.resolveBridgeHttpUrl; @@ -29,12 +29,8 @@ function createCtxConfigManager(deps) { if (!options) { return; } - const depsOk = await ensurePythonDependencies( - options.pythonPath, - options.workingDirectory, - options.pythonPathSource - ); - if (!depsOk) { + const pythonReady = await ensurePythonReady(options.pythonPath); + if (!pythonReady) { return; } options = resolveOptions() || options; diff --git a/vscode-extension/context-engine-uploader/extension.js b/vscode-extension/context-engine-uploader/extension.js index 9a387c66..9b056082 100644 --- a/vscode-extension/context-engine-uploader/extension.js +++ b/vscode-extension/context-engine-uploader/extension.js @@ -12,7 +12,6 @@ const { createLogsTerminalManager } = require('./logs_terminal'); const { createPromptPlusManager } = require('./prompt_plus'); const { registerPromptPlusCommands } = require('./prompt_plus_commands'); const { createOnboardingManager } = require('./onboarding'); -const { createPythonEnvManager } = require('./python_env'); const { createProcessManager } = require('./process_manager'); const { registerExtensionCommands } = require('./commands'); const { createConfigResolver } = require('./config_resolver'); @@ -30,11 +29,11 @@ let ctxConfigManager; let promptPlusManager; let onboardingManager; -let pythonEnvManager; let processManager; let configResolver; let sidebarApi; let pendingProfileRestartTimer; +let hasShownPythonError = false; const DEFAULT_CONTAINER_ROOT = '/work'; // const CLAUDE_HOOK_COMMAND = '/home/coder/project/Context-Engine/ctx-hook-simple.sh'; @@ -185,25 +184,6 @@ function activate(context) { log(`Config resolver init failed: ${error instanceof Error ? error.message : String(error)}`); } - try { - pythonEnvManager = createPythonEnvManager({ - vscode, - spawn: spawn, - path, - fs, - log, - getEffectiveConfig, - getWorkspaceFolderPath: () => configResolver ? configResolver.getWorkspaceFolderPath() : undefined, - getExtensionRoot: () => extensionRoot, - getGlobalStoragePath: () => globalStoragePath, - getPythonOverridePath: () => pythonOverridePath, - setPythonOverridePath: (p) => { pythonOverridePath = p; }, - }); - } catch (error) { - pythonEnvManager = undefined; - log(`Python env manager init failed: ${error instanceof Error ? error.message : String(error)}`); - } - try { processManager = createProcessManager({ vscode, @@ -230,11 +210,13 @@ function activate(context) { path, fs, log, + extensionRoot, getEffectiveConfig, resolveBridgeWorkspacePath: () => configResolver ? configResolver.resolveBridgeWorkspacePath() : undefined, attachOutput: (child, label) => processManager ? processManager.attachOutput(child, label) : undefined, terminateProcess: (proc, label, afterStop) => processManager ? processManager.terminateProcess(proc, label, afterStop) : Promise.resolve(), scheduleMcpConfigRefreshAfterBridge: (delay) => mcpConfigManager ? mcpConfigManager.scheduleMcpConfigRefreshAfterBridge(delay) : undefined, + cancelPendingBridgeConfigRefresh: () => mcpConfigManager ? mcpConfigManager.cancelPendingBridgeConfigRefresh() : undefined, }); } catch (error) { bridgeManager = undefined; @@ -249,10 +231,7 @@ function activate(context) { extensionRoot, getEffectiveConfig, resolveOptions: () => configResolver ? configResolver.resolveOptions() : undefined, - ensurePythonDependencies: (pythonPath, workingDirectory, pythonPathSource) => - pythonEnvManager - ? pythonEnvManager.ensurePythonDependencies(pythonPath, workingDirectory, pythonPathSource) - : Promise.resolve(false), + ensurePythonReady, buildChildEnv: (options) => processManager?.buildChildEnv?.(options) ?? {}, resolveBridgeHttpUrl: () => bridgeManager ? bridgeManager.resolveBridgeHttpUrl() : undefined, }); @@ -274,6 +253,7 @@ function activate(context) { resolveBridgeCliInvocation: () => bridgeManager ? bridgeManager.resolveBridgeCliInvocation() : undefined, resolveBridgeHttpUrl: () => bridgeManager ? bridgeManager.resolveBridgeHttpUrl() : undefined, requiresHttpBridge: (s, t) => bridgeManager ? bridgeManager.requiresHttpBridge(s, t) : (s === 'bridge' && t === 'http'), + requiresLocalBridgeProcess: (s, t) => bridgeManager ? bridgeManager.requiresLocalBridgeProcess(s, t) : (s === 'bridge' && (t === 'http' || t === 'sse-remote')), ensureHttpBridgeReadyForConfigs: () => bridgeManager ? bridgeManager.ensureReadyForConfigs() : Promise.resolve(false), getBridgeIsRunning: () => (bridgeManager && typeof bridgeManager.isRunning === 'function' ? bridgeManager.isRunning() : false), writeCtxConfig: () => ctxConfigManager ? ctxConfigManager.writeCtxConfig() : Promise.resolve(), @@ -302,13 +282,6 @@ function activate(context) { } catch (_) { // ignore } - try { - const venvPy = pythonEnvManager ? pythonEnvManager.resolvePrivateVenvPython() : undefined; - if (venvPy) { - pythonOverridePath = venvPy; - log(`Detected existing private venv interpreter: ${venvPy}`); - } - } catch (_) { } statusBarItem = vscode.window.createStatusBarItem(vscode.StatusBarAlignment.Left, 100); statusBarItem.command = 'contextEngineUploader.indexCodebase'; context.subscriptions.push(statusBarItem); @@ -425,6 +398,7 @@ function activate(context) { event.affectsConfiguration('contextEngineUploader.mcpBridgeBinPath') || event.affectsConfiguration('contextEngineUploader.mcpBridgePort') || event.affectsConfiguration('contextEngineUploader.mcpBridgeLocalOnly') || + event.affectsConfiguration('contextEngineUploader.mcpBridgeMode') || event.affectsConfiguration('contextEngineUploader.windsurfMcpPath') || event.affectsConfiguration('contextEngineUploader.augmentMcpPath') || event.affectsConfiguration('contextEngineUploader.antigravityMcpPath') || @@ -439,6 +413,7 @@ function activate(context) { event.affectsConfiguration('contextEngineUploader.mcpBridgePort') || event.affectsConfiguration('contextEngineUploader.mcpBridgeBinPath') || event.affectsConfiguration('contextEngineUploader.mcpBridgeLocalOnly') || + event.affectsConfiguration('contextEngineUploader.mcpBridgeMode') || event.affectsConfiguration('contextEngineUploader.mcpIndexerUrl') || event.affectsConfiguration('contextEngineUploader.mcpMemoryUrl') || event.affectsConfiguration('contextEngineUploader.mcpServerMode') || @@ -484,10 +459,10 @@ function activate(context) { const serverModeRaw = config.get('mcpServerMode') || 'bridge'; const transportMode = (typeof transportModeRaw === 'string' ? transportModeRaw.trim() : 'sse-remote') || 'sse-remote'; const serverMode = (typeof serverModeRaw === 'string' ? serverModeRaw.trim() : 'bridge') || 'bridge'; - if (bridgeManager && bridgeManager.requiresHttpBridge(serverMode, transportMode)) { + if (bridgeManager && bridgeManager.requiresLocalBridgeProcess(serverMode, transportMode)) { startHttpBridgeProcess().catch(error => log(`Auto-start HTTP MCP bridge failed: ${error instanceof Error ? error.message : String(error)}`)); } else { - log('Context Engine Uploader: autoStartMcpBridge is enabled, but current MCP wiring does not use the HTTP bridge; skipping auto-start.'); + log('Context Engine Uploader: autoStartMcpBridge is enabled, but current MCP wiring does not use the local bridge process; skipping auto-start.'); } } } @@ -516,14 +491,12 @@ async function runSequence(mode = 'auto') { log(`Auth preflight check failed: ${error instanceof Error ? error.message : String(error)}`); } - const depsSatisfied = pythonEnvManager - ? await pythonEnvManager.ensurePythonDependencies(options.pythonPath, options.workingDirectory, options.pythonPathSource) - : false; + const depsSatisfied = await ensurePythonReady(options.pythonPath); if (!depsSatisfied) { setStatusBarState('idle'); return; } - // Re-resolve options in case ensurePythonDependencies switched to a better interpreter + // Re-resolve options in case Python preflight selected a better interpreter. const reoptions = configResolver ? configResolver.resolveOptions() : undefined; if (reoptions) { Object.assign(options, reoptions); @@ -540,8 +513,9 @@ async function runSequence(mode = 'auto') { if (code === 0) { setStatusBarState('indexed'); if (processManager) { processManager.ensureIndexedWatcher(options.targetPath); } - // Only start watching after a regular force sync, not after git history upload - if (mode === 'force' && options.startWatchAfterForce && processManager) { + // Start watch after successful force sync in normal flows (`force` and `auto`), + // but keep git-history upload as one-shot. + if (mode !== 'uploadGitHistory' && options.startWatchAfterForce && processManager) { processManager.startWatch(options); } } else { @@ -554,6 +528,42 @@ async function runSequence(mode = 'auto') { } } +function probePython(command, args = []) { + try { + const result = spawnSync(command, [...args, '-c', 'import sys; print(f"{sys.version_info[0]}|{sys.executable}")'], { encoding: 'utf8', timeout: 5000 }); + if (result.status !== 0) return undefined; + const [major, executable] = String(result.stdout || '').trim().split('|'); + return Number.parseInt(major, 10) >= 3 && executable ? executable.trim() : undefined; + } catch (_) { + return undefined; + } +} + +function ensurePythonReady(pythonPath) { + if (pythonOverridePath) return true; + const requested = pythonPath || 'python3'; + const candidates = process.platform === 'win32' + ? [[requested, []], ['py', ['-3']], ['python', []], ['python3', []]] + : [[requested, []], ['python3', []], ['python', []], ['/opt/homebrew/bin/python3', []]]; + const seen = new Set(); + for (const [command, args] of candidates) { + const key = `${command} ${args.join(' ')}`.trim(); + if (!command || seen.has(key)) continue; + seen.add(key); + const executable = probePython(command, args); + if (!executable) continue; + pythonOverridePath = executable; + log(`Using Python interpreter: ${executable}`); + return true; + } + log(`Python preflight failed for ${requested}.`); + if (!hasShownPythonError) { + hasShownPythonError = true; + vscode.window.showErrorMessage(`Context Engine Uploader: Python 3 was not found. Install Python 3 or update contextEngineUploader.pythonPath (current: ${requested}).`); + } + return false; +} + async function startHttpBridgeProcess() { diff --git a/vscode-extension/context-engine-uploader/mcp_bridge.js b/vscode-extension/context-engine-uploader/mcp_bridge.js index d9825177..bbb47fbc 100644 --- a/vscode-extension/context-engine-uploader/mcp_bridge.js +++ b/vscode-extension/context-engine-uploader/mcp_bridge.js @@ -4,18 +4,32 @@ function createBridgeManager(deps) { const path = deps.path; const fs = deps.fs; const log = deps.log; + const extensionRoot = deps.extensionRoot; const getEffectiveConfig = deps.getEffectiveConfig; const resolveBridgeWorkspacePath = deps.resolveBridgeWorkspacePath; const attachOutput = deps.attachOutput; const terminateProcess = deps.terminateProcess; const scheduleMcpConfigRefreshAfterBridge = deps.scheduleMcpConfigRefreshAfterBridge; + const cancelPendingBridgeConfigRefresh = deps.cancelPendingBridgeConfigRefresh; let httpBridgeProcess; let httpBridgePort; let httpBridgeWorkspace; let stopInFlight; + function clearBridgeState(child) { + if (httpBridgeProcess !== child) { + return; + } + httpBridgeProcess = undefined; + httpBridgePort = undefined; + httpBridgeWorkspace = undefined; + if (typeof cancelPendingBridgeConfigRefresh === 'function') { + cancelPendingBridgeConfigRefresh(); + } + } + function normalizeBridgeUrl(url) { if (!url || typeof url !== 'string') { return ''; @@ -42,7 +56,32 @@ function createBridgeManager(deps) { } } + function getBridgeMode() { + try { + const settings = getEffectiveConfig(); + return (settings.get('mcpBridgeMode') || 'bundled').trim(); + } catch (_) { + return 'bundled'; + } + } + + function findBundledBridgeBin() { + if (!extensionRoot) return undefined; + const bundledPath = path.join(extensionRoot, 'ctx-mcp-bridge', 'bin', 'ctxce.js'); + if (fs.existsSync(bundledPath)) { + return path.resolve(bundledPath); + } + return undefined; + } + function findLocalBridgeBin() { + // First check for bundled bridge if mode is 'bundled' + const mode = getBridgeMode(); + if (mode === 'bundled') { + return findBundledBridgeBin(); + } + + // External mode logic (existing behavior) let localOnly = true; let configured = ''; try { @@ -69,12 +108,20 @@ function createBridgeManager(deps) { function resolveBridgeCliInvocation() { const binPath = findLocalBridgeBin(); if (binPath) { + // Use absolute Node runtime to avoid PATH dependency in extension hosts + const bundledBin = findBundledBridgeBin(); + const resolvedKind = bundledBin && path.resolve(binPath) === path.resolve(bundledBin) + ? 'bundled' + : 'local'; return { - command: 'node', + command: process.execPath, args: [binPath], - kind: 'local' + kind: resolvedKind }; } + if (getBridgeMode() === 'bundled') { + return undefined; + } const isWindows = process.platform === 'win32'; if (isWindows) { return { @@ -107,6 +154,10 @@ function createBridgeManager(deps) { return serverMode === 'bridge' && transportMode === 'http'; } + function requiresLocalBridgeProcess(serverMode, transportMode) { + return serverMode === 'bridge' && (transportMode === 'http' || transportMode === 'sse-remote'); + } + function resolveBridgeHttpUrl() { try { const settings = getEffectiveConfig(); @@ -199,21 +250,12 @@ function createBridgeManager(deps) { attachOutput(child, 'mcp-http'); child.on('exit', (code, signal) => { log(`HTTP MCP bridge exited with code ${code} signal ${signal || ''}`.trim()); - if (httpBridgeProcess === child) { - httpBridgeProcess = undefined; - httpBridgePort = undefined; - httpBridgeWorkspace = undefined; - } + clearBridgeState(child); }); child.on('error', error => { log(`HTTP MCP bridge process error: ${error instanceof Error ? error.message : String(error)}`); - if (httpBridgeProcess === child) { - httpBridgeProcess = undefined; - httpBridgePort = undefined; - httpBridgeWorkspace = undefined; - } + clearBridgeState(child); }); - vscode.window.showInformationMessage(`Context Engine HTTP MCP bridge listening on http://127.0.0.1:${options.port}/mcp`); if (typeof scheduleMcpConfigRefreshAfterBridge === 'function') { scheduleMcpConfigRefreshAfterBridge(); } @@ -269,10 +311,10 @@ function createBridgeManager(deps) { const serverModeRaw = config.get('mcpServerMode') || 'bridge'; const transportMode = (typeof transportModeRaw === 'string' ? transportModeRaw.trim() : 'sse-remote') || 'sse-remote'; const serverMode = (typeof serverModeRaw === 'string' ? serverModeRaw.trim() : 'bridge') || 'bridge'; - if (requiresHttpBridge(serverMode, transportMode)) { + if (requiresLocalBridgeProcess(serverMode, transportMode)) { await start(); } else { - log('Context Engine Uploader: HTTP bridge settings changed, but current MCP wiring does not use the HTTP bridge; not restarting HTTP bridge.'); + log('Context Engine Uploader: bridge settings changed, but current MCP wiring does not use the local bridge process; not restarting bridge.'); } } } @@ -290,6 +332,7 @@ function createBridgeManager(deps) { getState, isRunning, requiresHttpBridge, + requiresLocalBridgeProcess, resolveBridgeHttpUrl, ensureReadyForConfigs, start, diff --git a/vscode-extension/context-engine-uploader/mcp_config.js b/vscode-extension/context-engine-uploader/mcp_config.js index 3e40c2fd..a8c4f11c 100644 --- a/vscode-extension/context-engine-uploader/mcp_config.js +++ b/vscode-extension/context-engine-uploader/mcp_config.js @@ -52,6 +52,13 @@ function createMcpConfigManager(deps) { } } + function cancelPendingBridgeConfigRefresh() { + if (pendingBridgeConfigTimer) { + clearTimeout(pendingBridgeConfigTimer); + pendingBridgeConfigTimer = undefined; + } + } + async function writeAntigravityMcpServers(configPath, indexerUrl, memoryUrl, transportMode, serverMode = 'bridge', workspaceHint) { // TODO: Factor the shared "ensure dir + load JSON + applyMcpServersUpdate + writeJsonConfig" pattern // into a helper so Claude/Windsurf/Augment/Antigravity all call the same utility. @@ -338,10 +345,7 @@ function createMcpConfigManager(deps) { function scheduleMcpConfigRefreshAfterBridge(delayMs = 1500) { try { - if (pendingBridgeConfigTimer) { - clearTimeout(pendingBridgeConfigTimer); - pendingBridgeConfigTimer = undefined; - } + cancelPendingBridgeConfigRefresh(); // For bridge-http mode started by the extension, Windsurf needs the // "context-engine" MCP server entry removed and then re-added once the // HTTP bridge is ready. Best-effort removal happens immediately here; @@ -363,8 +367,12 @@ function createMcpConfigManager(deps) { } pendingBridgeConfigTimer = setTimeout(() => { pendingBridgeConfigTimer = undefined; - log('Context Engine Uploader: HTTP bridge ready; refreshing MCP configs.'); - writeMcpConfig().catch(error => { + if (typeof getBridgeIsRunning === 'function' && !getBridgeIsRunning()) { + log('Context Engine Uploader: HTTP bridge is not running; skipping delayed MCP config refresh.'); + return; + } + log('Context Engine Uploader: HTTP bridge still running; refreshing MCP configs.'); + writeMcpConfig({ skipHttpBridgeStart: true }).catch(error => { log(`Context Engine Uploader: MCP config refresh after bridge start failed: ${error instanceof Error ? error.message : String(error)}`); }); }, delayMs); @@ -736,6 +744,10 @@ function createMcpConfigManager(deps) { const needsHttpBridge = requiresHttpBridge(serverMode, transportMode); const bridgeWasRunning = !!(typeof getBridgeIsRunning === 'function' && getBridgeIsRunning()); if (needsHttpBridge) { + if (options.skipHttpBridgeStart && !bridgeWasRunning) { + log('Context Engine Uploader: HTTP bridge is not running; MCP config refresh will not restart it.'); + return; + } const ready = await ensureHttpBridgeReadyForConfigs(); if (!ready) { vscode.window.showErrorMessage('Context Engine Uploader: HTTP MCP bridge failed to start; MCP config not updated.'); @@ -828,6 +840,7 @@ function createMcpConfigManager(deps) { } return { + cancelPendingBridgeConfigRefresh, scheduleMcpConfigRefreshAfterBridge, writeMcpConfig, dispose, diff --git a/vscode-extension/context-engine-uploader/package.json b/vscode-extension/context-engine-uploader/package.json index d5e3584f..c3eadf0a 100644 --- a/vscode-extension/context-engine-uploader/package.json +++ b/vscode-extension/context-engine-uploader/package.json @@ -1,8 +1,8 @@ { "name": "context-engine-uploader", "displayName": "Context Engine Uploader", - "description": "Runs the Context-Engine remote upload client with a force sync on startup followed by watch mode. Requires Python with pip install requests urllib3 charset_normalizer.", - "version": "0.1.39", + "description": "Runs the Context-Engine remote upload client with bundled Python dependencies, force sync, and watch mode.", + "version": "0.1.40", "publisher": "context-engine", "engines": { "vscode": "^1.85.0" @@ -282,7 +282,7 @@ "contextEngineUploader.autoStartMcpBridge": { "type": "boolean", "default": true, - "description": "When enabled and mcpServerMode='bridge' with mcpTransportMode='http', automatically start the local ctx-mcp-bridge HTTP server for the active workspace so IDE clients can connect over HTTP without manual commands. Has no effect in stdio/direct modes." + "description": "When enabled and mcpServerMode='bridge', automatically start the bundled local ctx bridge process for the active workspace. In http mode it serves the local HTTP MCP bridge directly; in sse-remote mode it starts the same bundled bridge adapter used by bridge-stdio wiring. Has no effect in direct modes." }, "contextEngineUploader.mcpBridgePort": { "type": "number", @@ -297,7 +297,17 @@ "contextEngineUploader.mcpBridgeLocalOnly": { "type": "boolean", "default": false, - "description": "Development toggle. When true (default) the extension prefers local bridge binaries resolved from mcpBridgeBinPath or CTXCE_BRIDGE_BIN before falling back to the published npm build via npx." + "description": "Development toggle. When true and mcpBridgeMode='external', prefers local bridge binaries resolved from mcpBridgeBinPath or CTXCE_BRIDGE_BIN before falling back to the published npm build via npx. Ignored when mcpBridgeMode='bundled'." + }, + "contextEngineUploader.mcpBridgeMode": { + "type": "string", + "enum": ["bundled", "external"], + "default": "bundled", + "description": "Bridge invocation mode. 'bundled' uses the bundled bridge inside the extension (offline, no npx required). 'external' uses external binary path or npx (current behavior).", + "enumDescriptions": [ + "Use the bundled MCP bridge inside the extension (works offline).", + "Use external binary path or npx to run the bridge (requires internet for first npx install)." + ] }, "contextEngineUploader.mcpServerMode": { "type": "string", diff --git a/vscode-extension/context-engine-uploader/process_manager.js b/vscode-extension/context-engine-uploader/process_manager.js index a5e36fd4..22173d58 100644 --- a/vscode-extension/context-engine-uploader/process_manager.js +++ b/vscode-extension/context-engine-uploader/process_manager.js @@ -64,8 +64,11 @@ function createProcessManager(deps) { env.CONTAINER_ROOT = options.containerRoot; } try { - const libsPath = path.join(options.workingDirectory, 'python_libs'); - if (fs.existsSync(libsPath)) { + const libsPath = [ + path.join(options.workingDirectory, 'python_libs'), + path.join(getExtensionRoot(), 'python_libs') + ].find(p => p && fs.existsSync(p)); + if (libsPath) { const existing = env.PYTHONPATH || ''; env.PYTHONPATH = existing ? `${libsPath}${path.delimiter}${existing}` : libsPath; if (!_hasLoggedPythonPath) { diff --git a/vscode-extension/context-engine-uploader/python_env.js b/vscode-extension/context-engine-uploader/python_env.js deleted file mode 100644 index 190f9945..00000000 --- a/vscode-extension/context-engine-uploader/python_env.js +++ /dev/null @@ -1,386 +0,0 @@ -/** - * Python environment management for Context Engine extension. - * Handles dependency checking, venv creation, and Python interpreter detection. - */ -function createPythonEnvManager(deps) { - const vscode = deps.vscode; - const spawn = deps.spawn; - const path = deps.path; - const fs = deps.fs; - const log = deps.log; - - - - // Helper to spawn processes asynchronously with Promise wrapper - function execAsync(command, args, options = {}) { - return new Promise((resolve) => { - // Diagnostic check for spawn injection - if (typeof spawn !== 'function') { - resolve({ code: -1, stdout: '', stderr: `createPythonEnvManager: spawn is ${typeof spawn}` }); - return; - } - - const child = spawn(command, args, { - ...options, - env: options.env || process.env - }); - - let stdout = ''; - let stderr = ''; - - if (child.stdout) { - child.stdout.on('data', (data) => { - const str = data.toString(); - stdout += str; - if (options.onStdout) options.onStdout(str); - }); - } - - if (child.stderr) { - child.stderr.on('data', (data) => { - const str = data.toString(); - stderr += str; - if (options.onStderr) options.onStderr(str); - }); - } - - let finished = false; - - child.on('error', (err) => { - if (!finished) { - finished = true; - resolve({ code: -1, stdout, stderr: stderr || err.message }); - } - }); - - child.on('close', (code) => { - if (!finished) { - finished = true; - resolve({ code: code === null ? -1 : code, stdout, stderr }); - } - }); - - // Handle cancellation if token provided - if (options.token) { - options.token.onCancellationRequested(() => { - if (!finished) { - finished = true; - try { child.kill(); } catch (_) { } - resolve({ code: -1, stdout, stderr: 'Cancelled' }); - } - }); - } - - // Safety timeout - if (options.timeout) { - setTimeout(() => { - if (!finished) { - finished = true; - try { child.kill(); } catch (_) { } - resolve({ code: -1, stdout, stderr: 'Process timeout' }); - } - }, options.timeout); - } - }); - } - - function getExtensionRoot() { - try { - if (typeof deps.getExtensionRoot === 'function') { - const root = deps.getExtensionRoot(); - if (root) { - return root; - } - } - } catch (_) { - } - if (deps.extensionRoot) return deps.extensionRoot; - try { - return vscode.extensions.getExtension('context-engine.context-engine-uploader').extensionPath; - } catch (_) { - return __dirname; - } - } - - function getPythonOverridePath() { - return typeof deps.getPythonOverridePath === 'function' ? deps.getPythonOverridePath() : undefined; - } - - function setPythonOverridePath(p) { - if (typeof deps.setPythonOverridePath === 'function') { - deps.setPythonOverridePath(p); - } - } - - const REQUIRED_PYTHON_MODULES = ['requests', 'urllib3', 'charset_normalizer', 'watchdog']; - const depCheckCache = new Map(); - - function cacheKey(pythonPath, workingDirectory) { - return `${pythonPath || ''}::${workingDirectory || ''}`; - } - - function venvRootDir() { - // Prefer workspace storage; fallback to extension storage - try { - const ws = deps.getWorkspaceFolderPath(); - const globalStorage = deps.getGlobalStoragePath() || path.join(getExtensionRoot(), '.storage'); - const base = ws && fs.existsSync(ws) ? path.join(ws, '.vscode', '.context-engine-uploader') - : globalStorage; - if (!fs.existsSync(base)) fs.mkdirSync(base, { recursive: true }); - return base; - } catch (e) { - return getExtensionRoot(); - } - } - - function privateVenvPath() { - return path.join(venvRootDir(), 'py-venv'); - } - - function resolvePrivateVenvPython() { - const venvPath = privateVenvPath(); - const bin = process.platform === 'win32' ? path.join(venvPath, 'Scripts', 'python.exe') : path.join(venvPath, 'bin', 'python'); - return fs.existsSync(bin) ? bin : undefined; - } - - async function detectSystemPython() { - // Try configured pythonPath, then common names - const candidates = []; - try { - const cfg = (typeof getEffectiveConfig === 'function') - ? getEffectiveConfig() - : vscode.workspace.getConfiguration('contextEngineUploader'); - const configured = (cfg && typeof cfg.get === 'function') ? (cfg.get('pythonPath') || '').trim() : ''; - if (configured) candidates.push(configured); - } catch { } - if (process.platform === 'win32') { - candidates.push('py', 'python3', 'python'); - } else { - candidates.push('python3', 'python'); - // Add common Homebrew path on Apple Silicon - candidates.push('/opt/homebrew/bin/python3'); - } - - for (const cmd of candidates) { - try { - // Version check: major >= 3 and print executable - const res = await execAsync(cmd, ['-c', 'import sys; print(f"{sys.version_info[0]}|{sys.executable}")'], { timeout: 3000 }); - if (res.code === 0) { - const parts = res.stdout.trim().split('|'); - if (parts.length === 2) { - const major = parseInt(parts[0], 10); - const executable = parts[1].trim(); - if (major >= 3 && executable) return executable; - } - } - } catch (e) { - // Skip candidate - } - } - return undefined; - } - - async function checkPythonDeps(pythonPath, workingDirectory, options = {}) { - const showInterpreterError = options.showInterpreterError !== undefined ? options.showInterpreterError : true; - const missing = []; - let pythonError; - const env = { ...process.env }; - try { - const candidates = []; - if (workingDirectory) { - candidates.push(path.join(workingDirectory, 'python_libs')); - } - candidates.push(path.join(getExtensionRoot(), 'python_libs')); - for (const libsPath of candidates) { - if (libsPath && fs.existsSync(libsPath)) { - const existing = env.PYTHONPATH || ''; - env.PYTHONPATH = existing ? `${libsPath}${path.delimiter}${existing}` : libsPath; - break; - } - } - } catch (error) { - log(`Failed to configure PYTHONPATH for dependency check: ${error instanceof Error ? error.message : String(error)}`); - } - - const smoke = await execAsync(pythonPath, ['-c', 'import sys; print(sys.executable)'], { env, timeout: 5000 }); - if (smoke.code !== 0) { - pythonError = String((smoke.stderr || smoke.stdout || '')).trim(); - } - - if (!pythonError) { - for (const moduleName of REQUIRED_PYTHON_MODULES) { - const check = await execAsync(pythonPath, ['-c', `import ${moduleName}`], { env, timeout: 5000 }); - if (check.code !== 0) { - missing.push(moduleName); - } - } - } - - if (pythonError) { - if (showInterpreterError) { - vscode.window.showErrorMessage(`Context Engine Uploader: failed to run ${pythonPath}. Update contextEngineUploader.pythonPath.`); - } - log(`Dependency check failed: ${pythonError}`); - return false; - } - if (missing.length) { - log(`Missing Python modules for ${pythonPath}: ${missing.join(', ')}`); - return false; - } - return true; - } - - async function ensurePrivateVenv() { - try { - const python = resolvePrivateVenvPython(); - if (python) { - log('Private venv already exists.'); - return true; - } - const venvPath = privateVenvPath(); - const basePy = await detectSystemPython(); - if (!basePy) { - vscode.window.showErrorMessage('Context Engine Uploader: no Python 3 interpreter found to bootstrap venv.'); - return false; - } - - // Verify venv module presence - try { - const venvCheck = await execAsync(basePy, ['-c', 'import venv'], { timeout: 5000 }); - if (venvCheck.code !== 0) { - log(`Python "venv" module missing in ${basePy}: ${venvCheck.stderr}`); - vscode.window.showErrorMessage(`Context Engine Uploader: Python "venv" module is missing in ${basePy}.`); - return false; - } - } catch (e) { - const errorMsg = e instanceof Error ? e.message : String(e); - log(`Failed to check for venv module: ${errorMsg}`); - return false; - } - - log(`Creating private venv at ${venvPath} using ${basePy}`); - const res = await execAsync(basePy, ['-m', 'venv', venvPath], { timeout: 30000 }); - if (res.code !== 0) { - log(`venv creation failed: ${res.stderr || res.stdout}`); - vscode.window.showErrorMessage('Context Engine Uploader: failed to create private venv.'); - return false; - } - return true; - } catch (e) { - log(`ensurePrivateVenv error: ${e && e.message ? e.message : String(e)}`); - return false; - } - } - - async function installDepsInto(pythonBin) { - return vscode.window.withProgress({ - location: vscode.ProgressLocation.Notification, - title: "Context Engine Uploader: Installing Python dependencies...", - cancellable: true - }, async (progress, token) => { - try { - log(`Installing Python deps into private venv via ${pythonBin}`); - const args = ['-m', 'pip', 'install', ...REQUIRED_PYTHON_MODULES]; - - const res = await execAsync(pythonBin, args, { - timeout: 60000, - token, - onStdout: (data) => { - progress.report({ message: data.split('\n').pop() }); - }, - onStderr: (data) => { - log(`pip install stderr: ${data}`); - } - }); - - if (res.code !== 0) { - log(`pip install failed: ${res.stderr || res.stdout}`); - vscode.window.showErrorMessage('Context Engine Uploader: pip install failed. See Output for details.'); - return false; - } - return true; - } catch (e) { - const msg = e && e.message ? e.message : String(e); - log(`installDepsInto error: ${msg}`); - vscode.window.showErrorMessage(`Context Engine Uploader: ${msg}`); - return false; - } - }); - } - - async function ensurePythonDependencies(pythonPath, workingDirectory, pythonPathSource) { - // Probe current interpreter with bundled python_libs first - const allowPrompt = pythonPathSource === 'configured' || pythonPathSource === 'override'; - const primaryKey = cacheKey(pythonPath, workingDirectory); - if (depCheckCache.get(primaryKey)) { - return true; - } - let ok = await checkPythonDeps(pythonPath, workingDirectory, { showInterpreterError: allowPrompt }); - if (ok) { - depCheckCache.set(primaryKey, true); - return true; - } - - // If that fails, try to auto-detect a better system Python before falling back to a venv - const autoPython = await detectSystemPython(); - if (autoPython && autoPython !== pythonPath) { - log(`Falling back to auto-detected Python interpreter: ${autoPython}`); - const autoKey = cacheKey(autoPython, workingDirectory); - if (depCheckCache.get(autoKey)) { - setPythonOverridePath(autoPython); - return true; - } - ok = await checkPythonDeps(autoPython, workingDirectory, { showInterpreterError: allowPrompt }); - if (ok) { - setPythonOverridePath(autoPython); - depCheckCache.set(autoKey, true); - return true; - } - } - - // As a last resort, offer to create a private venv and install deps via pip - if (!allowPrompt) { - log('Skipping auto-install prompt; interpreter was auto-detected and missing modules.'); - return false; - } - const choice = await vscode.window.showErrorMessage( - 'Context Engine Uploader: missing Python modules. Create isolated environment and auto-install?', - 'Auto-install to private venv', - 'Cancel' - ); - if (choice !== 'Auto-install to private venv') { - return false; - } - const created = await ensurePrivateVenv(); - if (!created) return false; - const venvPython = resolvePrivateVenvPython(); - if (!venvPython) { - vscode.window.showErrorMessage('Context Engine Uploader: failed to locate private venv python.'); - return false; - } - const installed = await installDepsInto(venvPython); - if (!installed) return false; - setPythonOverridePath(venvPython); - log(`Using private venv interpreter: ${getPythonOverridePath()}`); - const venvKey = cacheKey(venvPython, workingDirectory); - if (depCheckCache.get(venvKey)) { - return true; - } - const finalOk = await checkPythonDeps(venvPython, workingDirectory, { showInterpreterError: true }); - if (finalOk) { - depCheckCache.set(venvKey, true); - } - return finalOk; - } - - return { - resolvePrivateVenvPython, - detectSystemPython, - checkPythonDeps, - ensurePythonDependencies, - }; -} - -module.exports = { - createPythonEnvManager, -};