Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
Expand All @@ -21,7 +21,7 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -644,6 +644,15 @@ async def sse_writer():
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)

except ClientDisconnect:
# Client went away mid-request (network timeout, cancel, LB drop). Not a
# server error — log at WARNING and skip the response: the socket is gone.
# Notify the writer so the inner session task can unblock cleanly.
logger.warning("Client disconnected during POST request")
if writer is not None:
with suppress(Exception):
await writer.send(ClientDisconnect())
return
except Exception as err: # pragma: no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
Expand Down
Empty file.
199 changes: 199 additions & 0 deletions tests/server/streamable_http/test_client_disconnect_post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Tests for ClientDisconnect handling in StreamableHTTPServerTransport._handle_post_request.

Regression test for pattern 1: ClientDisconnect raised during POST should log at WARNING
(not ERROR) and should not attempt to send a response to the closed socket.

Inspired by upstream PRs:
- https://github.com/modelcontextprotocol/python-sdk/pull/1647 (scope: POST only)
- https://github.com/modelcontextprotocol/python-sdk/pull/1947 (semantics: notify writer, skip response)
"""

from __future__ import annotations as _annotations

import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from starlette.requests import ClientDisconnect, Request

from mcp.server.streamable_http import StreamableHTTPServerTransport


class TestClientDisconnectDuringPOST:
"""ClientDisconnect during POST should be handled gracefully."""

def _make_scope(self, headers: dict[str, bytes] | None = None) -> dict[str, Any]:
"""Build a minimal ASGI scope for a POST request."""
return {
"type": "http",
"method": "POST",
"path": "/mcp",
"query_string": b"",
"headers": list((headers or {}).items()) if headers else [
(b"content-type", b"application/json"),
(b"accept", b"application/json, text/event-stream"),
],
}

@pytest.mark.anyio
async def test_client_disconnect_logs_warning_not_error(self, caplog):
"""ClientDisconnect should produce a WARNING, not an ERROR."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

# Set up a dummy writer so the transport passes the None check
mock_writer = MagicMock()
mock_writer.send = AsyncMock()
transport._read_stream_writer = mock_writer

# Mock request.body() to raise ClientDisconnect (simulates client going away
# mid-request body upload).
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

send_calls: list[Any] = []

async def dummy_receive():
return {"type": "http.request", "body": b""}

async def dummy_send(message):
send_calls.append(message)

with caplog.at_level(logging.DEBUG, logger="mcp.server.streamable_http"):
await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# Should log a WARNING, not an ERROR
warning_records = [
r for r in caplog.records if r.levelno == logging.WARNING
and "Client disconnected" in r.getMessage()
]
error_records = [
r for r in caplog.records if r.levelno == logging.ERROR
]
assert len(warning_records) == 1, (
f"Expected exactly 1 WARNING with 'Client disconnected', got {len(warning_records)}"
)
assert len(error_records) == 0, (
f"Expected 0 ERROR logs, got {len(error_records)}: {[r.getMessage() for r in error_records]}"
)

@pytest.mark.anyio
async def test_client_disconnect_does_not_send_response(self):
"""After ClientDisconnect, no response should be sent (socket is closed)."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

# Set up a dummy writer so the transport passes the None check
mock_writer = MagicMock()
mock_writer.send = AsyncMock()
transport._read_stream_writer = mock_writer

mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

send_calls: list[Any] = []

async def dummy_receive():
return {"type": "http.request", "body": b""}

async def dummy_send(message):
send_calls.append(message)

await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# No HTTP response should be sent to the closed socket
assert len(send_calls) == 0, (
f"Expected no ASGI sends after ClientDisconnect, got {len(send_calls)}"
)

@pytest.mark.anyio
async def test_client_disconnect_notifies_writer(self):
"""Writer should receive ClientDisconnect so the inner session task can unblock."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

# Capture what the writer receives
writer_messages: list[Any] = []

async def capture_writer(msg):
writer_messages.append(msg)

# Patch the internal writer
with patch.object(transport, "_read_stream_writer", MagicMock(send=capture_writer)):
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

send_calls: list[Any] = []

async def dummy_receive():
return {"type": "http.request", "body": b""}

async def dummy_send(message):
send_calls.append(message)

await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# Writer should have been notified with ClientDisconnect
assert len(writer_messages) == 1, (
f"Expected writer to receive 1 message, got {len(writer_messages)}"
)
assert isinstance(writer_messages[0], ClientDisconnect), (
f"Expected ClientDisconnect sent to writer, got {type(writer_messages[0])}"
)

@pytest.mark.anyio
async def test_client_disconnect_writer_suppresses_errors(self):
"""If the writer itself is broken, we should not crash (suppress(Exception))."""
transport = StreamableHTTPServerTransport(mcp_session_id=None)
scope = self._make_scope()

broken_send = AsyncMock(side_effect=RuntimeError("writer is broken"))

with patch.object(transport, "_read_stream_writer", MagicMock(send=broken_send)):
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(side_effect=ClientDisconnect())
mock_request.headers = {
"content-type": "application/json",
"accept": "application/json, text/event-stream",
}
mock_request.scope = scope

async def dummy_receive():
return {"type": "http.request", "body": b""}

send_calls: list[Any] = []

async def dummy_send(message):
send_calls.append(message)

# Should not raise even though writer.send() fails
await transport._handle_post_request(
scope, mock_request, dummy_receive, dummy_send
)

# The broken writer.send was called once
broken_send.assert_called_once()
# No response was sent (socket is closed)
assert len(send_calls) == 0
4 changes: 2 additions & 2 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ async def test_stateless_get_returns_405():

error_data = json.loads(response_body)
assert error_data["jsonrpc"] == "2.0"
assert error_data["id"] is None
assert error_data["id"] == ""
assert error_data["error"]["code"] == INVALID_REQUEST
assert "GET" in error_data["error"]["message"]
assert "stateless" in error_data["error"]["message"].lower()
Expand All @@ -464,7 +464,7 @@ async def test_stateless_delete_returns_405():

error_data = json.loads(response_body)
assert error_data["jsonrpc"] == "2.0"
assert error_data["id"] is None
assert error_data["id"] == ""
assert error_data["error"]["code"] == INVALID_REQUEST
assert "DELETE" in error_data["error"]["message"]

Expand Down
Loading