diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 72f88af44..4530ac6e3 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -649,11 +649,8 @@ async def sse_writer(): # server error — log at WARNING and send a response so middleware chains # (e.g. Starlette BaseHTTPMiddleware) don't raise "No response returned". # The ASGI server will drop the response if the socket is already closed. - # Notify the writer so the inner session task can unblock cleanly. + # The connect() context manager will close streams, unblocking the session task. logger.warning("Client disconnected during POST request") - if writer is not None: - with suppress(Exception): - await writer.send(ClientDisconnect()) # 499 = Client Closed Request (nginx convention, not in stdlib HTTPStatus) response = self._create_json_response(None, 499) # type: ignore[arg-type] with suppress(Exception): diff --git a/tests/server/streamable_http/test_client_disconnect_post.py b/tests/server/streamable_http/test_client_disconnect_post.py index cc5d5062c..a20d97542 100644 --- a/tests/server/streamable_http/test_client_disconnect_post.py +++ b/tests/server/streamable_http/test_client_disconnect_post.py @@ -12,7 +12,7 @@ import logging from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from starlette.requests import ClientDisconnect, Request @@ -125,81 +125,4 @@ async def dummy_send(message): assert send_calls[0]["type"] == "http.response.start" assert send_calls[0]["status"] == 499 - @pytest.mark.anyio - async def test_client_disconnect_notifies_writer(self): - """Writer should receive ClientDisconnect so the inner session task can unblock.""" - transport = StreamableHTTPServerTransport(mcp_session_id=None) - scope = self._make_scope() - - # Capture what the writer receives - writer_messages: list[Any] = [] - - async def capture_writer(msg): - writer_messages.append(msg) - - # Patch the internal writer - with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)): - mock_request = MagicMock(spec=Request) - mock_request.body = AsyncMock(side_effect=ClientDisconnect()) - mock_request.headers = { - "content-type": "application/json", - "accept": "application/json, text/event-stream", - } - mock_request.scope = scope - - send_calls: list[Any] = [] - - async def dummy_receive(): - return {"type": "http.request", "body": b""} - - async def dummy_send(message): - send_calls.append(message) - - await transport._handle_post_request( - scope, mock_request, dummy_receive, dummy_send - ) - - # Writer should have been notified with ClientDisconnect - assert len(writer_messages) == 1, ( - f"Expected writer to receive 1 message, got {len(writer_messages)}" - ) - assert isinstance(writer_messages[0], ClientDisconnect), ( - f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}" - ) - @pytest.mark.anyio - async def test_client_disconnect_writer_suppresses_errors(self): - """If the writer itself is broken, we should not crash (suppress(Exception)).""" - transport = StreamableHTTPServerTransport(mcp_session_id=None) - scope = self._make_scope() - - broken_send = AsyncMock(side_effect=RuntimeError("writer is broken")) - - with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)): - mock_request = MagicMock(spec=Request) - mock_request.body = AsyncMock(side_effect=ClientDisconnect()) - mock_request.headers = { - "content-type": "application/json", - "accept": "application/json, text/event-stream", - } - mock_request.scope = scope - - async def dummy_receive(): - return {"type": "http.request", "body": b""} - - send_calls: list[Any] = [] - - async def dummy_send(message): - send_calls.append(message) - - # Should not raise even though writer.send() fails - await transport._handle_post_request( - scope, mock_request, dummy_receive, dummy_send - ) - - # The broken writer.send was called once (suppressed) - broken_send.assert_called_once() - # Response is still sent even though writer was broken - assert len(send_calls) >= 1 - assert send_calls[0]["type"] == "http.response.start" - assert send_calls[0]["status"] == 499