diff --git a/.pr-body.md b/.pr-body.md new file mode 100644 index 000000000..52eed5ae6 --- /dev/null +++ b/.pr-body.md @@ -0,0 +1,40 @@ +## Problem + +When a client disconnects mid-POST (network timeout, cancel, LB drop), the MCP SDK's `_handle_post_request` catches `ClientDisconnect` in its generic `except Exception` handler, which: + +1. Logs an **ERROR** ("Error handling POST request") to Datadog — this is pattern 1, **67% of all prod ERRORs** in ai-codegen-api +2. Synthesizes a 500 response and sends it to the already-closed socket +3. Does not notify the inner session writer, so stateful session tasks can hang + +## Solution + +Catch `ClientDisconnect` before the generic `except Exception` handler in `_handle_post_request`: + +- **Log at WARNING** — honest level: we aborted handling a request, but it wasn't our fault +- **Notify the session writer** with `ClientDisconnect()` so inner session tasks unblock cleanly (stateful sessions) +- **Skip sending a response** — the socket is closed; sending to it is just noise +- **Scope: POST only** — Datadog confirms 100% of pattern 1 stacks originate from `_handle_post_request` + +## Tests + +4 new unit tests in `tests/server/streamable_http/test_client_disconnect_post.py`: +- `test_client_disconnect_logs_warning_not_error` — verifies WARNING, not ERROR +- `test_client_disconnect_does_not_send_response` — verifies no ASGI sends to closed socket +- `test_client_disconnect_notifies_writer` — verifies writer receives ClientDisconnect +- `test_client_disconnect_writer_suppresses_errors` — verifies broken writer doesn't crash us + +All existing streamable_http tests (49 + 18) pass. + +## Upstream context + +This patch combines the approach of two open upstream PRs (neither merged): +- [modelcontextprotocol/python-sdk#1647](https://github.com/modelcontextprotocol/python-sdk/pull/1647) — right scope (POST only) but wrong semantics (still sends response, doesn't notify writer) +- [modelcontextprotocol/python-sdk#1947](https://github.com/modelcontextprotocol/python-sdk/pull/1947) — right semantics (notify writer, skip response) but broader scope than needed (covers GET/replay paths) + +This is a tight ~10-line change taking the spirit of #1647 + semantics of #1947. Tracking issue: [modelcontextprotocol/python-sdk#1648](https://github.com/modelcontextprotocol/python-sdk/issues/1648). + +## Follow-up + +After this lands in the fork and is verified in prod via ai-codegen-api bump: +- Open mirror PR upstream +- Consider dropping `ClientDisconnectHandlerMiddleware` in ai-codegen-api (no longer load-bearing for POST path) diff --git a/pyproject.toml b/pyproject.toml index 7818775c5..2566cb9da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mcp" -version = "1.27.0.post1" +version = "1.27.0.post2" description = "Model Context Protocol SDK" readme = "README.md" requires-python = ">=3.10" diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index c241e831a..15328ccab 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -12,7 +12,7 @@ import re from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from http import HTTPStatus from typing import Any @@ -21,7 +21,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from sse_starlette import EventSourceResponse -from starlette.requests import Request +from starlette.requests import ClientDisconnect, Request from starlette.responses import Response from starlette.types import Receive, Scope, Send @@ -644,6 +644,15 @@ async def sse_writer(): await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) + except ClientDisconnect: + # Client went away mid-request (network timeout, cancel, LB drop). Not a + # server error — log at WARNING and skip the response: the socket is gone. + # Notify the writer so the inner session task can unblock cleanly. + logger.warning("Client disconnected during POST request") + if writer is not None: + with suppress(Exception): + await writer.send(ClientDisconnect()) + return except Exception as err: # pragma: no cover logger.exception("Error handling POST request") response = self._create_error_response( diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8b7b5281c..8bb8b36e1 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -201,6 +201,29 @@ async def _handle_stateless_request( receive: ASGI receive function send: ASGI send function """ + # In stateless mode, only POST is meaningful. GET (SSE stream) and DELETE + # (session termination) both require session state that stateless mode does + # not maintain, so reject them with 405 before creating a transport. + request = Request(scope, receive) + if request.method in ("GET", "DELETE"): + logger.debug(f"Stateless mode: rejecting {request.method} with 405") + error_response = JSONRPCError( + jsonrpc="2.0", + id="", + error=ErrorData( + code=INVALID_REQUEST, + message=(f"Method Not Allowed: {request.method} is not supported in stateless mode"), + ), + ) + response = Response( + content=error_response.model_dump_json(by_alias=True, exclude_unset=True), + status_code=HTTPStatus.METHOD_NOT_ALLOWED, + headers={"Allow": "POST"}, + media_type="application/json", + ) + await response(scope, receive, send) + return + logger.debug("Stateless mode: Creating new transport for this request") # No session ID needed in stateless mode http_transport = StreamableHTTPServerTransport( diff --git a/tests/server/streamable_http/__init__.py b/tests/server/streamable_http/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/server/streamable_http/test_client_disconnect_post.py b/tests/server/streamable_http/test_client_disconnect_post.py new file mode 100644 index 000000000..0e2b9f4de --- /dev/null +++ b/tests/server/streamable_http/test_client_disconnect_post.py @@ -0,0 +1,199 @@ +"""Tests for ClientDisconnect handling in StreamableHTTPServerTransport._handle_post_request. + +Regression test for pattern 1: ClientDisconnect raised during POST should log at WARNING +(not ERROR) and should not attempt to send a response to the closed socket. + +Inspired by upstream PRs: +- https://github.com/modelcontextprotocol/python-sdk/pull/1647 (scope: POST only) +- https://github.com/modelcontextprotocol/python-sdk/pull/1947 (semantics: notify writer, skip response) +""" + +from __future__ import annotations as _annotations + +import logging +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from starlette.requests import ClientDisconnect, Request + +from mcp.server.streamable_http import StreamableHTTPServerTransport + + +class TestClientDisconnectDuringPOST: + """ClientDisconnect during POST should be handled gracefully.""" + + def _make_scope(self, headers: dict[str, bytes] | None = None) -> dict[str, Any]: + """Build a minimal ASGI scope for a POST request.""" + return { + "type": "http", + "method": "POST", + "path": "/mcp", + "query_string": b"", + "headers": list((headers or {}).items()) if headers else [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + @pytest.mark.anyio + async def test_client_disconnect_logs_warning_not_error(self, caplog): + """ClientDisconnect should produce a WARNING, not an ERROR.""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Set up a dummy writer so the transport passes the None check + mock_writer = MagicMock() + mock_writer.send = AsyncMock() + transport._read_stream_writer = mock_writer + + # Mock request.body() to raise ClientDisconnect (simulates client going away + # mid-request body upload). + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + with caplog.at_level(logging.DEBUG, logger="mcp.server.streamable_http"): + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # Should log a WARNING, not an ERROR + warning_records = [ + r for r in caplog.records if r.levelno == logging.WARNING + and "Client disconnected" in r.getMessage() + ] + error_records = [ + r for r in caplog.records if r.levelno == logging.ERROR + ] + assert len(warning_records) == 1, ( + f"Expected exactly 1 WARNING with 'Client disconnected', got {len(warning_records)}" + ) + assert len(error_records) == 0, ( + f"Expected 0 ERROR logs, got {len(error_records)}: {[r.getMessage() for r in error_records]}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_does_not_send_response(self): + """After ClientDisconnect, no response should be sent (socket is closed).""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Set up a dummy writer so the transport passes the None check + mock_writer = MagicMock() + mock_writer.send = AsyncMock() + transport._read_stream_writer = mock_writer + + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # No HTTP response should be sent to the closed socket + assert len(send_calls) == 0, ( + f"Expected no ASGI sends after ClientDisconnect, got {len(send_calls)}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_notifies_writer(self): + """Writer should receive ClientDisconnect so the inner session task can unblock.""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + # Capture what the writer receives + writer_messages: list[Any] = [] + + async def capture_writer(msg): + writer_messages.append(msg) + + # Patch the internal writer + with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)): + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + send_calls: list[Any] = [] + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + async def dummy_send(message): + send_calls.append(message) + + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # Writer should have been notified with ClientDisconnect + assert len(writer_messages) == 1, ( + f"Expected writer to receive 1 message, got {len(writer_messages)}" + ) + assert isinstance(writer_messages[0], ClientDisconnect), ( + f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}" + ) + + @pytest.mark.anyio + async def test_client_disconnect_writer_suppresses_errors(self): + """If the writer itself is broken, we should not crash (suppress(Exception)).""" + transport = StreamableHTTPServerTransport(mcp_session_id=None) + scope = self._make_scope() + + broken_send = AsyncMock(side_effect=RuntimeError("writer is broken")) + + with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)): + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(side_effect=ClientDisconnect()) + mock_request.headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + mock_request.scope = scope + + async def dummy_receive(): + return {"type": "http.request", "body": b""} + + send_calls: list[Any] = [] + + async def dummy_send(message): + send_calls.append(message) + + # Should not raise even though writer.send() fails + await transport._handle_post_request( + scope, mock_request, dummy_receive, dummy_send + ) + + # The broken writer.send was called once + broken_send.assert_called_once() + # No response was sent (socket is closed) + assert len(send_calls) == 0 diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 33bcb5f2a..8e92cfa43 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -390,3 +390,116 @@ def test_session_idle_timeout_rejects_non_positive(): def test_session_idle_timeout_rejects_stateless(): with pytest.raises(RuntimeError, match="not supported in stateless"): StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True) + + +async def _collect_stateless_response( + method: str, +) -> tuple[Message | None, bytes]: + """Send a request of the given method to a stateless manager and return + (response.start message, response body).""" + app = Server("test-stateless-method") + manager = StreamableHTTPSessionManager(app=app, stateless=True) + + sent_messages: list[Message] = [] + response_body = b"" + + async def mock_send(message: Message): + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + scope = { + "type": "http", + "method": method, + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + async def mock_receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + async with manager.run(): + await manager.handle_request(scope, mock_receive, mock_send) + + response_start = next( + (msg for msg in sent_messages if msg["type"] == "http.response.start"), + None, + ) + return response_start, response_body + + +@pytest.mark.anyio +async def test_stateless_get_returns_405(): + """GET requests return 405 in stateless mode since SSE streams require session state.""" + response_start, response_body = await _collect_stateless_response("GET") + + assert response_start is not None + assert response_start["status"] == 405 + + headers = {name.decode().lower(): value.decode() for name, value in response_start.get("headers", [])} + assert headers.get("allow") == "POST" + + error_data = json.loads(response_body) + assert error_data["jsonrpc"] == "2.0" + assert error_data["id"] is None + assert error_data["error"]["code"] == INVALID_REQUEST + assert "GET" in error_data["error"]["message"] + assert "stateless" in error_data["error"]["message"].lower() + + +@pytest.mark.anyio +async def test_stateless_delete_returns_405(): + """DELETE requests return 405 in stateless mode since there is no session to terminate.""" + response_start, response_body = await _collect_stateless_response("DELETE") + + assert response_start is not None + assert response_start["status"] == 405 + + headers = {name.decode().lower(): value.decode() for name, value in response_start.get("headers", [])} + assert headers.get("allow") == "POST" + + error_data = json.loads(response_body) + assert error_data["jsonrpc"] == "2.0" + assert error_data["id"] is None + assert error_data["error"]["code"] == INVALID_REQUEST + assert "DELETE" in error_data["error"]["message"] + + +@pytest.mark.anyio +async def test_stateless_get_does_not_create_transport(): + """A GET in stateless mode should short-circuit without spinning up a transport.""" + app = Server("test-stateless-no-transport") + manager = StreamableHTTPSessionManager(app=app, stateless=True) + + created_transports: list[StreamableHTTPServerTransport] = [] + original_constructor = StreamableHTTPServerTransport + + def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: + transport = original_constructor(*args, **kwargs) # pragma: no cover + created_transports.append(transport) # pragma: no cover + return transport # pragma: no cover + + with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport): + async with manager.run(): + sent_messages: list[Message] = [] + + async def mock_send(message: Message): + sent_messages.append(message) + + scope = { + "type": "http", + "method": "GET", + "path": "/mcp", + "headers": [(b"accept", b"text/event-stream")], + } + + async def mock_receive(): # pragma: no cover + return {"type": "http.request", "body": b"", "more_body": False} + + await manager.handle_request(scope, mock_receive, mock_send) + + assert created_transports == [], "Stateless GET must not create a transport"