Skip to content

Commit b07fa39

Browse files
committed
fix(streamable_http): handle ClientDisconnect during POST at WARNING level
Catch ClientDisconnect in _handle_post_request so client disconnections (network timeout, cancel, LB drop) log at WARNING instead of ERROR and do not attempt to send a response to the closed socket. This fixes pattern 1 (67% of all prod ERRORs in ai-codegen-api): when the client disconnects mid-POST, the current except-Exception handler synthesizes a 500 response to a closed socket and logs ERROR-level noise to Datadog. Changes: - Catch ClientDisconnect before the generic except-Exception handler - Log at WARNING (honest level: we aborted but it wasn't our fault) - Notify the session writer so inner session tasks unblock cleanly - Suppress errors from writer.send() since the writer may already be closed - Skip sending a response (the socket is gone) Upstream context: - modelcontextprotocol#1647 (POST-only scope, right approach) - modelcontextprotocol#1947 (writer-notification semantics, skip-response) - modelcontextprotocol#1648 (tracking issue, still open, no PRs merged) This patch takes the spirit of modelcontextprotocol#1647 (POST-only scope) + the semantics of modelcontextprotocol#1947 (notify writer, skip response) in a tight ~10-line change. Github-Issue:modelcontextprotocol#1648
1 parent 221aa6d commit b07fa39

3 files changed

Lines changed: 210 additions & 2 deletions

File tree

src/mcp/server/streamable_http.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import re
1313
from abc import ABC, abstractmethod
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
15-
from contextlib import asynccontextmanager
15+
from contextlib import asynccontextmanager, suppress
1616
from dataclasses import dataclass
1717
from http import HTTPStatus
1818
from typing import Any
@@ -21,7 +21,7 @@
2121
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2222
from pydantic import ValidationError
2323
from sse_starlette import EventSourceResponse
24-
from starlette.requests import Request
24+
from starlette.requests import ClientDisconnect, Request
2525
from starlette.responses import Response
2626
from starlette.types import Receive, Scope, Send
2727

@@ -644,6 +644,15 @@ async def sse_writer():
644644
await sse_stream_reader.aclose()
645645
await self._clean_up_memory_streams(request_id)
646646

647+
except ClientDisconnect:
648+
# Client went away mid-request (network timeout, cancel, LB drop). Not a
649+
# server error — log at WARNING and skip the response: the socket is gone.
650+
# Notify the writer so the inner session task can unblock cleanly.
651+
logger.warning("Client disconnected during POST request")
652+
if writer is not None:
653+
with suppress(Exception):
654+
await writer.send(ClientDisconnect())
655+
return
647656
except Exception as err: # pragma: no cover
648657
logger.exception("Error handling POST request")
649658
response = self._create_error_response(

tests/server/streamable_http/__init__.py

Whitespace-only changes.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""Tests for ClientDisconnect handling in StreamableHTTPServerTransport._handle_post_request.
2+
3+
Regression test for pattern 1: ClientDisconnect raised during POST should log at WARNING
4+
(not ERROR) and should not attempt to send a response to the closed socket.
5+
6+
Inspired by upstream PRs:
7+
- https://github.com/modelcontextprotocol/python-sdk/pull/1647 (scope: POST only)
8+
- https://github.com/modelcontextprotocol/python-sdk/pull/1947 (semantics: notify writer, skip response)
9+
"""
10+
11+
from __future__ import annotations as _annotations
12+
13+
import logging
14+
from typing import Any
15+
from unittest.mock import AsyncMock, MagicMock, patch
16+
17+
import pytest
18+
from starlette.requests import ClientDisconnect, Request
19+
20+
from mcp.server.streamable_http import StreamableHTTPServerTransport
21+
22+
23+
class TestClientDisconnectDuringPOST:
24+
"""ClientDisconnect during POST should be handled gracefully."""
25+
26+
def _make_scope(self, headers: dict[str, bytes] | None = None) -> dict[str, Any]:
27+
"""Build a minimal ASGI scope for a POST request."""
28+
return {
29+
"type": "http",
30+
"method": "POST",
31+
"path": "/mcp",
32+
"query_string": b"",
33+
"headers": list((headers or {}).items()) if headers else [
34+
(b"content-type", b"application/json"),
35+
(b"accept", b"application/json, text/event-stream"),
36+
],
37+
}
38+
39+
@pytest.mark.anyio
40+
async def test_client_disconnect_logs_warning_not_error(self, caplog):
41+
"""ClientDisconnect should produce a WARNING, not an ERROR."""
42+
transport = StreamableHTTPServerTransport(mcp_session_id=None)
43+
scope = self._make_scope()
44+
45+
# Set up a dummy writer so the transport passes the None check
46+
mock_writer = MagicMock()
47+
mock_writer.send = AsyncMock()
48+
transport._read_stream_writer = mock_writer
49+
50+
# Mock request.body() to raise ClientDisconnect (simulates client going away
51+
# mid-request body upload).
52+
mock_request = MagicMock(spec=Request)
53+
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
54+
mock_request.headers = {
55+
"content-type": "application/json",
56+
"accept": "application/json, text/event-stream",
57+
}
58+
mock_request.scope = scope
59+
60+
send_calls: list[Any] = []
61+
62+
async def dummy_receive():
63+
return {"type": "http.request", "body": b""}
64+
65+
async def dummy_send(message):
66+
send_calls.append(message)
67+
68+
with caplog.at_level(logging.DEBUG, logger="mcp.server.streamable_http"):
69+
await transport._handle_post_request(
70+
scope, mock_request, dummy_receive, dummy_send
71+
)
72+
73+
# Should log a WARNING, not an ERROR
74+
warning_records = [
75+
r for r in caplog.records if r.levelno == logging.WARNING
76+
and "Client disconnected" in r.getMessage()
77+
]
78+
error_records = [
79+
r for r in caplog.records if r.levelno == logging.ERROR
80+
]
81+
assert len(warning_records) == 1, (
82+
f"Expected exactly 1 WARNING with 'Client disconnected', got {len(warning_records)}"
83+
)
84+
assert len(error_records) == 0, (
85+
f"Expected 0 ERROR logs, got {len(error_records)}: {[r.getMessage() for r in error_records]}"
86+
)
87+
88+
@pytest.mark.anyio
89+
async def test_client_disconnect_does_not_send_response(self):
90+
"""After ClientDisconnect, no response should be sent (socket is closed)."""
91+
transport = StreamableHTTPServerTransport(mcp_session_id=None)
92+
scope = self._make_scope()
93+
94+
# Set up a dummy writer so the transport passes the None check
95+
mock_writer = MagicMock()
96+
mock_writer.send = AsyncMock()
97+
transport._read_stream_writer = mock_writer
98+
99+
mock_request = MagicMock(spec=Request)
100+
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
101+
mock_request.headers = {
102+
"content-type": "application/json",
103+
"accept": "application/json, text/event-stream",
104+
}
105+
mock_request.scope = scope
106+
107+
send_calls: list[Any] = []
108+
109+
async def dummy_receive():
110+
return {"type": "http.request", "body": b""}
111+
112+
async def dummy_send(message):
113+
send_calls.append(message)
114+
115+
await transport._handle_post_request(
116+
scope, mock_request, dummy_receive, dummy_send
117+
)
118+
119+
# No HTTP response should be sent to the closed socket
120+
assert len(send_calls) == 0, (
121+
f"Expected no ASGI sends after ClientDisconnect, got {len(send_calls)}"
122+
)
123+
124+
@pytest.mark.anyio
125+
async def test_client_disconnect_notifies_writer(self):
126+
"""Writer should receive ClientDisconnect so the inner session task can unblock."""
127+
transport = StreamableHTTPServerTransport(mcp_session_id=None)
128+
scope = self._make_scope()
129+
130+
# Capture what the writer receives
131+
writer_messages: list[Any] = []
132+
133+
async def capture_writer(msg):
134+
writer_messages.append(msg)
135+
136+
# Patch the internal writer
137+
with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)):
138+
mock_request = MagicMock(spec=Request)
139+
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
140+
mock_request.headers = {
141+
"content-type": "application/json",
142+
"accept": "application/json, text/event-stream",
143+
}
144+
mock_request.scope = scope
145+
146+
send_calls: list[Any] = []
147+
148+
async def dummy_receive():
149+
return {"type": "http.request", "body": b""}
150+
151+
async def dummy_send(message):
152+
send_calls.append(message)
153+
154+
await transport._handle_post_request(
155+
scope, mock_request, dummy_receive, dummy_send
156+
)
157+
158+
# Writer should have been notified with ClientDisconnect
159+
assert len(writer_messages) == 1, (
160+
f"Expected writer to receive 1 message, got {len(writer_messages)}"
161+
)
162+
assert isinstance(writer_messages[0], ClientDisconnect), (
163+
f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}"
164+
)
165+
166+
@pytest.mark.anyio
167+
async def test_client_disconnect_writer_suppresses_errors(self):
168+
"""If the writer itself is broken, we should not crash (suppress(Exception))."""
169+
transport = StreamableHTTPServerTransport(mcp_session_id=None)
170+
scope = self._make_scope()
171+
172+
broken_send = AsyncMock(side_effect=RuntimeError("writer is broken"))
173+
174+
with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)):
175+
mock_request = MagicMock(spec=Request)
176+
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
177+
mock_request.headers = {
178+
"content-type": "application/json",
179+
"accept": "application/json, text/event-stream",
180+
}
181+
mock_request.scope = scope
182+
183+
async def dummy_receive():
184+
return {"type": "http.request", "body": b""}
185+
186+
send_calls: list[Any] = []
187+
188+
async def dummy_send(message):
189+
send_calls.append(message)
190+
191+
# Should not raise even though writer.send() fails
192+
await transport._handle_post_request(
193+
scope, mock_request, dummy_receive, dummy_send
194+
)
195+
196+
# The broken writer.send was called once
197+
broken_send.assert_called_once()
198+
# No response was sent (socket is closed)
199+
assert len(send_calls) == 0

0 commit comments

Comments
 (0)