Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/app/endpoints/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,7 @@ async def responses_endpoint_handler(
original_request.input, inline_rag_context.context_text
)

api_params = ResponsesApiParams.model_validate(
{
**updated_request.model_dump(exclude={"tools"}),
"tools": updated_request.tools,
}
)
api_params = ResponsesApiParams.model_validate(updated_request.model_dump())
context = ResponsesContext(
client=client,
auth=auth,
Expand Down
12 changes: 12 additions & 0 deletions src/models/api/requests/responses_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from constants import RESPONSES_REQUEST_MAX_SIZE
from models.common.query import SolrVectorSearchRequest
from models.common.responses.types import IncludeParameter, ResponseInput
from models.utils import add_mcp_authorizations
from utils import suid


Expand Down Expand Up @@ -176,3 +177,14 @@ def check_previous_response_id(cls, value: Optional[str]) -> Optional[str]:
if value is not None and value.startswith("modr"):
raise ValueError("You cannot provide context by moderation response.")
return value

def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""Serialize to a request body dict.

Returns:
Serializable dict with MCP authorizations preserved.
"""
result = super().model_dump(*args, **kwargs)
if result.get("tools") is not None and self.tools is not None:
result["tools"] = add_mcp_authorizations(result["tools"], self.tools)
return result
Comment thread
asimurka marked this conversation as resolved.
26 changes: 9 additions & 17 deletions src/models/common/responses/responses_api_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pydantic import BaseModel, Field

from models.common.responses.types import IncludeParameter, ResponseInput
from models.utils import add_mcp_authorizations
from utils.tool_formatter import translate_vector_store_ids_to_user_facing

# Attribute names that are echoed back in the response.
Expand Down Expand Up @@ -126,28 +127,19 @@ class ResponsesApiParams(BaseModel):
)

def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""Serialize params, re-injecting MCP authorization stripped by exclude=True.
"""Serialize to a request body dict.

llama-stack-api marks ``InputToolMCP.authorization`` with
``Field(exclude=True)`` to prevent token leakage in API responses.
The base ``model_dump()`` therefore strips the field, but we need it
in the request payload so llama-stack server can authenticate with
MCP servers. See LCORE-1414 / GitHub issue #1269.
Omits conversation when previous_response_id is set; restores MCP
authorization on dumped tool rows.

Returns:
Serializable dict for the Responses API request body.
"""
result = super().model_dump(*args, **kwargs)
# Only one context option is allowed, previous_response_id has priority
# Turn is added to conversation manually if previous_response_id is used
if self.previous_response_id:
result.pop("conversation", None)
dumped_tools = result.get("tools")
if not self.tools or not isinstance(dumped_tools, list):
return result
if len(dumped_tools) != len(self.tools):
return result
for tool, dumped_tool in zip(self.tools, dumped_tools):
authorization = getattr(tool, "authorization", None)
if authorization is not None and isinstance(dumped_tool, dict):
dumped_tool["authorization"] = authorization
if self.tools is not None and result.get("tools") is not None:
result["tools"] = add_mcp_authorizations(result["tools"], self.tools)
return result

def echoed_params(self, rag_id_mapping: Mapping[str, str]) -> dict[str, Any]:
Expand Down
39 changes: 39 additions & 0 deletions src/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Utility functions for models."""

from typing import Any

from llama_stack_api.openai_responses import OpenAIResponseInputTool as InputTool


def add_mcp_authorizations(
dumped_tools: list[dict[str, Any]],
tools: list[InputTool],
) -> list[dict[str, Any]]:
"""Merge MCP authorization into serialized tool dicts keyed by server_label.

Args:
dumped_tools: Serialized tools.
tools: Live tool models. MCP entries with authorization are mapped by
server_label.

Returns:
A new list of dicts. For MCP rows, authorization is set only when a
matching non-None token exists.
"""
authorizations = {
tool.server_label: tool.authorization
for tool in tools
if tool.type == "mcp" and tool.authorization is not None
} # server_labels are unique by design
result: list[dict[str, Any]] = []
for dumped in dumped_tools:
row = dict(dumped)
if (
row.get("type") == "mcp"
and (label := row.get("server_label")) is not None
and (token := authorizations.get(label)) is not None
):
row["authorization"] = token

result.append(row)
return result
34 changes: 27 additions & 7 deletions tests/unit/app/endpoints/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from llama_stack_api.openai_responses import (
OpenAIResponseInputToolChoiceMode as ToolChoiceMode,
)
from llama_stack_api.openai_responses import OpenAIResponseMessage
from llama_stack_api.openai_responses import (
OpenAIResponseInputToolMCP as InputToolMCP,
)
from llama_stack_api.openai_responses import (
OpenAIResponseMessage,
)
from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient
from pytest_mock import MockerFixture

Expand Down Expand Up @@ -71,12 +76,7 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments
user_agent: Optional[str] = None,
) -> tuple[ResponsesApiParams, ResponsesContext]:
"""Build api_params/context for direct helper invocation tests."""
api_params = ResponsesApiParams.model_validate(
{
**updated_request.model_dump(exclude={"tools"}),
"tools": updated_request.tools,
}
)
api_params = ResponsesApiParams.model_validate(updated_request.model_dump())
context = ResponsesContext.model_construct(
client=client,
auth=auth,
Expand All @@ -94,6 +94,26 @@ def build_api_params_and_context( # pylint: disable=too-many-arguments
return api_params, context


def test_responses_api_params_preserves_mcp_authorization() -> None:
"""After model_validate, MCP tool authorization from model_dump is kept on api_params.tools."""
token = "secret-token"
req = ResponsesRequest(
input="x",
model=MODEL,
conversation=VALID_CONV_ID,
tools=[
InputToolMCP(
server_label="alpha",
server_url="http://alpha",
require_approval="never",
authorization=token,
)
],
)
api = ResponsesApiParams.model_validate(req.model_dump())
assert api.tools is not None and api.tools[0].authorization == token


def _patch_base(mocker: MockerFixture, config: AppConfig) -> None:
"""Patch configuration and mandatory checks for responses endpoint."""
mocker.patch(f"{MODULE}.configuration", config)
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Unit tests for models.utils (mirrors src/models/utils.py)."""

from llama_stack_api.openai_responses import (
OpenAIResponseInputToolFileSearch as InputToolFileSearch,
)
from llama_stack_api.openai_responses import (
OpenAIResponseInputToolMCP as InputToolMCP,
)

from models.utils import add_mcp_authorizations


class TestAddMcpAuthorizations:
"""Tests for add_mcp_authorizations with realistic MCP tool rows.

Assumes server_label is present on MCP dicts and unique across configured
servers; see InputToolMCP in llama-stack-api.
"""

def test_merges_authorization_by_server_label(self) -> None:
"""MCP model_dump omits authorization; the helper restores it by server_label."""
live = InputToolMCP(
server_label="alpha",
server_url="http://alpha",
require_approval="never",
authorization="secret-token",
)
dumped = [live.model_dump()]
assert "authorization" not in dumped[0]

out = add_mcp_authorizations(dumped, [live])
assert len(out) == 1
assert out[0]["authorization"] == "secret-token"
assert out[0]["server_label"] == "alpha"

def test_two_mcp_servers_distinct_tokens(self) -> None:
"""Each server_label receives its own authorization."""
a = InputToolMCP(
server_label="srv-a",
server_url="http://a",
require_approval="never",
authorization="token-a",
)
b = InputToolMCP(
server_label="srv-b",
server_url="http://b",
require_approval="never",
authorization="token-b",
)
dumped = [a.model_dump(), b.model_dump()]
assert "authorization" not in dumped[0]
assert "authorization" not in dumped[1]

out = add_mcp_authorizations(dumped, [a, b])
assert out[0]["authorization"] == "token-a"
assert out[1]["authorization"] == "token-b"

def test_file_search_row_unchanged_no_authorization_merge(self) -> None:
"""Non-MCP rows are copied; MCP row still gets auth from live list."""
mcp = InputToolMCP(
server_label="m",
server_url="http://m",
require_approval="never",
authorization="mcp-secret",
)
fs = InputToolFileSearch(type="file_search", vector_store_ids=["vs-1"])
dumped = [fs.model_dump(), mcp.model_dump()]
assert "authorization" not in dumped[1]

out = add_mcp_authorizations(dumped, [fs, mcp])
assert out[0]["type"] == "file_search"
assert "authorization" not in out[0]
assert out[1]["authorization"] == "mcp-secret"

def test_subset_dumped_rows_still_match_live_by_label(self) -> None:
"""When only some MCP tools appear in dumped_tools, labels still align."""
first = InputToolMCP(
server_label="one",
server_url="http://one",
require_approval="never",
authorization="tok-one",
)
second = InputToolMCP(
server_label="two",
server_url="http://two",
require_approval="never",
authorization="tok-two",
)
dumped = [second.model_dump()]
assert "authorization" not in dumped[0]

out = add_mcp_authorizations(dumped, [first, second])
assert len(out) == 1
assert out[0]["authorization"] == "tok-two"

def test_does_not_mutate_input_list_or_dicts(self) -> None:
"""Output is new containers; inputs stay as provided."""
live = InputToolMCP(
server_label="s",
server_url="http://s",
require_approval="never",
authorization="t",
)
dumped = [live.model_dump()]
row = dumped[0]
assert "authorization" not in row

out = add_mcp_authorizations(dumped, [live])
assert out is not dumped
assert out[0] is not row
assert "authorization" not in row
6 changes: 3 additions & 3 deletions tests/unit/utils/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def test_multiple_mcp_tools_each_preserves_authorization(self) -> None:
assert dumped["tools"][0]["authorization"] == "token-a"
assert dumped["tools"][1]["authorization"] == "token-b"

def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None:
"""Test that exclude removing tool indices does not mis-assign authorization."""
def test_partial_tool_dump_reinjects_auth_by_server_label(self) -> None:
"""When exclude drops some tools, remaining MCP rows still get auth by label."""
tool_a = InputToolMCP(
server_label="server-a",
server_url="http://a:3000",
Expand All @@ -258,7 +258,7 @@ def test_exclude_changing_tool_list_shape_skips_reinjection(self) -> None:
dumped = params.model_dump(exclude={"tools": {0}})
assert len(dumped["tools"]) == 1
assert dumped["tools"][0]["server_label"] == "server-b"
assert "authorization" not in dumped["tools"][0]
assert dumped["tools"][0]["authorization"] == "token-b"

def test_no_tools_does_not_error(self) -> None:
"""Test that model_dump() works when tools is None."""
Expand Down
Loading