diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index c241e831a..15328ccab 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -12,7 +12,7 @@ import re from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from http import HTTPStatus from typing import Any @@ -21,7 +21,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse -from starlette.requests import Request +from starlette.requests import ClientDisconnect, Request from starlette.responses import Response from starlette.types import Receive, Scope, Send @@ -644,6 +644,15 @@ async def sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) + except ClientDisconnect: + # Client went away mid-request (network timeout, cancel, LB drop). Not a + # server error — log at WARNING and skip the response: the socket is gone. + # Notify the writer so the inner session task can unblock cleanly. + logger.warning("Client disconnected during POST request") + if writer is not None: + with suppress(Exception): + await writer.send(ClientDisconnect()) + return except Exception as err: # pragma: no cover logger.exception("Error handling POST request") response = self._create_error_response( diff --git a/tests/server/streamable_http/__init__.py b/tests/server/streamable_http/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/streamable_http/test_client_disconnect_post.py b/tests/server/streamable_http/test_client_disconnect_post.py new file mode 100644 index 000000000..0e2b9f4de --- /dev/null +++ b/tests/server/streamable_http/test_client_disconnect_post.py @@ -0,0 +1,199 @@ +"""Tests for ClientDisconnect handling in StreamableHTTPServerTransport._handle_post_request. + +Regression test for pattern 1: ClientDisconnect raised during POST should log at WARNING +(not ERROR) and should not attempt to send a response to the closed socket. + +Inspired by upstream PRs: +- https://github.com/modelcontextprotocol/python-sdk/pull/1647 (scope: POST only) +- https://github.com/modelcontextprotocol/python-sdk/pull/1947 (semantics: notify writer, skip response) +""" + +from __future__ import annotations as _annotations + +import logging +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from starlette.requests import ClientDisconnect, Request + +from mcp.server.streamable_http import StreamableHTTPServerTransport + + +class TestClientDisconnectDuringPOST: + """ClientDisconnect during POST should be handled gracefully.""" + + def _make_scope(self, headers: dict[str, bytes] | None = None) -> dict[str, Any]: + """Build a minimal ASGI scope for a POST request.""" + return { + "type": "http", + "method": "POST", + "path": "/mcp", + "query_string": b"", + "headers": list((headers or {}).items()) if headers else [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + @pytest.mark.anyio + async def test_client_disconnect_logs_warning_not_error(self, caplog): + """ClientDisconnect should produce a WARNING, not an ERROR.""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Set up a dummy writer so the transport passes the None check + mock_writer = MagicMock() + mock_writer.send = AsyncMock() + transport._read_stream_writer = mock_writer + + # Mock request.body() to raise ClientDisconnect (simulates client going away + # mid-request body upload). + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + with caplog.at_level(logging.DEBUG, logger="mcp.server.streamable_http"): + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # Should log a WARNING, not an ERROR + warning_records = [ + r for r in caplog.records if r.levelno == logging.WARNING + and "Client disconnected" in r.getMessage() + ] + error_records = [ + r for r in caplog.records if r.levelno == logging.ERROR + ] + assert len(warning_records) == 1, ( + f"Expected exactly 1 WARNING with 'Client disconnected', got {len(warning_records)}" + ) + assert len(error_records) == 0, ( + f"Expected 0 ERROR logs, got {len(error_records)}: {[r.getMessage() for r in error_records]}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_does_not_send_response(self): + """After ClientDisconnect, no response should be sent (socket is closed).""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Set up a dummy writer so the transport passes the None check + mock_writer = MagicMock() + mock_writer.send = AsyncMock() + transport._read_stream_writer = mock_writer + + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # No HTTP response should be sent to the closed socket + assert len(send_calls) == 0, ( + f"Expected no ASGI sends after ClientDisconnect, got {len(send_calls)}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_notifies_writer(self): + """Writer should receive ClientDisconnect so the inner session task can unblock.""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Capture what the writer receives + writer_messages: list[Any] = [] + + async def capture_writer(msg): + writer_messages.append(msg) + + # Patch the internal writer + with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)): + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # Writer should have been notified with ClientDisconnect + assert len(writer_messages) == 1, ( + f"Expected writer to receive 1 message, got {len(writer_messages)}" + ) + assert isinstance(writer_messages[0], ClientDisconnect), ( + f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_writer_suppresses_errors(self): + """If the writer itself is broken, we should not crash (suppress(Exception)).""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + broken_send = AsyncMock(side_effect=RuntimeError("writer is broken")) + + with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)): + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + send_calls: list[Any] = [] + + async def dummy_send(message): + send_calls.append(message) + + # Should not raise even though writer.send() fails + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # The broken writer.send was called once + broken_send.assert_called_once() + # No response was sent (socket is closed) + assert len(send_calls) == 0 diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 8e92cfa43..23fd86c3b 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -445,7 +445,7 @@ async def test_stateless_get_returns_405(): error_data = json.loads(response_body) assert error_data["jsonrpc"] == "2.0" - assert error_data["id"] is None + assert error_data["id"] == "" assert error_data["error"]["code"] == INVALID_REQUEST assert "GET" in error_data["error"]["message"] assert "stateless" in error_data["error"]["message"].lower() @@ -464,7 +464,7 @@ async def test_stateless_delete_returns_405(): error_data = json.loads(response_body) assert error_data["jsonrpc"] == "2.0" - assert error_data["id"] is None + assert error_data["id"] == "" assert error_data["error"]["code"] == INVALID_REQUEST assert "DELETE" in error_data["error"]["message"]