diff --git a/docs/openapi.json b/docs/openapi.json index 8ff6e171e..1486b9af6 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -19574,15 +19574,15 @@ "round": { "type": "integer", "title": "Round", - "description": "Round number or step of tool execution" + "description": "Round number or step of tool execution", + "default": 1 } }, "type": "object", "required": [ "id", "status", - "content", - "round" + "content" ], "title": "ToolResultSummary", "description": "Model representing a result from a tool call (for tool_results list)." diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 6079d2aa4..90dff71ce 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,34 +4,14 @@ import asyncio import datetime -import json from collections.abc import AsyncIterator -from typing import Annotated, Any, Optional, cast +from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from llama_stack_api import ( - OpenAIResponseObject, OpenAIResponseObjectStream, ) -from llama_stack_api import ( - OpenAIResponseObjectStreamResponseMcpCallArgumentsDone as MCPArgsDoneChunk, -) -from llama_stack_api import ( - OpenAIResponseObjectStreamResponseOutputItemAdded as OutputItemAddedChunk, -) -from llama_stack_api import ( - OpenAIResponseObjectStreamResponseOutputItemDone as OutputItemDoneChunk, -) -from llama_stack_api import ( - OpenAIResponseObjectStreamResponseOutputTextDelta as TextDeltaChunk, -) -from llama_stack_api import ( - OpenAIResponseObjectStreamResponseOutputTextDone as TextDoneChunk, -) -from llama_stack_api import ( - OpenAIResponseOutputMessageMCPCall as MCPCall, -) from llama_stack_client import ( APIConnectionError, ) @@ -49,10 +29,6 @@ from constants import ( ENDPOINT_PATH_STREAMING_QUERY, INTERRUPTED_RESPONSE_MESSAGE, - LLM_TOKEN_EVENT, - LLM_TOOL_CALL_EVENT, - LLM_TOOL_RESULT_EVENT, - LLM_TURN_COMPLETE_EVENT, MEDIA_TYPE_EVENT_STREAM, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, @@ -63,7 +39,6 @@ from models.api.requests import QueryRequest from models.api.responses.constants import UNAUTHORIZED_OPENAPI_EXAMPLES_WITH_MCP_OAUTH from models.api.responses.error import ( - AbstractErrorResponse, ForbiddenResponse, InternalServerErrorResponse, NotFoundResponse, @@ -75,7 +50,7 @@ ) from models.api.responses.successful import StreamingQueryResponse from models.common.responses.responses_api_params import ResponsesApiParams -from models.common.turn_summary import ReferencedDocument, TurnSummary +from models.common.turn_summary import TurnSummary from models.config import Action from models.context import ResponseGeneratorContext from utils.conversations import append_turn_items_to_conversation @@ -99,9 +74,6 @@ ) from utils.quota import check_tokens_available, get_available_quotas from utils.responses import ( - build_mcp_tool_call_from_arguments_done, - build_tool_call_summary, - build_tool_result_from_mcp_output_item_done, deduplicate_referenced_documents, extract_token_usage, extract_vector_store_ids_from_tools, @@ -116,8 +88,17 @@ validate_shield_ids_override, ) from utils.stream_interrupts import get_stream_interrupt_registry +from utils.streaming.chunk_dispatchers import dispatch_stream_chunk +from utils.streaming.event_serializers import ( + serialize_end_event, + serialize_event, + serialize_http_error_event, + serialize_interrupted_event, + serialize_start_event, +) +from utils.streaming.state import StreamDispatchState +from utils.streaming.stream_payloads import LlmTokenChunkData, LlmTokenStreamPayload from utils.suid import get_suid, normalize_conversation_id -from utils.token_counter import TokenCounter from utils.vector_search import build_rag_context logger = get_logger(__name__) @@ -593,7 +574,7 @@ async def generate_response( stream_completed = False try: - yield stream_start_event( + yield serialize_start_event( conversation_id=context.conversation_id, request_id=context.request_id, ) @@ -611,16 +592,22 @@ async def generate_response( if is_context_length_error(str(e)) else InternalServerErrorResponse.generic() ) - yield stream_http_error_event(error_response, context.query_request.media_type) + yield serialize_http_error_event( + error_response, context.query_request.media_type + ) except APIConnectionError as e: error_response = ServiceUnavailableResponse( backend_name="Llama Stack", cause=str(e), ) - yield stream_http_error_event(error_response, context.query_request.media_type) + yield serialize_http_error_event( + error_response, context.query_request.media_type + ) except (LLSApiStatusError, OpenAIAPIStatusError) as e: error_response = handle_known_apistatus_errors(e, responses_params.model) - yield stream_http_error_event(error_response, context.query_request.media_type) + yield serialize_http_error_event( + error_response, context.query_request.media_type + ) except asyncio.CancelledError: logger.info("Streaming request %s interrupted by user", context.request_id) current_task = asyncio.current_task() @@ -630,7 +617,7 @@ async def generate_response( persist_guard[0] = True turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE await _persist_interrupted_turn(context, responses_params, turn_summary) - yield stream_interrupted_event(context.request_id) + yield serialize_interrupted_event(context.request_id) finally: get_stream_interrupt_registry().deregister_stream(context.request_id) @@ -664,7 +651,7 @@ async def generate_response( quota_limiters=configuration.quota_limiters, user_id=context.user_id ) - yield stream_end_event( + yield serialize_end_event( turn_summary.token_usage, available_quotas, turn_summary.referenced_documents, @@ -688,7 +675,7 @@ async def generate_response( ) -async def response_generator( # pylint: disable=too-many-branches,too-many-statements,too-many-locals +async def response_generator( turn_response: AsyncIterator[OpenAIResponseObjectStream], context: ResponseGeneratorContext, turn_summary: TurnSummary, @@ -710,157 +697,29 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat SSE-formatted strings for tokens, tool calls, tool results, turn completion, and error events. """ - chunk_id = 0 media_type = context.query_request.media_type or MEDIA_TYPE_JSON - text_parts: list[str] = [] - mcp_calls: dict[int, tuple[str, str]] = ( - {} - ) # output_index -> (mcp_call_id, mcp_call_name) - latest_response_object: Optional[OpenAIResponseObject] = None + dispatch_state = StreamDispatchState() logger.debug("Starting streaming response (Responses API) processing") - async for chunk in turn_response: - event_type = getattr(chunk, "type", None) - logger.debug("Processing chunk %d, type: %s", chunk_id, event_type) - - # Content part started - emit an empty token to kick off UI streaming - if event_type == "response.content_part.added": - yield stream_event( - { - "id": chunk_id, - "token": "", - }, - LLM_TOKEN_EVENT, - media_type, - ) - chunk_id += 1 - - # Store MCP call item info for later lookup when arguments.done event occurs - elif event_type == "response.output_item.added": - item_added_chunk = cast(OutputItemAddedChunk, chunk) - if item_added_chunk.item.type == "mcp_call": - mcp_call_item = cast(MCPCall, item_added_chunk.item) - mcp_calls[item_added_chunk.output_index] = ( - mcp_call_item.id, - mcp_call_item.name, - ) - - # Text streaming - emit token delta - elif event_type == "response.output_text.delta": - delta_chunk = cast(TextDeltaChunk, chunk) - text_parts.append(delta_chunk.delta) - yield stream_event( - { - "id": chunk_id, - "token": delta_chunk.delta, - }, - LLM_TOKEN_EVENT, - media_type, - ) - chunk_id += 1 - - # Final text of the output (capture, but emit at response.completed) - elif event_type == "response.output_text.done": - text_done_chunk = cast(TextDoneChunk, chunk) - turn_summary.llm_response = text_done_chunk.text - - # Emit tool call when MCP call arguments are done - elif event_type == "response.mcp_call.arguments.done": - mcp_arguments_done_chunk = cast(MCPArgsDoneChunk, chunk) - tool_call = build_mcp_tool_call_from_arguments_done( - mcp_arguments_done_chunk.output_index, - mcp_arguments_done_chunk.arguments, - mcp_calls, - ) - if tool_call: - turn_summary.tool_calls.append(tool_call) - yield stream_event( - tool_call.model_dump(), - LLM_TOOL_CALL_EVENT, - media_type, - ) - - # Process tool calls and results when output items are done - # For mcp_call, only emit result (call was already emitted when arguments.done) - # For other types, emit both call and result - elif event_type == "response.output_item.done": - output_item_done_chunk = cast(OutputItemDoneChunk, chunk) - item_type = output_item_done_chunk.item.type - # Skip message items as they are parsed separately - if item_type == "message": - continue - - output_index = output_item_done_chunk.output_index - - # For mcp_call, only emit result if call was already emitted when arguments.done - # (indicated by output_index not being in mcp_calls dict) - # If output_index is in dict, process in else branch (emit both call and result) - if item_type == "mcp_call" and output_index not in mcp_calls: - # Call was already emitted during arguments.done, only emit result - mcp_call_item = cast(MCPCall, output_item_done_chunk.item) - tool_result = build_tool_result_from_mcp_output_item_done(mcp_call_item) - turn_summary.tool_results.append(tool_result) - yield stream_event( - tool_result.model_dump(), - LLM_TOOL_RESULT_EVENT, - media_type, - ) - else: - # For all other types (and mcp_call when arguments.done didn't happen), - # emit both call and result together - tool_call, tool_result = build_tool_call_summary( - output_item_done_chunk.item - ) - if tool_call: - turn_summary.tool_calls.append(tool_call) - yield stream_event( - tool_call.model_dump(), - LLM_TOOL_CALL_EVENT, - media_type, - ) - if tool_result: - turn_summary.tool_results.append(tool_result) - yield stream_event( - tool_result.model_dump(), - LLM_TOOL_RESULT_EVENT, - media_type, - ) - - # Completed response - capture final text and response object - elif event_type == "response.completed": - latest_response_object = cast( - OpenAIResponseObject, - getattr(chunk, "response"), # noqa: B009 - ) - turn_summary.llm_response = turn_summary.llm_response or "".join(text_parts) - yield stream_event( - { - "id": chunk_id, - "token": turn_summary.llm_response, - }, - LLM_TURN_COMPLETE_EVENT, - media_type, - ) - chunk_id += 1 + logger.debug( + "Processing chunk %d, type: %s", + dispatch_state.chunk_id, + chunk.type, + ) + dispatch_result = dispatch_stream_chunk( + chunk, + dispatch_state, + media_type, + context.model_id, + ) + dispatch_state = dispatch_result.state + for event in dispatch_result.events: + yield event - # Incomplete or failed response - emit error - elif event_type in ("response.incomplete", "response.failed"): - latest_response_object = cast( - OpenAIResponseObject, - getattr(chunk, "response"), # noqa: B009 - ) - error_message = ( - latest_response_object.error.message - if latest_response_object.error - else "An unexpected error occurred while processing the request." - ) - error_response = ( - PromptTooLongResponse(model=context.model_id) - if is_context_length_error(error_message) - else InternalServerErrorResponse.query_failed(error_message) - ) - yield stream_http_error_event(error_response, media_type) + turn_summary.llm_response = dispatch_state.llm_response + turn_summary.tool_calls.extend(dispatch_state.tool_calls) + turn_summary.tool_results.extend(dispatch_state.tool_results) logger.debug( "Streaming complete - Tool calls: %d, Response chars: %d", @@ -869,15 +728,16 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat ) # Extract token usage and referenced documents from the final response object - if not latest_response_object: + if dispatch_state.latest_response_object is None: return + final_response = dispatch_state.latest_response_object turn_summary.token_usage = extract_token_usage( - latest_response_object.usage, context.model_id, endpoint_path + final_response.usage, context.model_id, endpoint_path ) # Parse tool-based referenced documents from the final response object tool_rag_docs = parse_referenced_documents( - latest_response_object, + final_response, vector_store_ids=context.vector_store_ids, rag_id_mapping=context.rag_id_mapping, ) @@ -886,191 +746,13 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat context.inline_rag_context.referenced_documents + tool_rag_docs ) tool_rag_chunks = parse_rag_chunks( - latest_response_object, + final_response, vector_store_ids=context.vector_store_ids, rag_id_mapping=context.rag_id_mapping, ) turn_summary.rag_chunks = context.inline_rag_context.rag_chunks + tool_rag_chunks -def stream_http_error_event( - error: AbstractErrorResponse, media_type: Optional[str] = MEDIA_TYPE_JSON -) -> str: - """ - Create an SSE-formatted error response for generic LLM or API errors. - - Args: - error: An AbstractErrorResponse instance representing the error. - media_type: The media type for the response format. Defaults to MEDIA_TYPE_JSON if None. - - Returns: - str: A Server-Sent Events (SSE) formatted error message containing - the serialized error details. - """ - logger.error("Error while obtaining answer for user question") - media_type = media_type or MEDIA_TYPE_JSON - if media_type == MEDIA_TYPE_TEXT: - return f"Status: {error.status_code} - {error.detail.response} - {error.detail.cause}" - - return format_stream_data( - { - "event": "error", - "data": { - "status_code": error.status_code, - "response": error.detail.response, - "cause": error.detail.cause, - }, - } - ) - - -def format_stream_data(d: dict) -> str: - """ - Create a response generator function for Responses API streaming. - - Parameters: - ---------- - d (dict): The data to be formatted as an SSE event. - - Returns: - ------- - str: The formatted SSE data string. - """ - data = json.dumps(d) - return f"data: {data}\n\n" - - -def stream_start_event(conversation_id: str, request_id: str) -> str: - """Format an SSE start event for a streaming response. - - The payload contains both the conversation ID and the request ID - so the client can correlate the stream with a conversation and - use the request ID to issue an interrupt if needed. - - Parameters: - ---------- - conversation_id (str): Unique identifier for the conversation. - request_id (str): Unique SUID for this streaming request, - returned to the client for interrupt support. - - Returns: - ------- - str: SSE-formatted string representing the start event. - """ - return format_stream_data( - { - "event": "start", - "data": { - "conversation_id": conversation_id, - "request_id": request_id, - }, - } - ) - - -def stream_interrupted_event(request_id: str) -> str: - """Format an SSE event indicating the stream was interrupted. - - Emitted to the client just before the generator closes so the - frontend can distinguish an intentional user-initiated interruption - from an unexpected connection drop. - - Parameters: - ---------- - request_id (str): Unique identifier for the interrupted request. - - Returns: - ------- - str: SSE-formatted string representing the interrupted event. - """ - return format_stream_data( - { - "event": "interrupted", - "data": { - "request_id": request_id, - }, - } - ) - - -def stream_end_event( - token_usage: TokenCounter, - available_quotas: dict[str, int], - referenced_documents: list[ReferencedDocument], - media_type: str = MEDIA_TYPE_JSON, -) -> str: - """ - Yield the end of the data stream. - - Format and return the end event for a streaming response, - including referenced document metadata and token usage information. - - Parameters: - ---------- - token_usage (TokenCounter): Token usage information. - available_quotas (dict[str, int]): Available quotas for the user. - referenced_documents (list[ReferencedDocument]): List of referenced documents. - media_type (str): The media type for the response format. - - Returns: - ------- - str: A Server-Sent Events (SSE) formatted string - representing the end of the data stream. - """ - if media_type == MEDIA_TYPE_TEXT: - ref_docs_string = "\n".join( - f"{doc.doc_title}: {doc.doc_url}" - for doc in referenced_documents - if doc.doc_url and doc.doc_title - ) - return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" - - referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents] - - return format_stream_data( - { - "event": "end", - "data": { - "referenced_documents": referenced_docs_dict, - "truncated": None, - "input_tokens": token_usage.input_tokens, - "output_tokens": token_usage.output_tokens, - }, - "available_quotas": available_quotas, - } - ) - - -def stream_event(data: dict, event_type: str, media_type: str) -> str: - """Build an item to yield based on media type. - - Args: - data: Dictionary containing the event data - event_type: Type of event (token, tool call, etc.) - media_type: The media type for the response format - - Returns: - SSE-formatted string representing the event - """ - if media_type == MEDIA_TYPE_TEXT: - if event_type == LLM_TOKEN_EVENT: - return data.get("token", "") - if event_type == LLM_TOOL_CALL_EVENT: - return f"[Tool Call: {data.get('function_name', 'unknown')}]\n" - if event_type == LLM_TOOL_RESULT_EVENT: - return "[Tool Result]\n" - if event_type == LLM_TURN_COMPLETE_EVENT: - return "" - return "" - - return format_stream_data( - { - "event": event_type, - "data": data, - } - ) - - async def shield_violation_generator( violation_message: str, media_type: str = MEDIA_TYPE_TEXT, @@ -1089,11 +771,9 @@ async def shield_violation_generator( Yields: str: SSE-formatted strings for start, token, and end events. """ - yield stream_event( - { - "id": 0, - "token": violation_message, - }, - LLM_TOKEN_EVENT, + yield serialize_event( + LlmTokenStreamPayload( + data=LlmTokenChunkData(id=0, token=violation_message), + ), media_type, ) diff --git a/src/models/common/turn_summary.py b/src/models/common/turn_summary.py index 920a17c71..cf4b27b09 100644 --- a/src/models/common/turn_summary.py +++ b/src/models/common/turn_summary.py @@ -78,6 +78,33 @@ class ToolCallSummary(BaseModel): type: str = Field("tool_call", description="Type indicator for tool call") +class ToolInfoSummary(BaseModel): + """Model representing metadata for a single tool exposed by MCP list tools.""" + + name: str = Field(description="Tool name") + description: Optional[str] = Field( + default=None, + description="Human-readable tool description", + ) + input_schema: Optional[dict[str, Any]] = Field( + default=None, + description="JSON schema for the tool input", + ) + + +class MCPListToolsSummary(BaseModel): + """Model representing MCP list tools payload serialized into tool results.""" + + server_label: Optional[str] = Field( + default=None, + description="MCP server label associated with the tool list", + ) + tools: list[ToolInfoSummary] = Field( + default_factory=list, + description="Tools exposed by the MCP server", + ) + + class ToolResultSummary(BaseModel): """Model representing a result from a tool call (for tool_results list).""" @@ -89,7 +116,7 @@ class ToolResultSummary(BaseModel): ) content: str = Field(..., description="Content/result returned from the tool") type: str = Field("tool_result", description="Type indicator for tool result") - round: int = Field(..., description="Round number or step of tool execution") + round: int = Field(default=1, description="Round number or step of tool execution") class TurnSummary(BaseModel): diff --git a/src/utils/streaming/__init__.py b/src/utils/streaming/__init__.py new file mode 100644 index 000000000..b281a3256 --- /dev/null +++ b/src/utils/streaming/__init__.py @@ -0,0 +1,58 @@ +"""Streaming utilities package.""" + +from utils.streaming.chunk_dispatchers import dispatch_stream_chunk +from utils.streaming.event_serializers import ( + format_stream_data, + serialize_end_event, + serialize_event, + serialize_http_error_event, + serialize_interrupted_event, + serialize_start_event, +) +from utils.streaming.output_item_dispatchers import dispatch_output_item_done +from utils.streaming.state import ChunkDispatchResult, StreamDispatchState +from utils.streaming.stream_payloads import ( + EndEventData, + EndStreamPayload, + ErrorEventData, + ErrorStreamPayload, + InterruptedEventData, + InterruptedStreamPayload, + LlmTokenChunkData, + LlmTokenStreamPayload, + LlmToolCallStreamPayload, + LlmToolResultStreamPayload, + LlmTurnCompleteStreamPayload, + StartEventData, + StartStreamPayload, + StreamLlmEventPayload, + StreamPayloadBase, +) + +__all__ = [ + "ChunkDispatchResult", + "EndEventData", + "EndStreamPayload", + "ErrorEventData", + "ErrorStreamPayload", + "InterruptedEventData", + "InterruptedStreamPayload", + "LlmTokenChunkData", + "LlmTokenStreamPayload", + "LlmToolCallStreamPayload", + "LlmToolResultStreamPayload", + "LlmTurnCompleteStreamPayload", + "StartEventData", + "StartStreamPayload", + "StreamDispatchState", + "StreamLlmEventPayload", + "StreamPayloadBase", + "dispatch_output_item_done", + "dispatch_stream_chunk", + "format_stream_data", + "serialize_end_event", + "serialize_event", + "serialize_http_error_event", + "serialize_interrupted_event", + "serialize_start_event", +] diff --git a/src/utils/streaming/chunk_dispatchers.py b/src/utils/streaming/chunk_dispatchers.py new file mode 100644 index 000000000..364cbfc81 --- /dev/null +++ b/src/utils/streaming/chunk_dispatchers.py @@ -0,0 +1,243 @@ +"""Dispatchers for streaming response chunks.""" + +from dataclasses import replace +from functools import singledispatch +from typing import cast + +from llama_stack_api import OpenAIResponseObjectStream +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseCompleted as CompletedChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseContentPartAdded as ContentPartAddedChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseFailed as FailedChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseIncomplete as IncompleteChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone as MCPArgsDoneChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputItemAdded as OutputItemAddedChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputItemDone as OutputItemDoneChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputTextDelta as TextDeltaChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputTextDone as TextDoneChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageMCPCall as MCPCall, +) + +from log import get_logger +from models.api.responses.error import ( + InternalServerErrorResponse, + PromptTooLongResponse, +) +from models.common.turn_summary import ToolCallSummary +from utils.query import is_context_length_error +from utils.responses import parse_arguments_string +from utils.streaming.event_serializers import ( + serialize_event, +) +from utils.streaming.output_item_dispatchers import dispatch_output_item_done +from utils.streaming.state import ChunkDispatchResult, StreamDispatchState +from utils.streaming.stream_payloads import ( + ErrorEventData, + ErrorStreamPayload, + LlmTokenChunkData, + LlmTokenStreamPayload, + LlmToolCallStreamPayload, + LlmTurnCompleteStreamPayload, +) + +logger = get_logger(__name__) + + +@singledispatch +def dispatch_stream_chunk( + chunk: OpenAIResponseObjectStream, + state: StreamDispatchState, + _media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Fallback dispatcher for unknown chunk types.""" + logger.debug( + "Ignoring unsupported chunk type=%s", + getattr(chunk, "type", None), + ) + return ChunkDispatchResult(state=state) + + +@dispatch_stream_chunk.register +def _( + _chunk: ContentPartAddedChunk, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Handle content part start by emitting an empty token.""" + payload = LlmTokenStreamPayload( + data=LlmTokenChunkData(id=state.chunk_id, token=""), + ) + return ChunkDispatchResult( + state=replace(state, chunk_id=state.chunk_id + 1), + events=[serialize_event(payload, media_type)], + ) + + +@dispatch_stream_chunk.register +def _( + chunk: OutputItemAddedChunk, + state: StreamDispatchState, + _media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Track MCP call metadata for arguments.done events.""" + if chunk.item.type != "mcp_call": + return ChunkDispatchResult(state=state) + + mcp_call_item = cast(MCPCall, chunk.item) + next_mcp_calls = { + **state.mcp_calls, + chunk.output_index: (mcp_call_item.id, mcp_call_item.name), + } + return ChunkDispatchResult(state=replace(state, mcp_calls=next_mcp_calls)) + + +@dispatch_stream_chunk.register +def _( + chunk: TextDeltaChunk, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Handle token delta chunks.""" + state.text_parts.append(chunk.delta) + payload = LlmTokenStreamPayload( + data=LlmTokenChunkData(id=state.chunk_id, token=chunk.delta) + ) + return ChunkDispatchResult( + state=replace(state, chunk_id=state.chunk_id + 1), + events=[serialize_event(payload, media_type)], + ) + + +@dispatch_stream_chunk.register +def _( + chunk: TextDoneChunk, + state: StreamDispatchState, + _media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Store final generated text from output_text.done.""" + return ChunkDispatchResult(state=replace(state, llm_response=chunk.text)) + + +@dispatch_stream_chunk.register +def _( + chunk: MCPArgsDoneChunk, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit MCP tool call when arguments are complete.""" + next_mcp_calls = dict(state.mcp_calls) + item_info = next_mcp_calls.pop(chunk.output_index, None) + if item_info is None: + return ChunkDispatchResult(state=replace(state, mcp_calls=next_mcp_calls)) + + item_id, item_name = item_info + tool_call = ToolCallSummary( + id=item_id, + name=item_name, + args=parse_arguments_string(chunk.arguments), + type="mcp_call", + ) + payload = LlmToolCallStreamPayload(data=tool_call) + return ChunkDispatchResult( + state=replace( + state, + mcp_calls=next_mcp_calls, + tool_calls=[*state.tool_calls, tool_call], + ), + events=[serialize_event(payload, media_type)], + ) + + +@dispatch_stream_chunk.register +def _( + chunk: OutputItemDoneChunk, + state: StreamDispatchState, + media_type: str, + model_id: str, +) -> ChunkDispatchResult: + """Handle output item completion for tool calls and results.""" + return dispatch_output_item_done( + chunk.item, + chunk.output_index, + state, + media_type, + model_id, + ) + + +@dispatch_stream_chunk.register +def _( + chunk: CompletedChunk, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Handle successful response completion.""" + final_text = state.llm_response or "".join(state.text_parts) + payload = LlmTurnCompleteStreamPayload( + data=LlmTokenChunkData(id=state.chunk_id, token=final_text), + ) + return ChunkDispatchResult( + state=replace( + state, + chunk_id=state.chunk_id + 1, + latest_response_object=chunk.response, + llm_response=final_text, + ), + events=[serialize_event(payload, media_type)], + ) + + +@dispatch_stream_chunk.register +def _( + chunk: IncompleteChunk | FailedChunk, + state: StreamDispatchState, + media_type: str, + model_id: str, +) -> ChunkDispatchResult: + """Handle incomplete or failed response.""" + error_message = ( + chunk.response.error.message + if chunk.response.error is not None + else "An unexpected error occurred while processing the request." + ) + error_response = ( + PromptTooLongResponse(model=model_id) + if is_context_length_error(error_message) + else InternalServerErrorResponse.query_failed(error_message) + ) + payload = ErrorStreamPayload( + data=ErrorEventData( + status_code=error_response.status_code, + response=error_response.detail.response, + cause=error_response.detail.cause, # pylint: disable=no-member + ), + ) + return ChunkDispatchResult( + state=replace(state, latest_response_object=chunk.response), + events=[serialize_event(payload, media_type)], + ) diff --git a/src/utils/streaming/event_serializers.py b/src/utils/streaming/event_serializers.py new file mode 100644 index 000000000..d12b97d0a --- /dev/null +++ b/src/utils/streaming/event_serializers.py @@ -0,0 +1,164 @@ +"""Shared streaming event formatting utilities.""" + +import json +from functools import singledispatch +from typing import Optional + +from constants import MEDIA_TYPE_JSON +from log import get_logger +from models.api.responses.error import ( + AbstractErrorResponse, +) +from models.common.turn_summary import ReferencedDocument +from utils.streaming.stream_payloads import ( + EndEventData, + EndStreamPayload, + ErrorEventData, + ErrorStreamPayload, + InterruptedEventData, + InterruptedStreamPayload, + LlmTokenStreamPayload, + LlmToolCallStreamPayload, + LlmToolResultStreamPayload, + LlmTurnCompleteStreamPayload, + StartEventData, + StartStreamPayload, + StreamLlmEventPayload, + StreamPayloadBase, +) +from utils.token_counter import TokenCounter + +logger = get_logger(__name__) + + +def format_stream_data(data: StreamPayloadBase) -> str: + """Format a Pydantic payload as an SSE ``data:`` line.""" + return f"data: {json.dumps(data.model_dump(mode='json'))}\n\n" + + +def serialize_http_error_event( + error: AbstractErrorResponse, + media_type: Optional[str] = None, +) -> str: + """Serialize an API error to an SSE or plain-text client response.""" + logger.error("Error while obtaining answer for user question") + resolved_media_type = media_type or MEDIA_TYPE_JSON + payload = ErrorStreamPayload( + data=ErrorEventData( + status_code=error.status_code, + response=error.detail.response, + cause=error.detail.cause, + ), + ) + return serialize_event(payload, resolved_media_type) + + +def serialize_start_event( + conversation_id: str, + request_id: str, + media_type: str = MEDIA_TYPE_JSON, +) -> str: + """Serialize the stream start payload to an SSE line.""" + payload = StartStreamPayload( + data=StartEventData( + conversation_id=conversation_id, + request_id=request_id, + ), + ) + return serialize_event(payload, media_type) + + +def serialize_interrupted_event( + request_id: str, media_type: str = MEDIA_TYPE_JSON +) -> str: + """Serialize an interrupted-stream payload to an SSE line.""" + payload = InterruptedStreamPayload( + data=InterruptedEventData(request_id=request_id), + ) + return serialize_event(payload, media_type) + + +def serialize_end_event( + token_usage: TokenCounter, + available_quotas: dict[str, int], + referenced_documents: list[ReferencedDocument], + media_type: Optional[str] = None, +) -> str: + """Serialize the stream end payload for JSON SSE or plain-text clients.""" + resolved_media_type = media_type or MEDIA_TYPE_JSON + payload = EndStreamPayload( + data=EndEventData( + referenced_documents=referenced_documents, + truncated=None, + input_tokens=token_usage.input_tokens, + output_tokens=token_usage.output_tokens, + ), + available_quotas=available_quotas, + ) + return serialize_event(payload, resolved_media_type) + + +def serialize_event( + payload: StreamLlmEventPayload, + media_type: str = MEDIA_TYPE_JSON, +) -> str: + """Serialize an LLM stream payload (token, tool, turn complete) for the client.""" + if media_type == MEDIA_TYPE_JSON: + return format_stream_data(payload) + return serialize_event_text(payload) + + +@singledispatch +def serialize_event_text(_payload: StreamPayloadBase) -> str: + """Serialize stream payload to plain text for text media type.""" + return "" + + +@serialize_event_text.register +def _(payload: LlmTokenStreamPayload) -> str: + """Serialize token stream payload to plain text.""" + return str(payload.data.token) + + +@serialize_event_text.register +def _(_payload: LlmTurnCompleteStreamPayload) -> str: + """Serialize turn complete stream payload to plain text.""" + return "" + + +@serialize_event_text.register +def _(payload: LlmToolCallStreamPayload) -> str: + """Serialize tool call stream payload to plain text.""" + return f"[Tool Call: {payload.data.name}]\n" + + +@serialize_event_text.register +def _(_payload: LlmToolResultStreamPayload) -> str: + """Serialize tool result stream payload to plain text.""" + return "[Tool Result]\n" + + +@serialize_event_text.register +def _(payload: EndStreamPayload) -> str: + """Serialize end stream payload to plain text.""" + ref_docs_string = "\n".join( + f"{doc.doc_title}: {doc.doc_url}" + for doc in payload.data.referenced_documents + if doc.doc_url and doc.doc_title + ) + return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" + + +@serialize_event_text.register +def _(payload: ErrorStreamPayload) -> str: + """Serialize error stream payload to plain text.""" + cause_part = payload.data.cause if payload.data.cause is not None else "" + return ( + f"Status: {payload.data.status_code} - {payload.data.response} - {cause_part}" + ) + + +@serialize_event_text.register +def _(_payload: StartStreamPayload) -> str: + """Serialize start stream payload to plain text.""" + return "" diff --git a/src/utils/streaming/output_item_dispatchers.py b/src/utils/streaming/output_item_dispatchers.py new file mode 100644 index 000000000..d923d1ac9 --- /dev/null +++ b/src/utils/streaming/output_item_dispatchers.py @@ -0,0 +1,319 @@ +"""Dispatchers for response output items.""" + +import json +from dataclasses import replace +from functools import singledispatch +from typing import Any, Optional + +from llama_stack_api.openai_responses import ( + OpenAIResponseInputFunctionToolCallOutput as FunctionToolCallOutput, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseMessage as ResponseMessage, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageMCPCall as MCPCall, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageMCPListTools as MCPListTools, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageWebSearchToolCall as WebSearchCall, +) + +from constants import DEFAULT_RAG_TOOL +from log import get_logger +from models.common.responses.types import ResponseItem +from models.common.turn_summary import ( + MCPListToolsSummary, + ToolCallSummary, + ToolInfoSummary, + ToolResultSummary, +) +from utils.responses import parse_arguments_string +from utils.streaming.event_serializers import serialize_event +from utils.streaming.state import ChunkDispatchResult, StreamDispatchState +from utils.streaming.stream_payloads import ( + LlmToolCallStreamPayload, + LlmToolResultStreamPayload, +) + +logger = get_logger(__name__) + + +def _serialize_tool_summary_events( + media_type: str, + tool_call: Optional[ToolCallSummary], + tool_result: Optional[ToolResultSummary], +) -> list[str]: + """Serialize tool summaries to SSE event strings (no state updates).""" + events: list[str] = [] + if tool_call: + events.append( + serialize_event(LlmToolCallStreamPayload(data=tool_call), media_type) + ) + if tool_result: + events.append( + serialize_event(LlmToolResultStreamPayload(data=tool_result), media_type) + ) + return events + + +def _stringify_function_tool_output(output: str | list[Any]) -> str: + """Coerce API function_call_output content to a string (matches summary models).""" + if isinstance(output, str): + return output + return json.dumps([part.model_dump() for part in output]) + + +@singledispatch +def dispatch_output_item_done( + item: ResponseItem, + _output_index: int, + state: StreamDispatchState, + _media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Dispatch output_item.done processing by concrete output item class.""" + logger.debug("Ignoring unsupported output item class=%s", type(item).__name__) + return ChunkDispatchResult(state=state) + + +@dispatch_output_item_done.register +def _( + _item: ResponseMessage, + _output_index: int, + state: StreamDispatchState, + _media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Skip message output items (they are parsed elsewhere).""" + return ChunkDispatchResult(state=state) + + +@dispatch_output_item_done.register +def _( + item: FunctionCall, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit function call summary only.""" + tool_call = ToolCallSummary( + id=item.call_id, + name=item.name, + args=parse_arguments_string(item.arguments), + type="function_call", + ) + return ChunkDispatchResult( + state=replace(state, tool_calls=[*state.tool_calls, tool_call]), + events=_serialize_tool_summary_events(media_type, tool_call, None), + ) + + +@dispatch_output_item_done.register +def _( + item: FunctionToolCallOutput, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit function tool output as tool result only.""" + tool_result = ToolResultSummary( + id=item.call_id, + status=item.status or "success", + content=_stringify_function_tool_output(item.output), + type="function_call_output", + ) + return ChunkDispatchResult( + state=replace(state, tool_results=[*state.tool_results, tool_result]), + events=_serialize_tool_summary_events(media_type, None, tool_result), + ) + + +@dispatch_output_item_done.register +def _( + item: FileSearchCall, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit both call and result for file search call.""" + response_payload: Optional[dict[str, Any]] = None + if item.results is not None: + response_payload = {"results": [result.model_dump() for result in item.results]} + + tool_call = ToolCallSummary( + id=item.id, + name=DEFAULT_RAG_TOOL, + args={"queries": item.queries}, + type="file_search_call", + ) + tool_result = ToolResultSummary( + id=item.id, + status=item.status, + content=json.dumps(response_payload) if response_payload else "", + type="file_search_call", + ) + return ChunkDispatchResult( + state=replace( + state, + tool_calls=[*state.tool_calls, tool_call], + tool_results=[*state.tool_results, tool_result], + ), + events=_serialize_tool_summary_events(media_type, tool_call, tool_result), + ) + + +@dispatch_output_item_done.register +def _( + item: WebSearchCall, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit both call and result for web search call.""" + tool_call = ToolCallSummary( + id=item.id, + name="web_search", + args={}, + type="web_search_call", + ) + tool_result = ToolResultSummary( + id=item.id, + status=item.status, + content="", + type="web_search_call", + ) + return ChunkDispatchResult( + state=replace( + state, + tool_calls=[*state.tool_calls, tool_call], + tool_results=[*state.tool_results, tool_result], + ), + events=_serialize_tool_summary_events(media_type, tool_call, tool_result), + ) + + +@dispatch_output_item_done.register +def _( + item: MCPCall, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Handle MCP call output item completion.""" + content = item.error or (item.output or "") + tool_result = ToolResultSummary( + id=item.id, + status="success" if item.error is None else "failure", + content=content, + type="mcp_call", + ) + return ChunkDispatchResult( + state=replace(state, tool_results=[*state.tool_results, tool_result]), + events=_serialize_tool_summary_events(media_type, None, tool_result), + ) + + +@dispatch_output_item_done.register +def _( + item: MCPListTools, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit both call and result for MCP list tools events.""" + tool_call = ToolCallSummary( + id=item.id, + name="mcp_list_tools", + args={"server_label": item.server_label}, + type="mcp_list_tools", + ) + tools_info = [ + ToolInfoSummary( + name=tool.name, + description=tool.description, + input_schema=tool.input_schema, + ) + for tool in item.tools + ] + tool_result = ToolResultSummary( + id=item.id, + status="success", + content=json.dumps( + MCPListToolsSummary( + server_label=item.server_label, + tools=tools_info, + ).model_dump() + ), + type="mcp_list_tools", + ) + return ChunkDispatchResult( + state=replace( + state, + tool_calls=[*state.tool_calls, tool_call], + tool_results=[*state.tool_results, tool_result], + ), + events=_serialize_tool_summary_events(media_type, tool_call, tool_result), + ) + + +@dispatch_output_item_done.register +def _( + item: MCPApprovalRequest, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit approval request as tool call only.""" + tool_call = ToolCallSummary( + id=item.id, + name=item.name, + args=parse_arguments_string(item.arguments), + type="mcp_approval_request", + ) + return ChunkDispatchResult( + state=replace(state, tool_calls=[*state.tool_calls, tool_call]), + events=_serialize_tool_summary_events(media_type, tool_call, None), + ) + + +@dispatch_output_item_done.register +def _( + item: MCPApprovalResponse, + _output_index: int, + state: StreamDispatchState, + media_type: str, + _model_id: str, +) -> ChunkDispatchResult: + """Emit approval response as tool result only.""" + tool_result = ToolResultSummary( + id=item.approval_request_id, + status="success" if item.approve else "denied", + content=json.dumps({"reason": item.reason} if item.reason else {}), + type="mcp_approval_response", + ) + return ChunkDispatchResult( + state=replace(state, tool_results=[*state.tool_results, tool_result]), + events=_serialize_tool_summary_events(media_type, None, tool_result), + ) diff --git a/src/utils/streaming/state.py b/src/utils/streaming/state.py new file mode 100644 index 000000000..d2426ab3b --- /dev/null +++ b/src/utils/streaming/state.py @@ -0,0 +1,29 @@ +"""State models for streaming dispatch.""" + +from dataclasses import dataclass, field +from typing import Optional + +from llama_stack_api import OpenAIResponseObject + +from models.common.turn_summary import ToolCallSummary, ToolResultSummary + + +@dataclass(slots=True) +class StreamDispatchState: + """Streaming reducer state built incrementally from chunk events.""" + + chunk_id: int = 0 + text_parts: list[str] = field(default_factory=list) + llm_response: str = "" + tool_calls: list[ToolCallSummary] = field(default_factory=list) + tool_results: list[ToolResultSummary] = field(default_factory=list) + mcp_calls: dict[int, tuple[str, str]] = field(default_factory=dict) + latest_response_object: Optional[OpenAIResponseObject] = None + + +@dataclass(slots=True) +class ChunkDispatchResult: + """Result returned by chunk handlers.""" + + state: StreamDispatchState + events: list[str] = field(default_factory=list) diff --git a/src/utils/streaming/stream_payloads.py b/src/utils/streaming/stream_payloads.py new file mode 100644 index 000000000..9a716c16d --- /dev/null +++ b/src/utils/streaming/stream_payloads.py @@ -0,0 +1,120 @@ +"""Typed JSON bodies for SSE streaming events.""" + +from typing import Annotated, Literal, Optional, TypeAlias + +from pydantic import BaseModel, ConfigDict, Field + +from models.common import ReferencedDocument, ToolCallSummary, ToolResultSummary + + +class StreamPayloadBase(BaseModel): + """Base for streaming SSE JSON payloads.""" + + model_config = ConfigDict(extra="forbid") + + +class ErrorEventData(StreamPayloadBase): + """Payload for event: "error".""" + + status_code: int + response: str + cause: str + + +class StartEventData(StreamPayloadBase): + """Payload for event: "start".""" + + conversation_id: str + request_id: str + + +class InterruptedEventData(StreamPayloadBase): + """Payload for event: "interrupted".""" + + request_id: str + + +class EndEventData(StreamPayloadBase): + """Nested data for event: "end".""" + + referenced_documents: list[ReferencedDocument] + truncated: Optional[bool] + input_tokens: int + output_tokens: int + + +class ErrorStreamPayload(StreamPayloadBase): + """SSE error event body (event + typed data).""" + + event: Literal["error"] = "error" + data: ErrorEventData + + +class StartStreamPayload(StreamPayloadBase): + """SSE stream start body.""" + + event: Literal["start"] = "start" + data: StartEventData + + +class InterruptedStreamPayload(StreamPayloadBase): + """SSE interrupted stream body.""" + + event: Literal["interrupted"] = "interrupted" + data: InterruptedEventData + + +class EndStreamPayload(StreamPayloadBase): + """SSE end-of-stream body (includes available_quotas beside data).""" + + event: Literal["end"] = "end" + data: EndEventData + available_quotas: dict[str, int] + + +class LlmTokenChunkData(StreamPayloadBase): + """Structured data for token and turn-complete stream lines.""" + + id: int + token: str + + +class LlmTokenStreamPayload(StreamPayloadBase): + """SSE token delta (event: "token").""" + + event: Literal["token"] = "token" + data: LlmTokenChunkData + + +class LlmTurnCompleteStreamPayload(StreamPayloadBase): + """SSE turn completion (same data shape as token).""" + + event: Literal["turn_complete"] = "turn_complete" + data: LlmTokenChunkData + + +class LlmToolCallStreamPayload(StreamPayloadBase): + """SSE tool call summary.""" + + event: Literal["tool_call"] = "tool_call" + data: ToolCallSummary + + +class LlmToolResultStreamPayload(StreamPayloadBase): + """SSE tool result summary.""" + + event: Literal["tool_result"] = "tool_result" + data: ToolResultSummary + + +StreamLlmEventPayload: TypeAlias = Annotated[ + LlmTokenStreamPayload + | LlmTurnCompleteStreamPayload + | LlmToolCallStreamPayload + | LlmToolResultStreamPayload + | EndStreamPayload + | ErrorStreamPayload + | InterruptedStreamPayload + | StartStreamPayload, + Field(discriminator="event"), +] diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index 3cb5b878e..5def718b6 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -9,7 +9,24 @@ import pytest from fastapi import Request, status from fastapi.responses import StreamingResponse -from llama_stack_api.openai_responses import OpenAIResponseObject +from llama_stack_api.openai_responses import ( + OpenAIResponseObject, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseCompleted as CompletedChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputItemDone as OutputItemDoneChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseOutputTextDone as TextDoneChunk, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFileSearchToolCallResults as FileSearchCallResult, +) from pytest_mock import AsyncMockType, MockerFixture import constants @@ -155,65 +172,67 @@ def mock_streaming_byok_tool_client_fixture( # pylint: disable=too-many-stateme # Build a streaming response with file_search and completion events async def _mock_tool_stream() -> AsyncIterator[Any]: # file_search output item done - item_done_chunk = mocker.MagicMock() - item_done_chunk.type = "response.output_item.done" - item_done_chunk.output_index = 0 - - mock_item = mocker.MagicMock() - mock_item.type = "file_search_call" - mock_item.id = "call-fs-stream-1" - mock_item.queries = ["What is OpenShift?"] - mock_item.status = "completed" - - mock_result = mocker.MagicMock() + mock_result = mocker.Mock(spec=FileSearchCallResult) mock_result.file_id = "doc-ocp-1" mock_result.filename = "openshift-docs.txt" mock_result.score = 0.92 mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." mock_result.attributes = { "doc_url": "https://docs.redhat.com/ocp/overview", + "doc_title": "OpenShift Overview", } mock_result.model_dump = mocker.Mock( return_value={ "file_id": "doc-ocp-1", "filename": "openshift-docs.txt", "score": 0.92, - "text": "OpenShift is a Kubernetes distribution.", - "attributes": {"doc_url": "https://docs.redhat.com/ocp/overview"}, + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "doc_title": "OpenShift Overview", + }, } ) + mock_item = mocker.Mock(spec=FileSearchCall) + mock_item.type = "file_search_call" + mock_item.id = "call-fs-stream-1" + mock_item.queries = ["What is OpenShift?"] + mock_item.status = "completed" mock_item.results = [mock_result] + item_done_chunk = mocker.Mock(spec=OutputItemDoneChunk) + item_done_chunk.type = "response.output_item.done" + item_done_chunk.response_id = "response-tool-stream" item_done_chunk.item = mock_item + item_done_chunk.output_index = 0 + item_done_chunk.sequence_number = 1 yield item_done_chunk # Text done - text_done_chunk = mocker.MagicMock() + text_done_chunk = mocker.Mock(spec=TextDoneChunk) text_done_chunk.type = "response.output_text.done" + text_done_chunk.content_index = 0 text_done_chunk.text = ( "Based on the documentation, OpenShift is a Kubernetes distribution." ) + text_done_chunk.item_id = "msg-tool-stream" + text_done_chunk.output_index = 1 + text_done_chunk.sequence_number = 2 yield text_done_chunk # Response completed - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final_response = mocker.MagicMock(spec=OpenAIResponseObject) + mock_final_response = mocker.Mock(spec=OpenAIResponseObject) mock_final_response.id = "response-tool-stream" mock_final_response.error = None - - mock_usage = mocker.MagicMock() + mock_usage = mocker.Mock() mock_usage.input_tokens = 60 mock_usage.output_tokens = 25 mock_final_response.usage = mock_usage + mock_final_response.output = [mock_item] - # file_search results in the final response output - mock_fs_output = mocker.MagicMock() - mock_fs_output.type = "file_search_call" - mock_fs_output.id = "call-fs-stream-1" - mock_fs_output.results = [mock_result] - mock_final_response.output = [mock_fs_output] - + completed_chunk = mocker.Mock(spec=CompletedChunk) + completed_chunk.type = "response.completed" completed_chunk.response = mock_final_response + completed_chunk.sequence_number = 3 yield completed_chunk async def _responses_create(**kwargs: Any) -> Any: diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 1894190cf..746bbaadc 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -17,6 +17,9 @@ from llama_stack_api.openai_responses import ( OpenAIResponseObjectStreamResponseCompleted as CompletedChunk, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStreamResponseContentPartAdded as ContentPartAddedChunk, +) from llama_stack_api.openai_responses import ( OpenAIResponseObjectStreamResponseFailed as FailedChunk, ) @@ -38,6 +41,12 @@ from llama_stack_api.openai_responses import ( OpenAIResponseObjectStreamResponseOutputTextDone as TextDoneChunk, ) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, +) from llama_stack_api.openai_responses import ( OpenAIResponseOutputMessageMCPCall as MCPCall, ) @@ -49,21 +58,15 @@ generate_response, response_generator, retrieve_response_generator, + serialize_end_event, + serialize_event, + serialize_http_error_event, + serialize_start_event, shield_violation_generator, - stream_end_event, - stream_event, - stream_http_error_event, - stream_start_event, streaming_query_endpoint_handler, ) from configuration import AppConfig -from constants import ( - LLM_TOKEN_EVENT, - LLM_TOOL_CALL_EVENT, - LLM_TOOL_RESULT_EVENT, - MEDIA_TYPE_JSON, - MEDIA_TYPE_TEXT, -) +from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT from models.api.requests import QueryRequest from models.api.responses.error import InternalServerErrorResponse from models.common.moderation import ShieldModerationPassed @@ -73,11 +76,20 @@ RAGChunk, RAGContext, ReferencedDocument, + ToolCallSummary, + ToolResultSummary, TurnSummary, ) from models.config import Action from models.context import ResponseGeneratorContext from utils.stream_interrupts import StreamInterruptRegistry +from utils.streaming.stream_payloads import ( + LlmTokenChunkData, + LlmTokenStreamPayload, + LlmToolCallStreamPayload, + LlmToolResultStreamPayload, + LlmTurnCompleteStreamPayload, +) from utils.token_counter import TokenCounter MOCK_AUTH_STREAMING = ( @@ -129,75 +141,87 @@ class TestOLSStreamEventFormatting: def test_stream_event_json_token(self) -> None: """Test token event formatting for JSON media type.""" - data = {"id": 0, "token": "Hello"} - result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_JSON) + payload = LlmTokenStreamPayload(data=LlmTokenChunkData(id=0, token="Hello")) + result = serialize_event(payload, MEDIA_TYPE_JSON) expected = 'data: {"event": "token", "data": {"id": 0, "token": "Hello"}}\n\n' assert result == expected def test_stream_event_text_token(self) -> None: """Test token event formatting for text media type.""" - data = {"id": 0, "token": "Hello"} - result = stream_event(data, LLM_TOKEN_EVENT, MEDIA_TYPE_TEXT) + payload = LlmTokenStreamPayload(data=LlmTokenChunkData(id=0, token="Hello")) + result = serialize_event(payload, MEDIA_TYPE_TEXT) assert result == "Hello" def test_stream_event_json_tool_call(self) -> None: """Test tool call event formatting for JSON media type.""" - data = { - "id": 0, - "token": {"tool_name": "search", "arguments": {"query": "test"}}, - } - result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_JSON) + payload = LlmToolCallStreamPayload( + data=ToolCallSummary( + id="0", + name="search", + args={"query": "test"}, + ), + ) + result = serialize_event(payload, MEDIA_TYPE_JSON) expected = ( - 'data: {"event": "tool_call", "data": {"id": 0, "token": ' - '{"tool_name": "search", "arguments": {"query": "test"}}}}\n\n' + 'data: {"event": "tool_call", "data": {"id": "0", "name": "search", ' + '"args": {"query": "test"}, "type": "tool_call"}}\n\n' ) assert result == expected def test_stream_event_text_tool_call(self) -> None: """Test tool call event formatting for text media type.""" - data = { - "id": 0, - "function_name": "search", - "arguments": {"query": "test"}, - } - result = stream_event(data, LLM_TOOL_CALL_EVENT, MEDIA_TYPE_TEXT) + payload = LlmToolCallStreamPayload( + data=ToolCallSummary( + id="0", + name="search", + args={"query": "test"}, + ), + ) + result = serialize_event(payload, MEDIA_TYPE_TEXT) expected = "[Tool Call: search]\n" assert result == expected def test_stream_event_json_tool_result(self) -> None: """Test tool result event formatting for JSON media type.""" - data = { - "id": 0, - "token": {"tool_name": "search", "response": "Found results"}, - } - result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_JSON) + payload = LlmToolResultStreamPayload( + data=ToolResultSummary( + id="0", status="success", content="Found results", type="tool_result" + ), + ) + result = serialize_event(payload, MEDIA_TYPE_JSON) expected = ( - 'data: {"event": "tool_result", "data": {"id": 0, "token": ' - '{"tool_name": "search", "response": "Found results"}}}\n\n' + 'data: {"event": "tool_result", "data": {"id": "0", "status": "success", ' + '"content": "Found results", "type": "tool_result", "round": 1}}\n\n' ) assert result == expected def test_stream_event_text_tool_result(self) -> None: """Test tool result event formatting for text media type.""" - data = { - "id": 0, - "tool_name": "search", - "response": "Found results", - } - result = stream_event(data, LLM_TOOL_RESULT_EVENT, MEDIA_TYPE_TEXT) + payload = LlmToolResultStreamPayload( + data=ToolResultSummary( + id="0", + status="success", + content="Found results", + type="tool_result", + round=1, + ), + ) + result = serialize_event(payload, MEDIA_TYPE_TEXT) expected = "[Tool Result]\n" assert result == expected - def test_stream_event_unknown_type(self) -> None: - """Test handling of unknown event types.""" - data = {"id": 0, "token": "test"} - result = stream_event(data, "unknown_event", MEDIA_TYPE_TEXT) + def test_stream_event_text_turn_complete(self) -> None: + """Test turn_complete yields no text in plain-text media mode.""" + payload = LlmTurnCompleteStreamPayload( + data=LlmTokenChunkData(id=0, token="final"), + ) + result = serialize_event(payload, MEDIA_TYPE_TEXT) assert result == "" @@ -217,7 +241,7 @@ def test_stream_end_event_json(self) -> None: doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" ), ] - result = stream_end_event( + result = serialize_end_event( token_usage, available_quotas, referenced_documents, @@ -249,7 +273,7 @@ def test_stream_end_event_text(self) -> None: doc_url=AnyUrl("https://example.com/doc2"), doc_title="Test Doc 2" ), ] - result = stream_end_event( + result = serialize_end_event( token_usage, available_quotas, referenced_documents, @@ -267,7 +291,7 @@ def test_stream_end_event_text_no_docs(self) -> None: token_usage = TokenCounter(input_tokens=100, output_tokens=50) available_quotas: dict[str, int] = {} referenced_documents: list[ReferencedDocument] = [] - result = stream_end_event( + result = serialize_end_event( token_usage, available_quotas, referenced_documents, @@ -306,7 +330,7 @@ def test_ols_end_event_structure(self) -> None: doc_url=AnyUrl("https://example.com/doc"), doc_title="Test Doc" ), ] - end_event = stream_end_event( + end_event = serialize_end_event( token_usage, available_quotas, referenced_documents, @@ -1337,7 +1361,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_error_response.status_code = 413 mock_error_response.detail = mocker.Mock() mock_error_response.detail.response = "Prompt too long" - mock_error_response.detail.cause = None + mock_error_response.detail.cause = "Prompt exceeded model context window" mocker.patch( "app.endpoints.streaming_query.PromptTooLongResponse", return_value=mock_error_response, @@ -1385,7 +1409,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_error_response.status_code = 500 mock_error_response.detail = mocker.Mock() mock_error_response.detail.response = "Internal server error" - mock_error_response.detail.cause = None + mock_error_response.detail.cause = "Some other error" mocker.patch( "app.endpoints.streaming_query.InternalServerErrorResponse.generic", return_value=mock_error_response, @@ -1892,7 +1916,7 @@ async def test_response_generator_content_part_added( """Test response generator processes content part added events.""" async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: - chunk = mocker.Mock() + chunk = mocker.Mock(spec=ContentPartAddedChunk) chunk.type = "response.content_part.added" yield chunk @@ -2000,8 +2024,11 @@ async def test_response_generator_output_item_done( self, mocker: MockerFixture ) -> None: """Test response generator processes output item done events.""" - mock_output_item = mocker.Mock() - mock_output_item.type = "tool_call" + mock_output_item = mocker.Mock(spec=FunctionCall) + mock_output_item.type = "function_call" + mock_output_item.call_id = "func_call_123" + mock_output_item.name = "test_function" + mock_output_item.arguments = '{"param":"value"}' async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: chunk = mocker.Mock(spec=OutputItemDoneChunk) @@ -2021,13 +2048,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_turn_summary = TurnSummary() - mock_tool_call = mocker.Mock() - mock_tool_call.model_dump.return_value = {"tool": "test"} - mocker.patch( - "app.endpoints.streaming_query.build_tool_call_summary", - return_value=(mock_tool_call, None), - ) - mocker.patch( "app.endpoints.streaming_query.extract_token_usage", return_value=TokenCounter(input_tokens=0, output_tokens=0), @@ -2049,8 +2069,12 @@ async def test_response_generator_output_item_done_with_tool_result( self, mocker: MockerFixture ) -> None: """Test response generator processes output item done events with tool result.""" - mock_output_item = mocker.Mock() - mock_output_item.type = "tool_call" + mock_output_item = mocker.Mock(spec=FileSearchCall) + mock_output_item.type = "file_search_call" + mock_output_item.id = "file_search_123" + mock_output_item.queries = ["test query"] + mock_output_item.results = None + mock_output_item.status = "completed" async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: chunk = mocker.Mock(spec=OutputItemDoneChunk) @@ -2070,15 +2094,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_turn_summary = TurnSummary() - mock_tool_call = mocker.Mock() - mock_tool_call.model_dump.return_value = {"tool": "test"} - mock_tool_result = mocker.Mock() - mock_tool_result.model_dump.return_value = {"result": "test_result"} - mocker.patch( - "app.endpoints.streaming_query.build_tool_call_summary", - return_value=(mock_tool_call, mock_tool_result), - ) - mocker.patch( "app.endpoints.streaming_query.extract_token_usage", return_value=TokenCounter(input_tokens=0, output_tokens=0), @@ -2095,6 +2110,8 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: assert len(result) > 0 assert len(mock_turn_summary.tool_results) == 1 + assert mock_turn_summary.tool_results[0].type == "file_search_call" + assert mock_turn_summary.tool_results[0].id == "file_search_123" @pytest.mark.asyncio async def test_response_generator_response_completed( @@ -2474,7 +2491,7 @@ def test_stream_http_error_event_json(self, mocker: MockerFixture) -> None: error = InternalServerErrorResponse.query_failed("Test error") mocker.patch("app.endpoints.streaming_query.logger") - result = stream_http_error_event(error, MEDIA_TYPE_JSON) + result = serialize_http_error_event(error, MEDIA_TYPE_JSON) assert "error" in result assert "Test error" in result @@ -2484,7 +2501,7 @@ def test_stream_http_error_event_text(self, mocker: MockerFixture) -> None: error = InternalServerErrorResponse.query_failed("Test error") mocker.patch("app.endpoints.streaming_query.logger") - result = stream_http_error_event(error, MEDIA_TYPE_TEXT) + result = serialize_http_error_event(error, MEDIA_TYPE_TEXT) assert "Status:" in result assert "500" in result @@ -2495,7 +2512,7 @@ def test_stream_http_error_event_default(self, mocker: MockerFixture) -> None: error = InternalServerErrorResponse.query_failed("Test error") mocker.patch("app.endpoints.streaming_query.logger") - result = stream_http_error_event(error) + result = serialize_http_error_event(error) assert "error" in result assert "500" in result or "status_code" in result @@ -2504,9 +2521,11 @@ def test_stream_http_error_event_default(self, mocker: MockerFixture) -> None: class TestStreamStartEvent: # pylint: disable=too-few-public-methods """Tests for stream_start_event function.""" - def test_stream_start_event(self) -> None: + def test_serialize_start_event(self) -> None: """Test start event formatting.""" - result = stream_start_event("conv_123", "123e4567-e89b-12d3-a456-426614174000") + result = serialize_start_event( + "conv_123", "123e4567-e89b-12d3-a456-426614174000" + ) assert "start" in result assert "conv_123" in result @@ -2624,16 +2643,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_turn_summary = TurnSummary() - mock_tool_call = mocker.Mock() - mock_tool_call.model_dump.return_value = { - "id": "mcp_call_123", - "name": "test_mcp_tool", - } - mocker.patch( - "app.endpoints.streaming_query.build_mcp_tool_call_from_arguments_done", - return_value=mock_tool_call, - ) - mocker.patch( "app.endpoints.streaming_query.extract_token_usage", return_value=TokenCounter(input_tokens=0, output_tokens=0), @@ -2697,36 +2706,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_turn_summary = TurnSummary() - mock_tool_call = mocker.Mock() - mock_tool_call.model_dump.return_value = {"id": "mcp_call_123"} - - # Use side_effect to actually remove item from mcp_calls dict - def build_mcp_tool_call_side_effect( - output_index: int, - arguments: str, - mcp_call_items: dict[int, tuple[str, str]], - ) -> Any: - # Remove item from dict to simulate real behavior - # arguments parameter is required by function signature but unused here - _ = arguments - mcp_call_items.pop(output_index, None) - return mock_tool_call - - mocker.patch( - "app.endpoints.streaming_query.build_mcp_tool_call_from_arguments_done", - side_effect=build_mcp_tool_call_side_effect, - ) - - mock_tool_result = mocker.Mock() - mock_tool_result.model_dump.return_value = { - "id": "mcp_call_123", - "status": "success", - } - mocker.patch( - "app.endpoints.streaming_query.build_tool_result_from_mcp_output_item_done", - return_value=mock_tool_result, - ) - mocker.patch( "app.endpoints.streaming_query.extract_token_usage", return_value=TokenCounter(input_tokens=0, output_tokens=0), @@ -2749,7 +2728,7 @@ def build_mcp_tool_call_side_effect( async def test_response_generator_mcp_call_output_item_done_without_arguments_done( self, mocker: MockerFixture ) -> None: - """Test response generator emits both call and result when MCP output_item.done.""" + """Test response generator emits only result on MCP output_item.done.""" mock_mcp_item = mocker.Mock(spec=MCPCall) mock_mcp_item.type = "mcp_call" mock_mcp_item.id = "mcp_call_123" @@ -2767,7 +2746,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: added_chunk.output_index = 0 yield added_chunk - # output_item.done (should emit both call and result since arguments.done didn't happen) + # output_item.done (always emits MCP result only) done_chunk = mocker.Mock(spec=OutputItemDoneChunk) done_chunk.type = "response.output_item.done" done_chunk.item = mock_mcp_item @@ -2785,18 +2764,6 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: mock_turn_summary = TurnSummary() - mock_tool_call = mocker.Mock() - mock_tool_call.model_dump.return_value = {"id": "mcp_call_123"} - mock_tool_result = mocker.Mock() - mock_tool_result.model_dump.return_value = { - "id": "mcp_call_123", - "status": "success", - } - mocker.patch( - "app.endpoints.streaming_query.build_tool_call_summary", - return_value=(mock_tool_call, mock_tool_result), - ) - mocker.patch( "app.endpoints.streaming_query.extract_token_usage", return_value=TokenCounter(input_tokens=0, output_tokens=0), @@ -2811,6 +2778,7 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: ): result.append(item) - # Should have both tool call and result (fallback behavior) - assert len(mock_turn_summary.tool_calls) == 1 + # Should have result only for MCP output item done chunk + # MCP call was already emitted on argument.done chunk + assert len(mock_turn_summary.tool_calls) == 0 assert len(mock_turn_summary.tool_results) == 1