diff --git a/src/pi_agent_chain/pipeline.py b/src/pi_agent_chain/pipeline.py index c79e62e..db61be5 100644 --- a/src/pi_agent_chain/pipeline.py +++ b/src/pi_agent_chain/pipeline.py @@ -49,6 +49,7 @@ from pi_agent_chain.nodes.spec_synthesizer import SpecSynthesizerNode from pi_agent_chain.nodes.structural_extractor import StructuralExtractorNode from pi_agent_chain.nodes.verifier import DifferentialVerifierNode +from pi_agent_chain.tenant_context import current_tenant from pi_agent_chain.verification.auth_consistency import AuthConsistencyValidator from pi_agent_chain.verification.entropy_analysis import EntropyAnalysisValidator from pi_agent_chain.verification.provenance_validator import ProvenanceValidator @@ -683,6 +684,7 @@ def _log_trace( raw_output=raw_output, is_valid_type=is_valid, error_message=error, + tenant_id=current_tenant(), ) self.ledger.append(trace) diff --git a/src/pi_agent_chain/tenant_context.py b/src/pi_agent_chain/tenant_context.py new file mode 100644 index 0000000..995d348 --- /dev/null +++ b/src/pi_agent_chain/tenant_context.py @@ -0,0 +1,66 @@ +"""Async/thread-local current tenant for stamping execution-trace writes. + +The console's read surface already isolates traces per tenant +(``ledger_router`` filters by the caller's JWT claim via +``auth_guard.tenant_scope``). But the orchestrator write paths constructed +``ExecutionTrace`` without a ``tenant_id``, so every real audit row defaulted to +``'default'`` and the read filter never isolated real traffic. + +This module carries the *authenticated* tenant from the request boundary +(bound in ``jwt_validation_middleware`` from the JWT claim - NOT the forgeable +``X-Tenant-ID`` header) down to wherever a trace is written. CLI / direct / +background execution falls back to ``DEFAULT_TENANT`` unless a caller opens a +``tenant_scope``. + +Determinism: no wall-clock, no randomness - the value is whatever the caller +bound, so replay under the same scope reproduces the same attribution. +""" + +from __future__ import annotations + +import contextvars +import re +from contextlib import contextmanager +from typing import Any, Dict, Iterator, Optional + +DEFAULT_TENANT = "default" +# Matches the console's X-Tenant-ID validation so a claim can never carry a +# path-traversal / injection payload into an attribution. +_TENANT_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + +_tenant: contextvars.ContextVar[str] = contextvars.ContextVar("pi_tenant_id", default=DEFAULT_TENANT) + + +def _normalize(tenant_id: Optional[str]) -> str: + return tenant_id if (tenant_id and _TENANT_RE.match(tenant_id)) else DEFAULT_TENANT + + +def current_tenant() -> str: + """The tenant bound to the current async task / thread context.""" + return _tenant.get() + + +def set_tenant(tenant_id: Optional[str]) -> contextvars.Token: + """Bind the current tenant; returns a token for reset_tenant().""" + return _tenant.set(_normalize(tenant_id)) + + +def reset_tenant(token: contextvars.Token) -> None: + _tenant.reset(token) + + +@contextmanager +def tenant_scope(tenant_id: Optional[str]) -> Iterator[str]: + """Bind a tenant for the duration of the block (restores on exit).""" + token = set_tenant(tenant_id) + try: + yield current_tenant() + finally: + reset_tenant(token) + + +def tenant_from_claims(claims: Optional[Dict[str, Any]]) -> str: + """Extract the trusted tenant from authenticated JWT claims (or default).""" + if not claims: + return DEFAULT_TENANT + return _normalize(claims.get("tenant_id")) diff --git a/src/pi_console/auth_guard.py b/src/pi_console/auth_guard.py index 6643c98..96d12ab 100644 --- a/src/pi_console/auth_guard.py +++ b/src/pi_console/auth_guard.py @@ -12,11 +12,14 @@ (fail closed) unless an operator *explicitly* opts out for local development via ``PI_CONSOLE_ALLOW_UNAUTHENTICATED=1``. -NOTE (follow-up, not closed here): the ``execution_trace`` ledger has no -``tenant_id`` column, so reads cannot yet be scoped per-tenant. This gate closes -the unauthenticated-access hole; row-level tenant isolation requires a schema -migration (add ``tenant_id``, populate on write, filter by the caller's claim -unless an admin role) plus RBAC enforcement on these routes. +Row-level tenant isolation is now in place on both sides: + * READ: ``execution_trace`` has a ``tenant_id`` column (with in-place legacy + migration), and the ledger/transparency routes filter every query by the + caller's ``tenant_id`` JWT claim via ``tenant_scope`` below (admins see all). + * WRITE: the orchestrator write paths stamp the authenticated tenant onto each + trace via ``pi_agent_chain.tenant_context`` (bound from the JWT claim in + ``jwt_validation_middleware`` - never from the forgeable ``X-Tenant-ID`` + header), so audit rows carry their real tenant rather than defaulting. """ from __future__ import annotations diff --git a/src/pi_console/main.py b/src/pi_console/main.py index 5b8ceb3..05aedba 100644 --- a/src/pi_console/main.py +++ b/src/pi_console/main.py @@ -16,6 +16,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from pi_agent_chain.tenant_context import reset_tenant, set_tenant, tenant_from_claims from pi_console.auth_guard import require_reader from pi_console.routers import ( audit_router, @@ -205,7 +206,14 @@ async def jwt_validation_middleware(request: Request, call_next): request.state.jwt_claims = jwt_codec.decode(token) except AuthenticationError as e: return JSONResponse(status_code=401, content={"detail": str(e)}) - return await call_next(request) + # Bind the AUTHENTICATED tenant (from the JWT claim, not the forgeable + # X-Tenant-ID header) for the duration of this request, so any + # ExecutionTrace written downstream is attributed to the right tenant. + ctx_token = set_tenant(tenant_from_claims(request.state.jwt_claims)) + try: + return await call_next(request) + finally: + reset_tenant(ctx_token) @app.middleware("http") async def tenant_injection_middleware(request: Request, call_next): diff --git a/src/pi_micro_agents/orchestrator/chain_engine.py b/src/pi_micro_agents/orchestrator/chain_engine.py index d40feed..ba0d2b4 100644 --- a/src/pi_micro_agents/orchestrator/chain_engine.py +++ b/src/pi_micro_agents/orchestrator/chain_engine.py @@ -17,6 +17,7 @@ from pi_agent_chain.ledger import StateLedger from pi_agent_chain.models import ExecutionTrace +from pi_agent_chain.tenant_context import current_tenant from pi_micro_agents.orchestrator.consensus import run_with_consensus from pi_micro_agents.orchestrator.router import AgentRoute, AgentRouter from pi_micro_agents.orchestrator.shield import PiOrchestratorShield @@ -668,6 +669,7 @@ def execute_chain(self, routes: List[AgentRoute], goal: str, context: Dict[str, raw_output=json.dumps(details), is_valid_type=success, error_message=", ".join(alerts) if alerts else None, + tenant_id=current_tenant(), ) self.ledger.append(trace) diff --git a/src/pi_micro_agents/orchestrator/core.py b/src/pi_micro_agents/orchestrator/core.py index 8ed0315..0767d61 100644 --- a/src/pi_micro_agents/orchestrator/core.py +++ b/src/pi_micro_agents/orchestrator/core.py @@ -13,6 +13,7 @@ from pi_agent_chain.ledger import StateLedger from pi_agent_chain.models import ExecutionTrace +from pi_agent_chain.tenant_context import current_tenant from pi_micro_agents.orchestrator.consensus import run_with_consensus from pi_micro_agents.orchestrator.router import AgentRouter @@ -25,26 +26,12 @@ from pi_micro_agents.pi_publisher_dispatch import PiPublisherDispatch, PublisherInput from pi_micro_agents.pi_spend_hunter import PiSpendAnomalyHunter from pi_micro_agents.pi_youtube_transcriber import PiYoutubeTranscriber, TranscriptInput +from pi_micro_agents.strict_mode import resolve_strict_mode def is_strict_mode() -> bool: - """Strict mode resolver checking environment variables and configuration files.""" - env_val = os.getenv("PI_ORCHESTRATOR_STRICT_MODE") - if env_val is not None: - return env_val.lower() == "true" - - config_path = os.path.expanduser("~/.antigravitycli/config.json") - if not os.path.exists(config_path): - config_path = os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json") - - if os.path.exists(config_path): - try: - with open(config_path, "r") as f: - data = json.load(f) - return bool(data.get("PI_ORCHESTRATOR_STRICT_MODE", True)) - except Exception: - pass - return True + """Strict-mode resolver — delegates to the central resolver (was a per-file copy).""" + return resolve_strict_mode("PI_ORCHESTRATOR_STRICT_MODE") class OrchestratorInput(BaseModel): @@ -544,6 +531,7 @@ def _compile_and_log_output( raw_output=res_output.model_dump_json(), is_valid_type=success, error_message=", ".join(anomalies) if anomalies else None, + tenant_id=current_tenant(), ) self.ledger.append(trace) except Exception: diff --git a/src/pi_micro_agents/pi_cot_shadow.py b/src/pi_micro_agents/pi_cot_shadow.py index e4b64c9..4cb2eaa 100644 --- a/src/pi_micro_agents/pi_cot_shadow.py +++ b/src/pi_micro_agents/pi_cot_shadow.py @@ -1,31 +1,15 @@ from __future__ import annotations -import json import math -import os import re from typing import Any, Dict, List, Tuple +from pi_micro_agents.strict_mode import resolve_strict_mode -# 1. Load strict-mode configurations + +# 1. Strict-mode resolution (delegates to the central resolver) def is_strict_mode() -> bool: - env_val = os.getenv("PI_COT_STRICT_MODE") - if env_val is not None: - return env_val.lower() == "true" - - # Check config.json file - config_path = os.path.expanduser("~/.antigravitycli/config.json") - if not os.path.exists(config_path): - config_path = os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json") - - if os.path.exists(config_path): - try: - with open(config_path, "r") as f: - data = json.load(f) - return bool(data.get("PI_COT_STRICT_MODE", True)) - except Exception: - pass - return True + return resolve_strict_mode("PI_COT_STRICT_MODE") # 2. Heuristic Detection Core for Code/Text Payloads (checking for invisible guardrails or evasions) diff --git a/src/pi_micro_agents/pi_prompt_shield.py b/src/pi_micro_agents/pi_prompt_shield.py index 3dfe573..e1fa963 100644 --- a/src/pi_micro_agents/pi_prompt_shield.py +++ b/src/pi_micro_agents/pi_prompt_shield.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -import os import re from typing import Any, List, Tuple @@ -9,26 +8,12 @@ from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware +from pi_micro_agents.strict_mode import resolve_strict_mode -# 1. Load strict-mode configurations + +# 1. Strict-mode resolution (delegates to the central resolver) def is_strict_mode() -> bool: - env_val = os.getenv("PI_SHIELD_STRICT_MODE") - if env_val is not None: - return env_val.lower() == "true" - - # Check config.json file - config_path = os.path.expanduser("~/.antigravitycli/config.json") - if not os.path.exists(config_path): - config_path = os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json") - - if os.path.exists(config_path): - try: - with open(config_path, "r") as f: - data = json.load(f) - return bool(data.get("PI_SHIELD_STRICT_MODE", True)) - except Exception: - pass - return True + return resolve_strict_mode("PI_SHIELD_STRICT_MODE") # 2. Heuristic Detection Core diff --git a/src/pi_micro_agents/pi_schema_ghost.py b/src/pi_micro_agents/pi_schema_ghost.py index 9057c21..6250e6a 100644 --- a/src/pi_micro_agents/pi_schema_ghost.py +++ b/src/pi_micro_agents/pi_schema_ghost.py @@ -1,30 +1,14 @@ from __future__ import annotations -import json -import os import re from typing import Any, Dict, List, Set, Tuple +from pi_micro_agents.strict_mode import resolve_strict_mode -# 1. Load strict-mode configurations + +# 1. Strict-mode resolution (delegates to the central resolver) def is_strict_mode() -> bool: - env_val = os.getenv("PI_GHOST_STRICT_MODE") - if env_val is not None: - return env_val.lower() == "true" - - # Check config.json file - config_path = os.path.expanduser("~/.antigravitycli/config.json") - if not os.path.exists(config_path): - config_path = os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json") - - if os.path.exists(config_path): - try: - with open(config_path, "r") as f: - data = json.load(f) - return bool(data.get("PI_GHOST_STRICT_MODE", True)) - except Exception: - pass - return True + return resolve_strict_mode("PI_GHOST_STRICT_MODE") # 2. Heuristic Detection Core for Code/Text Payloads diff --git a/src/pi_micro_agents/pi_surplus_orchestrator.py b/src/pi_micro_agents/pi_surplus_orchestrator.py index ee25a16..724df56 100644 --- a/src/pi_micro_agents/pi_surplus_orchestrator.py +++ b/src/pi_micro_agents/pi_surplus_orchestrator.py @@ -1,32 +1,16 @@ from __future__ import annotations -import json -import os import re import time import uuid from typing import Any, Dict, List, Tuple +from pi_micro_agents.strict_mode import resolve_strict_mode -# 1. Load strict-mode configurations -def is_strict_mode() -> bool: - env_val = os.getenv("PI_SURPLUS_STRICT_MODE") - if env_val is not None: - return env_val.lower() == "true" - - # Check config.json file - config_path = os.path.expanduser("~/.antigravitycli/config.json") - if not os.path.exists(config_path): - config_path = os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json") - if os.path.exists(config_path): - try: - with open(config_path, "r") as f: - data = json.load(f) - return bool(data.get("PI_SURPLUS_STRICT_MODE", True)) - except Exception: - pass - return True +# 1. Strict-mode resolution (delegates to the central resolver) +def is_strict_mode() -> bool: + return resolve_strict_mode("PI_SURPLUS_STRICT_MODE") # 2. Heuristic static search for illegal surplus sub-key leakage diff --git a/src/pi_micro_agents/utils.py b/src/pi_micro_agents/utils.py index 983da3a..fb23d54 100644 --- a/src/pi_micro_agents/utils.py +++ b/src/pi_micro_agents/utils.py @@ -7,18 +7,15 @@ from __future__ import annotations -import json -import os +from pi_micro_agents.strict_mode import resolve_strict_mode def is_strict_mode(env_var: str) -> bool: """Return True if the named strict-mode flag is enabled. - Resolution order (first match wins): - 1. Environment variable ``env_var`` (e.g. ``PI_GAS_STRICT_MODE``). - 2. ``~/.antigravitycli/config.json`` key matching ``env_var``. - 3. Project-local ``.antigravitycli/config.json`` (two dirs above this file). - 4. Default: ``True`` (fail-safe — strict by default). + Thin compatibility wrapper that delegates to the single source of truth, + :func:`pi_micro_agents.strict_mode.resolve_strict_mode` (see it for the exact + resolution order; the default is fail-safe ``True``). Args: env_var: The environment variable / config key to check, e.g. @@ -26,31 +23,5 @@ def is_strict_mode(env_var: str) -> bool: Returns: ``True`` when strict mode is active, ``False`` otherwise. - - Example:: - - from pi_micro_agents.utils import is_strict_mode - - if is_strict_mode("PI_GAS_STRICT_MODE"): - status = "REJECTED_GAS_RISK" """ - env_val = os.getenv(env_var) - if env_val is not None: - return env_val.lower() == "true" - - # Try user-level config, then project-level config - config_paths = [ - os.path.expanduser("~/.antigravitycli/config.json"), - os.path.join(os.path.dirname(__file__), "../../.antigravitycli/config.json"), - ] - for config_path in config_paths: - if os.path.exists(config_path): - try: - with open(config_path) as f: - data = json.load(f) - return bool(data.get(env_var, True)) - except Exception: # noqa: BLE001 — best-effort config read - pass - - # Fail-safe: strict mode on by default - return True + return resolve_strict_mode(env_var) diff --git a/tests/console/backend/test_tenant_write_binding.py b/tests/console/backend/test_tenant_write_binding.py new file mode 100644 index 0000000..cfb3c5a --- /dev/null +++ b/tests/console/backend/test_tenant_write_binding.py @@ -0,0 +1,68 @@ +"""The JWT middleware binds the AUTHENTICATED tenant into tenant_context. + +This is the wiring that makes write-path stamping real: while serving an +authenticated request, current_tenant() reflects the caller's JWT tenant_id +claim, so any ExecutionTrace written downstream is attributed correctly. The +attribution source is the JWT claim, NOT the client-supplied X-Tenant-ID header +(which a caller could forge). Uses a sync probe route to exercise the +threadpool context-propagation path that real sync handlers take. +""" + +from __future__ import annotations + +from fastapi.testclient import TestClient + +from pi_agent_chain.tenant_context import current_tenant +from pi_console import main as console_main +from pi_production.security.auth import JWTToken + +_SECRET = "tenant-write-binding-secret" + + +def _app_with_probe(monkeypatch): + monkeypatch.setattr(console_main, "JWT_SECRET", _SECRET) + monkeypatch.delenv("PI_CONSOLE_ALLOW_UNAUTHENTICATED", raising=False) + app = console_main.create_app() + + @app.get("/api/v1/_probe_tenant") + def _probe(): # sync -> runs in the threadpool, exercising context propagation + return {"tenant": current_tenant()} + + return app + + +def _tok(**claims) -> str: + return JWTToken(_SECRET).encode(claims) + + +def test_authenticated_tenant_is_bound(monkeypatch): + client = TestClient(_app_with_probe(monkeypatch)) + tok = _tok(sub="u", tenant_id="tenant-a", role="user") + r = client.get("/api/v1/_probe_tenant", headers={"Authorization": f"Bearer {tok}"}) + assert r.status_code == 200 + assert r.json()["tenant"] == "tenant-a" + + +def test_header_cannot_forge_tenant(monkeypatch): + client = TestClient(_app_with_probe(monkeypatch)) + tok = _tok(sub="u", tenant_id="tenant-a", role="user") + r = client.get( + "/api/v1/_probe_tenant", + headers={"Authorization": f"Bearer {tok}", "X-Tenant-ID": "tenant-evil"}, + ) + assert r.json()["tenant"] == "tenant-a" # the JWT claim wins, not the header + + +def test_no_tenant_claim_defaults(monkeypatch): + client = TestClient(_app_with_probe(monkeypatch)) + tok = _tok(sub="u", role="user") # no tenant_id claim + r = client.get("/api/v1/_probe_tenant", headers={"Authorization": f"Bearer {tok}"}) + assert r.json()["tenant"] == "default" + + +def test_request_does_not_leak_tenant_across_calls(monkeypatch): + client = TestClient(_app_with_probe(monkeypatch)) + client.get("/api/v1/_probe_tenant", headers={"Authorization": f"Bearer {_tok(tenant_id='tenant-a')}"}) + # a fresh request with no tenant claim must not inherit tenant-a + r = client.get("/api/v1/_probe_tenant", headers={"Authorization": f"Bearer {_tok(sub='u')}"}) + assert r.json()["tenant"] == "default" diff --git a/tests/unit/pi-agent-chain/test_tenant_context.py b/tests/unit/pi-agent-chain/test_tenant_context.py new file mode 100644 index 0000000..a761594 --- /dev/null +++ b/tests/unit/pi-agent-chain/test_tenant_context.py @@ -0,0 +1,69 @@ +"""The async/thread-local current-tenant used to stamp execution-trace writes. + +Read-side isolation already existed; this carries the *authenticated* tenant +from the request boundary down to wherever a trace is written, so audit rows +stop defaulting to 'default'. The tenant source is the JWT claim (trusted), not +the client X-Tenant-ID header (forgeable). +""" + +from __future__ import annotations + +from pi_agent_chain.tenant_context import ( + DEFAULT_TENANT, + current_tenant, + set_tenant, + tenant_from_claims, + tenant_scope, +) + + +class TestContextVar: + def test_default_is_default_tenant(self): + assert current_tenant() == DEFAULT_TENANT + + def test_scope_sets_and_restores(self): + assert current_tenant() == DEFAULT_TENANT + with tenant_scope("acme"): + assert current_tenant() == "acme" + assert current_tenant() == DEFAULT_TENANT # restored on exit + + def test_scope_restores_on_exception(self): + try: + with tenant_scope("acme"): + raise ValueError("boom") + except ValueError: + pass + assert current_tenant() == DEFAULT_TENANT + + def test_nested_scopes(self): + with tenant_scope("a"): + with tenant_scope("b"): + assert current_tenant() == "b" + assert current_tenant() == "a" + + def test_set_tenant_token_resets(self): + tok = set_tenant("x") + assert current_tenant() == "x" + from pi_agent_chain.tenant_context import reset_tenant + + reset_tenant(tok) + assert current_tenant() == DEFAULT_TENANT + + def test_invalid_tenant_falls_back_to_default(self): + with tenant_scope("bad/../slash"): + assert current_tenant() == DEFAULT_TENANT + with tenant_scope(""): + assert current_tenant() == DEFAULT_TENANT + + +class TestTenantFromClaims: + def test_extracts_claim(self): + assert tenant_from_claims({"tenant_id": "acme", "sub": "u"}) == "acme" + + def test_none_or_missing_is_default(self): + assert tenant_from_claims(None) == DEFAULT_TENANT + assert tenant_from_claims({"sub": "u"}) == DEFAULT_TENANT + + def test_forged_invalid_claim_is_default(self): + # a malformed tenant claim must never become a real attribution + assert tenant_from_claims({"tenant_id": "../etc"}) == DEFAULT_TENANT diff --git a/tests/unit/pi-agent-chain/test_trace_tenant_stamping.py b/tests/unit/pi-agent-chain/test_trace_tenant_stamping.py new file mode 100644 index 0000000..17d22d2 --- /dev/null +++ b/tests/unit/pi-agent-chain/test_trace_tenant_stamping.py @@ -0,0 +1,61 @@ +"""The orchestrator write paths stamp the current tenant onto ExecutionTrace. + +This closes the write half of the console tenant-isolation finding: previously +PiOrchestrator / ChainExecutionEngine / PipelineDriver built traces with no +tenant_id, so every real audit row defaulted to 'default' and the (correct) +read-side filter never isolated real traffic. Now a trace written inside a +tenant_scope carries that tenant; with no scope it defaults to 'default'. +""" + +from __future__ import annotations + +from pydantic import BaseModel + +import pi_micro_agents.orchestrator.chain_engine as ce +from pi_agent_chain.ledger import StateLedger +from pi_agent_chain.pipeline import PipelineDriver +from pi_agent_chain.tenant_context import tenant_scope +from pi_micro_agents.orchestrator.chain_engine import ChainExecutionEngine +from pi_micro_agents.orchestrator.core import PiOrchestrator +from pi_micro_agents.orchestrator.router import AgentRoute + + +class _In(BaseModel): + goal: str = "" + + +class TestOrchestratorWritePath: + def test_stamps_current_tenant(self, tmp_path): + led = StateLedger(str(tmp_path / "o.db")) + orch = PiOrchestrator(ledger=led) + with tenant_scope("tenant-a"): + orch._compile_and_log_output(True, "PiX", 0.0, "s", {}, [], "goal") + traces = led.get_all() + assert traces and all(t.tenant_id == "tenant-a" for t in traces) + + def test_defaults_without_scope(self, tmp_path): + led = StateLedger(str(tmp_path / "o2.db")) + PiOrchestrator(ledger=led)._compile_and_log_output(True, "PiX", 0.0, "s", {}, [], "goal") + assert led.get_all()[0].tenant_id == "default" + + +class TestPipelineWritePath: + def test_stamps_current_tenant(self, tmp_path): + led = StateLedger(str(tmp_path / "p.db")) + drv = PipelineDriver(ledger=led, base_url="http://localhost") + with tenant_scope("tenant-b"): + drv._log_trace(trace_id="t1", node_name="n", input_hash="h", raw_output="{}", is_valid=True) + assert led.get_trace("t1")[0].tenant_id == "tenant-b" + + +class TestChainEngineWritePath: + def test_stamps_current_tenant(self, monkeypatch, tmp_path): + # stub the heavy consensus call so we reach the trace-write deterministically + monkeypatch.setattr(ce, "run_with_consensus", lambda *a, **k: (True, 0.0, "ok", {}, [])) + led = StateLedger(str(tmp_path / "c.db")) + engine = ChainExecutionEngine(orchestrator=object(), ledger=led) + route = AgentRoute(agent_name="PiX", keywords=[], agent_class=BaseModel, input_factory=lambda g, c: _In(goal=g)) + with tenant_scope("tenant-c"): + engine.execute_chain([route], goal="g", context={}) + traces = led.get_all() + assert traces and all(t.tenant_id == "tenant-c" for t in traces) diff --git a/tests/unit/pi-micro-agents/test_strict_mode_consolidation.py b/tests/unit/pi-micro-agents/test_strict_mode_consolidation.py new file mode 100644 index 0000000..dc027f7 --- /dev/null +++ b/tests/unit/pi-micro-agents/test_strict_mode_consolidation.py @@ -0,0 +1,83 @@ +"""The last real per-agent is_strict_mode() resolvers delegate to the central one. + +Commit 986d17c centralized ~205 resolvers onto pi_micro_agents.strict_mode. +resolve_strict_mode but explicitly deferred the outliers. Of the remaining real +(importable) outliers, six still resolved strict mode inline AND read +~/.antigravitycli directly — the last copies of the scattered-config footgun. +This pins that they now (a) delegate to the central resolver, (b) pass the right +env key, (c) preserve env-var behavior, and (d) no longer read the config path +themselves. + +(The other ~19 "outliers" are dead, unparseable string-literal blob files that +do not import or ship — out of scope here; a separate dead-file cleanup.) +""" + +from __future__ import annotations + +import importlib +from pathlib import Path + +import pytest + +# real modules whose module-level is_strict_mode() is consolidated +CONSOLIDATED = { + "pi_micro_agents.orchestrator.core": "PI_ORCHESTRATOR_STRICT_MODE", + "pi_micro_agents.pi_cot_shadow": "PI_COT_STRICT_MODE", + "pi_micro_agents.pi_prompt_shield": "PI_SHIELD_STRICT_MODE", + "pi_micro_agents.pi_schema_ghost": "PI_GHOST_STRICT_MODE", + "pi_micro_agents.pi_surplus_orchestrator": "PI_SURPLUS_STRICT_MODE", +} + +# files (relative to src/pi_micro_agents) that must no longer read the config path +NO_CONFIG_FILES = [ + "orchestrator/core.py", + "pi_cot_shadow.py", + "pi_prompt_shield.py", + "pi_schema_ghost.py", + "pi_surplus_orchestrator.py", + "utils.py", +] + + +@pytest.mark.parametrize("modname,env_key", list(CONSOLIDATED.items())) +def test_module_delegates_to_central_resolver(monkeypatch, modname, env_key): + mod = importlib.import_module(modname) + seen = {} + + def fake(key, default=True): + seen["key"] = key + return False + + monkeypatch.setattr(mod, "resolve_strict_mode", fake) + assert mod.is_strict_mode() is False # honors the central resolver + assert seen["key"] == env_key + + +def test_utils_parameterized_delegates(monkeypatch): + from pi_micro_agents import utils + + seen = {} + + def fake(key, default=True): + seen["key"] = key + return False + + monkeypatch.setattr(utils, "resolve_strict_mode", fake) + assert utils.is_strict_mode("PI_GAS_STRICT_MODE") is False + assert seen["key"] == "PI_GAS_STRICT_MODE" + + +@pytest.mark.parametrize("modname,env_key", list(CONSOLIDATED.items())) +def test_env_var_still_honored(monkeypatch, modname, env_key): + mod = importlib.import_module(modname) + monkeypatch.setenv(env_key, "false") + assert mod.is_strict_mode() is False + monkeypatch.setenv(env_key, "true") + assert mod.is_strict_mode() is True + + +def test_consolidated_files_no_longer_read_config_path(): + base = Path(importlib.import_module("pi_micro_agents").__file__).parent + for rel in NO_CONFIG_FILES: + text = (base / rel).read_text(encoding="utf-8") + assert "antigravitycli" not in text, f"{rel} still reads ~/.antigravitycli directly"