From 140ebecf5435620a122d032f7cdf2511376988a5 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Thu, 30 Apr 2026 16:07:20 +0200 Subject: [PATCH] Refactor of responses models dumping --- src/app/endpoints/responses.py | 7 +- src/models/api/requests/responses_openai.py | 12 ++ .../common/responses/responses_api_params.py | 26 ++-- src/models/utils.py | 39 ++++++ tests/unit/app/endpoints/test_responses.py | 34 ++++-- tests/unit/models/test_utils.py | 111 ++++++++++++++++++ tests/unit/utils/test_types.py | 6 +- 7 files changed, 202 insertions(+), 33 deletions(-) create mode 100644 src/models/utils.py create mode 100644 tests/unit/models/test_utils.py diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 56a98c188..8f50ee938 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -598,12 +598,7 @@ async def responses_endpoint_handler( original_request.input, inline_rag_context.context_text ) - api_params = ResponsesApiParams.model_validate( - { - **updated_request.model_dump(exclude={"tools"}), - "tools": updated_request.tools, - } - ) + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) context = ResponsesContext( client=client, auth=auth, diff --git a/src/models/api/requests/responses_openai.py b/src/models/api/requests/responses_openai.py index 2b05f4767..5c99a58c3 100644 --- a/src/models/api/requests/responses_openai.py +++ b/src/models/api/requests/responses_openai.py @@ -23,6 +23,7 @@ from constants import RESPONSES_REQUEST_MAX_SIZE from models.common.query import SolrVectorSearchRequest from models.common.responses.types import IncludeParameter, ResponseInput +from models.utils import add_mcp_authorizations from utils import suid @@ -176,3 +177,14 @@ def check_previous_response_id(cls, value: Optional[str]) -> Optional[str]: if value is not None and value.startswith("modr"): raise ValueError("You cannot provide context by moderation response.") return value + + def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Serialize to a request body dict. + + Returns: + Serializable dict with MCP authorizations preserved. + """ + result = super().model_dump(*args, **kwargs) + if result.get("tools") is not None and self.tools is not None: + result["tools"] = add_mcp_authorizations(result["tools"], self.tools) + return result diff --git a/src/models/common/responses/responses_api_params.py b/src/models/common/responses/responses_api_params.py index acb219c89..75eda77ba 100644 --- a/src/models/common/responses/responses_api_params.py +++ b/src/models/common/responses/responses_api_params.py @@ -24,6 +24,7 @@ from pydantic import BaseModel, Field from models.common.responses.types import IncludeParameter, ResponseInput +from models.utils import add_mcp_authorizations from utils.tool_formatter import translate_vector_store_ids_to_user_facing # Attribute names that are echoed back in the response. @@ -126,28 +127,19 @@ class ResponsesApiParams(BaseModel): ) def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Serialize params, re-injecting MCP authorization stripped by exclude=True. + """Serialize to a request body dict. - llama-stack-api marks ``InputToolMCP.authorization`` with - ``Field(exclude=True)`` to prevent token leakage in API responses. - The base ``model_dump()`` therefore strips the field, but we need it - in the request payload so llama-stack server can authenticate with - MCP servers. See LCORE-1414 / GitHub issue #1269. + Omits conversation when previous_response_id is set; restores MCP + authorization on dumped tool rows. + + Returns: + Serializable dict for the Responses API request body. """ result = super().model_dump(*args, **kwargs) - # Only one context option is allowed, previous_response_id has priority - # Turn is added to conversation manually if previous_response_id is used if self.previous_response_id: result.pop("conversation", None) - dumped_tools = result.get("tools") - if not self.tools or not isinstance(dumped_tools, list): - return result - if len(dumped_tools) != len(self.tools): - return result - for tool, dumped_tool in zip(self.tools, dumped_tools): - authorization = getattr(tool, "authorization", None) - if authorization is not None and isinstance(dumped_tool, dict): - dumped_tool["authorization"] = authorization + if self.tools is not None and result.get("tools") is not None: + result["tools"] = add_mcp_authorizations(result["tools"], self.tools) return result def echoed_params(self, rag_id_mapping: Mapping[str, str]) -> dict[str, Any]: diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 000000000..7d757c967 --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,39 @@ +"""Utility functions for models.""" + +from typing import Any + +from llama_stack_api.openai_responses import OpenAIResponseInputTool as InputTool + + +def add_mcp_authorizations( + dumped_tools: list[dict[str, Any]], + tools: list[InputTool], +) -> list[dict[str, Any]]: + """Merge MCP authorization into serialized tool dicts keyed by server_label. + + Args: + dumped_tools: Serialized tools. + tools: Live tool models. MCP entries with authorization are mapped by + server_label. + + Returns: + A new list of dicts. For MCP rows, authorization is set only when a + matching non-None token exists. + """ + authorizations = { + tool.server_label: tool.authorization + for tool in tools + if tool.type == "mcp" and tool.authorization is not None + } # server_labels are unique by design + result: list[dict[str, Any]] = [] + for dumped in dumped_tools: + row = dict(dumped) + if ( + row.get("type") == "mcp" + and (label := row.get("server_label")) is not None + and (token := authorizations.get(label)) is not None + ): + row["authorization"] = token + + result.append(row) + return result diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index f6b477e78..352f83876 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -12,7 +12,12 @@ from llama_stack_api.openai_responses import ( OpenAIResponseInputToolChoiceMode as ToolChoiceMode, ) -from llama_stack_api.openai_responses import OpenAIResponseMessage +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolMCP as InputToolMCP, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseMessage, +) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient from pytest_mock import MockerFixture @@ -71,12 +76,7 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments user_agent: Optional[str] = None, ) -> tuple[ResponsesApiParams, ResponsesContext]: """Build api_params/context for direct helper invocation tests.""" - api_params = ResponsesApiParams.model_validate( - { - **updated_request.model_dump(exclude={"tools"}), - "tools": updated_request.tools, - } - ) + api_params = ResponsesApiParams.model_validate(updated_request.model_dump()) context = ResponsesContext.model_construct( client=client, auth=auth, @@ -94,6 +94,26 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments return api_params, context +def test_responses_api_params_preserves_mcp_authorization() -> None: + """After model_validate, MCP tool authorization from model_dump is kept on api_params.tools.""" + token = "secret-token" + req = ResponsesRequest( + input="x", + model=MODEL, + conversation=VALID_CONV_ID, + tools=[ + InputToolMCP( + server_label="alpha", + server_url="http://alpha", + require_approval="never", + authorization=token, + ) + ], + ) + api = ResponsesApiParams.model_validate(req.model_dump()) + assert api.tools is not None and api.tools[0].authorization == token + + def _patch_base(mocker: MockerFixture, config: AppConfig) -> None: """Patch configuration and mandatory checks for responses endpoint.""" mocker.patch(f"{MODULE}.configuration", config) diff --git a/tests/unit/models/test_utils.py b/tests/unit/models/test_utils.py new file mode 100644 index 000000000..53919af0e --- /dev/null +++ b/tests/unit/models/test_utils.py @@ -0,0 +1,111 @@ +"""Unit tests for models.utils (mirrors src/models/utils.py).""" + +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFileSearch as InputToolFileSearch, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolMCP as InputToolMCP, +) + +from models.utils import add_mcp_authorizations + + +class TestAddMcpAuthorizations: + """Tests for add_mcp_authorizations with realistic MCP tool rows. + + Assumes server_label is present on MCP dicts and unique across configured + servers; see InputToolMCP in llama-stack-api. + """ + + def test_merges_authorization_by_server_label(self) -> None: + """MCP model_dump omits authorization; the helper restores it by server_label.""" + live = InputToolMCP( + server_label="alpha", + server_url="http://alpha", + require_approval="never", + authorization="secret-token", + ) + dumped = [live.model_dump()] + assert "authorization" not in dumped[0] + + out = add_mcp_authorizations(dumped, [live]) + assert len(out) == 1 + assert out[0]["authorization"] == "secret-token" + assert out[0]["server_label"] == "alpha" + + def test_two_mcp_servers_distinct_tokens(self) -> None: + """Each server_label receives its own authorization.""" + a = InputToolMCP( + server_label="srv-a", + server_url="http://a", + require_approval="never", + authorization="token-a", + ) + b = InputToolMCP( + server_label="srv-b", + server_url="http://b", + require_approval="never", + authorization="token-b", + ) + dumped = [a.model_dump(), b.model_dump()] + assert "authorization" not in dumped[0] + assert "authorization" not in dumped[1] + + out = add_mcp_authorizations(dumped, [a, b]) + assert out[0]["authorization"] == "token-a" + assert out[1]["authorization"] == "token-b" + + def test_file_search_row_unchanged_no_authorization_merge(self) -> None: + """Non-MCP rows are copied; MCP row still gets auth from live list.""" + mcp = InputToolMCP( + server_label="m", + server_url="http://m", + require_approval="never", + authorization="mcp-secret", + ) + fs = InputToolFileSearch(type="file_search", vector_store_ids=["vs-1"]) + dumped = [fs.model_dump(), mcp.model_dump()] + assert "authorization" not in dumped[1] + + out = add_mcp_authorizations(dumped, [fs, mcp]) + assert out[0]["type"] == "file_search" + assert "authorization" not in out[0] + assert out[1]["authorization"] == "mcp-secret" + + def test_subset_dumped_rows_still_match_live_by_label(self) -> None: + """When only some MCP tools appear in dumped_tools, labels still align.""" + first = InputToolMCP( + server_label="one", + server_url="http://one", + require_approval="never", + authorization="tok-one", + ) + second = InputToolMCP( + server_label="two", + server_url="http://two", + require_approval="never", + authorization="tok-two", + ) + dumped = [second.model_dump()] + assert "authorization" not in dumped[0] + + out = add_mcp_authorizations(dumped, [first, second]) + assert len(out) == 1 + assert out[0]["authorization"] == "tok-two" + + def test_does_not_mutate_input_list_or_dicts(self) -> None: + """Output is new containers; inputs stay as provided.""" + live = InputToolMCP( + server_label="s", + server_url="http://s", + require_approval="never", + authorization="t", + ) + dumped = [live.model_dump()] + row = dumped[0] + assert "authorization" not in row + + out = add_mcp_authorizations(dumped, [live]) + assert out is not dumped + assert out[0] is not row + assert "authorization" not in row diff --git a/tests/unit/utils/test_types.py b/tests/unit/utils/test_types.py index 8447054f0..3a6f707e7 100644 --- a/tests/unit/utils/test_types.py +++ b/tests/unit/utils/test_types.py @@ -240,8 +240,8 @@ def test_multiple_mcp_tools_each_preserves_authorization(self) -> None: assert dumped["tools"][0]["authorization"] == "token-a" assert dumped["tools"][1]["authorization"] == "token-b" - def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None: - """Test that exclude removing tool indices does not mis-assign authorization.""" + def test_partial_tool_dump_reinjects_auth_by_server_label(self) -> None: + """When exclude drops some tools, remaining MCP rows still get auth by label.""" tool_a = InputToolMCP( server_label="server-a", server_url="http://a:3000", @@ -258,7 +258,7 @@ def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None: dumped = params.model_dump(exclude={"tools": {0}}) assert len(dumped["tools"]) == 1 assert dumped["tools"][0]["server_label"] == "server-b" - assert "authorization" not in dumped["tools"][0] + assert dumped["tools"][0]["authorization"] == "token-b" def test_no_tools_does_not_error(self) -> None: """Test that model_dump() works when tools is None."""