diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc..341df0abb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index bb5639a13..27443ec87 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -20,6 +20,7 @@ from typing import Any import anyio +import anyio.abc from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError @@ -101,10 +102,17 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: self._on_request = on_request self._on_notify = on_notify self._ready.set() + task_status.started() await self._closed.wait() def close(self) -> None: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index ee02e2389..20c090323 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -20,6 +20,7 @@ from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable import anyio +import anyio.abc from mcp.shared.transport_context import TransportContext @@ -136,11 +137,21 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: """Drive the receive loop until the underlying channel closes. Each inbound request is dispatched to ``on_request`` in its own task; the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. + + ``task_status.started()`` is called once the dispatcher is ready to + accept ``send_request``/``notify`` calls, so callers can use + ``await tg.start(dispatcher.run, on_request, on_notify)``. """ ... diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 000000000..f1e7b3675 --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,543 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and +dicts — but it intercepts ``notifications/cancelled`` and +``notifications/progress`` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural ``match`` patterns (no casts, no +``mcp.types`` model coupling); a malformed payload simply fails to match and +the correlation is skipped. +""" + +from __future__ import annotations + +import contextvars +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, cast, overload + +import anyio +import anyio.abc +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError + +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + REQUEST_CANCELLED, + REQUEST_TIMEOUT, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound ``notifications/cancelled`` is applied to a running handler. + +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets +``ctx.cancel_requested`` and lets the handler observe it cooperatively. +""" + +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and +the `SessionMessage.metadata` the transport attached. Defaults to a plain +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + cancelled_by_peer: bool = False + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. ``None`` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[RequestId | None, MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None + self._running = False + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it ``None``. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or + ``CONNECTION_CLOSED`` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before ``run()`` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else None + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + meta = dict((out_params or {}).get("_meta") or {}) + meta["progressToken"] = request_id + out_params = {**(out_params or {}), "_meta": meta} + + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe — there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + try: + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately — shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled") + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; ``task_status.started()`` fires once that group is open, so + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_raw_request`` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + async with self._read_stream: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) + self._dispatch(item, on_request, on_notify, sender_ctx) + # Read stream EOF: wake any blocked `send_raw_request` waiters now, + # *before* the task group joins, so handlers parked in + # `dctx.send_raw_request()` can unwind and the join doesn't deadlock. + self._running = False + self._fan_out_closed() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. Synchronous: never awaits. + + Everything here is `send_nowait` or `_spawn`. An `await` would let one + slow message head-of-line block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): # pragma: no branch + # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. + self._resolve_pending(msg.id, msg.error) + + def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(req.id, metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.cancelled_by_peer = True + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(token) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + pending.on_progress, + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(None, metadata) + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(request_id) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes + the spawned handler inherit *that* context instead of the receive + loop's, so ``auth_context_var`` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending ``send_raw_request`` waiter with ``CONNECTION_CLOSED``. + + Synchronous (uses ``send_nowait``) because it's called from ``finally`` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run ``on_request`` for one inbound request and write its response. + + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to ``JSONRPCError``. Nothing above this in + the stack constructs wire errors. + """ + try: + with scope: + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Close the back-channel the moment the handler exits + # (success or raise), before the response write — a handler + # spawning detached work that later calls + # `dctx.send_raw_request()` should see `NoBackChannelError`. + dctx.close() + await self._write_result(req.id, result) + # Peer-cancel: `_dispatch_notification` cancelled this scope. anyio + # swallows a scope's *own* cancel at __exit__, so the result write + # (or the handler) is interrupted and execution lands here without + # reaching the `except cancelled` arm below. Spec SHOULD: send no + # response — fall through to `finally`. + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError as e: + await self._write_error(req.id, ErrorData(code=INVALID_PARAMS, message=str(e))) + except Exception as e: + logger.exception("handler for %r raised", req.method) + await self._write_error(req.id, ErrorData(code=INTERNAL_ERROR, message=str(e))) + if self._raise_handler_exceptions: + raise + finally: + self._in_flight.pop(req.id, None) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: + try: + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + pass diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 000000000..1222c05ab --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,61 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc"), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 784ef6698..bdadd4cda 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -1,8 +1,9 @@ -"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. +"""Behavioral tests for the Dispatcher Protocol. -These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using -the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in -``test_jsonrpc_dispatcher.py``. +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see ``conftest.py``); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. """ from collections.abc import AsyncIterator, Mapping @@ -14,10 +15,12 @@ from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT +from .conftest import PairFactory, direct_pair + class Recorder: def __init__(self) -> None: @@ -44,31 +47,34 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: @asynccontextmanager async def running_pair( + factory: PairFactory, *, server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, -) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" - client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client, server, close = factory(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() c_req, c_notify = echo_handlers(client_rec) s_req, s_notify = echo_handlers(server_rec) - async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) - try: - yield client, server, client_rec, server_rec - finally: - client.close() - server.close() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() @pytest.mark.anyio -async def test_send_raw_request_returns_result_from_peer_on_request(): - async with running_pair() as (client, _server, _crec, srec): +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} @@ -76,13 +82,13 @@ async def test_send_raw_request_returns_result_from_peer_on_request(): @pytest.mark.anyio -async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS @@ -90,36 +96,22 @@ async def on_request( @pytest.mark.anyio -async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): - async def on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - raise ValueError("oops") - - async with running_pair(server_on_request=on_request) as (client, *_): - with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_raw_request("tools/list", {}) - assert exc.value.error.code == INTERNAL_ERROR - assert isinstance(exc.value.__cause__, ValueError) - - -@pytest.mark.anyio -async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_request=on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @pytest.mark.anyio -async def test_notify_invokes_peer_on_notify(): - async with running_pair() as (client, _server, _crec, srec): +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.notify("notifications/initialized", {"v": 1}) await srec.notified.wait() @@ -127,7 +119,7 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_raw_request_round_trips_to_calling_side(): +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( @@ -136,7 +128,7 @@ async def server_on_request( sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] @@ -144,28 +136,27 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_raw_request("sampling/createMessage", None) - async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): - with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("tools/call", None) - assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio -async def test_ctx_notify_invokes_calling_side_on_notify(): +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await crec.notified.wait() @@ -173,7 +164,7 @@ async def server_on_request( @pytest.mark.anyio -async def test_ctx_progress_invokes_caller_on_progress_callback(): +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -185,14 +176,44 @@ async def server_on_request( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_request=server_on_request) as (client, *_): + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -212,21 +233,7 @@ async def late_start(): @pytest.mark.anyio -async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_request( - ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None - ) -> dict[str, Any]: - await ctx.progress(0.5) - return {"ok": True} - - async with running_pair(server_on_request=server_on_request) as (client, *_): - with anyio.fail_after(5): - result = await client.send_raw_request("tools/call", None) - assert result == {"ok": True} - - -@pytest.mark.anyio -async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): await d.send_raw_request("ping", None) @@ -235,7 +242,7 @@ async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_conne @pytest.mark.anyio -async def test_close_makes_run_return(): +async def test_direct_close_makes_run_return(): client, server = create_direct_dispatcher_pair() on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 000000000..7f9f11718 --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,531 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +``test_dispatcher.py``. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_raw_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_internal_error_with_str_message(): + """Per design: INTERNAL_ERROR carries str(e), not a scrubbed message.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_sets_cancel_requested_and_sends_no_response(): + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call_then_record() -> None: + with pytest.raises(MCPError): # we'll cancel via tg below + await client.send_raw_request("slow", None) + + tg.start_soon(call_then_record) + await handler_started.wait() + # cancel just the handler (peer-cancel), not our caller + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + # Handler torn down, no response was written; caller is still parked. + # Cancel the caller's task to end the test. + tg.cancel_scope.cancel() + assert seen_ctx[0].cancel_requested.is_set() + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result_box.append(await client.send_raw_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_raw_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_raw_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_rid: object, _meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == INTERNAL_ERROR + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("t", None) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_send_raw_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_raw_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +async def test_handler_inherits_sender_contextvars_via_spawn(): + """The handler task sees contextvars set by the task that wrote into the read stream.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy — close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_raw_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone — send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant — WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close()