Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,11 +649,8 @@ async def sse_writer():
# server error — log at WARNING and send a response so middleware chains
# (e.g. Starlette BaseHTTPMiddleware) don't raise "No response returned".
# The ASGI server will drop the response if the socket is already closed.
# Notify the writer so the inner session task can unblock cleanly.
# The connect() context manager will close streams, unblocking the session task.
logger.warning("Client disconnected during POST request")
if writer is not None:
with suppress(Exception):
await writer.send(ClientDisconnect())
# 499 = Client Closed Request (nginx convention, not in stdlib HTTPStatus)
response = self._create_json_response(None, 499) # type: ignore[arg-type]
with suppress(Exception):
Expand Down
79 changes: 1 addition & 78 deletions tests/server/streamable_http/test_client_disconnect_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock

import pytest
from starlette.requests import ClientDisconnect, Request
Expand Down Expand Up @@ -125,81 +125,4 @@ async def dummy_send(message):
assert send_calls[0]["type"] == "http.response.start"
assert send_calls[0]["status"] == 499

@pytest.mark.anyio
async def test_client_disconnect_notifies_writer(self):
"""Writer should receive ClientDisconnect so the inner session task can unblock."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

# Capture what the writer receives
writer_messages: list[Any] = []

async def capture_writer(msg):
writer_messages.append(msg)

# Patch the internal writer
with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)):
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

send_calls: list[Any] = []

async def dummy_receive():
return {"type": "http.request", "body": b""}

async def dummy_send(message):
send_calls.append(message)

await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# Writer should have been notified with ClientDisconnect
assert len(writer_messages) == 1, (
f"Expected writer to receive 1 message, got {len(writer_messages)}"
)
assert isinstance(writer_messages[0], ClientDisconnect), (
f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}"
)

@pytest.mark.anyio
async def test_client_disconnect_writer_suppresses_errors(self):
"""If the writer itself is broken, we should not crash (suppress(Exception))."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

broken_send = AsyncMock(side_effect=RuntimeError("writer is broken"))

with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)):
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

async def dummy_receive():
return {"type": "http.request", "body": b""}

send_calls: list[Any] = []

async def dummy_send(message):
send_calls.append(message)

# Should not raise even though writer.send() fails
await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# The broken writer.send was called once (suppressed)
broken_send.assert_called_once()
# Response is still sent even though writer was broken
assert len(send_calls) >= 1
assert send_calls[0]["type"] == "http.response.start"
assert send_calls[0]["status"] == 499
Loading