From f332c596aa327ccf0d548015759c4b1849fef0d6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:56:49 +0000 Subject: [PATCH 1/4] feat: add Dispatcher Protocol and DirectDispatcher Introduces the Dispatcher abstraction that decouples MCP request/response handling from JSON-RPC framing. A Dispatcher exposes call/notify for outbound messages and run(on_call, on_notify) for inbound dispatch, with no knowledge of MCP types or wire encoding. - shared/dispatcher.py: Dispatcher, DispatchContext, RequestSender Protocols; CallOptions, OnCall/OnNotify, ProgressFnT, DispatchMiddleware - shared/transport_context.py: TransportContext base dataclass - shared/direct_dispatcher.py: in-memory Dispatcher impl that wires two peers with no transport; serves as a fast test substrate and second-impl proof - shared/exceptions.py: NoBackChannelError(MCPError) for transports without a server-to-client request channel - types: REQUEST_CANCELLED SDK error code The JSON-RPC implementation and ServerRunner that consume this Protocol land in follow-up PRs. --- src/mcp/shared/direct_dispatcher.py | 173 +++++++++++++++++++ src/mcp/shared/dispatcher.py | 167 ++++++++++++++++++ src/mcp/shared/exceptions.py | 21 ++- src/mcp/shared/transport_context.py | 30 ++++ src/mcp/types/__init__.py | 2 + src/mcp/types/jsonrpc.py | 1 + tests/shared/test_dispatcher.py | 253 ++++++++++++++++++++++++++++ 7 files changed, 646 insertions(+), 1 deletion(-) create mode 100644 src/mcp/shared/direct_dispatcher.py create mode 100644 src/mcp/shared/dispatcher.py create mode 100644 src/mcp/shared/transport_context.py create mode 100644 tests/shared/test_dispatcher.py diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 000000000..465061942 --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,173 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call +on one side directly invokes the other side's `on_call`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller — there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio + +from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound call on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_request` reaches the peer that made the inbound call. + """ + + transport: TransportContext + _back_call: _Call + _back_notify: _Notify + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.transport.can_send_request: + raise NoBackChannelError(method) + return await self._back_call(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `call` on one awaits the peer's `on_call`. + `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_call: OnCall | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_call(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + self._on_call = on_call + self._on_notify = on_notify + self._ready.set() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_call is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_call(dctx, method, params) + except MCPError: + raise + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass ``False`` to simulate a transport with no back-channel. + + Returns: + A ``(left, right)`` pair. Conventionally ``left`` is the client side + and ``right`` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 000000000..09e5e87bb --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,167 @@ +"""Dispatcher Protocol — the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: ``call(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and + invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are ``dict[str, Any]``. The MCP type layer (request/result models, +capability negotiation, ``Context``) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See ``JSONRPCDispatcher`` for the production implementation and +``DirectDispatcher`` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio + +from mcp.shared.transport_context import TransportContext + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnCall", + "OnNotify", + "ProgressFnT", + "RequestSender", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending call.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" + + on_progress: ProgressFnT + """Receive ``notifications/progress`` updates for this call.""" + + resumption_token: str + """Opaque token to resume a previously interrupted call (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class RequestSender(Protocol): + """Anything that can send a request and await its result. + + Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` + (for server-to-client calls made *during* an inbound request) satisfy this. + """ + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: ... + + +class DispatchContext(Protocol[TransportT_co]): + """Per-request context handed to ``on_call`` / ``on_notify``. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends ``notifications/cancelled`` for this request.""" + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer.""" + ... + + async def send_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel and await its result. + + Raises: + NoBackChannelError: if ``transport.can_send_request`` is ``False``. + """ + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: ``(ctx, method, params)``.""" + +DispatchMiddleware = Callable[[OnCall], OnCall] +"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" + + +class Dispatcher(Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound calls to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def call( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to ``on_call`` in its own task; the + returned dict (or raised ``MCPError``) is sent back as the response. + Inbound notifications go to ``on_notify``. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..e9dd2c843 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_request` when `transport.can_send_request` + is ``False``, and serializes to an ``INVALID_REQUEST`` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 000000000..31230fda9 --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,30 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. ``"stdio"``, ``"streamable-http"``).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for + stdio, SSE, and stateful streamable HTTP. When ``False``, + `DispatchContext.send_request` raises `NoBackChannelError`. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b44230393..ca1c32893 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -192,6 +192,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -401,6 +402,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c..14743c33b 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 000000000..dd8d40721 --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,253 @@ +"""Behavioral tests for the Dispatcher Protocol via DirectDispatcher. + +These exercise the `Dispatcher` / `DispatchContext` contract end-to-end using +the in-memory `DirectDispatcher`. JSON-RPC framing is covered separately in +``test_jsonrpc_dispatcher.py``. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT + + +class Recorder: + def __init__(self) -> None: + self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + recorder.calls.append((method, params)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": dict(params or {})} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_call, on_notify + + +@asynccontextmanager +async def running_pair( + *, + server_on_call: OnCall | None = None, + server_on_notify: OnNotify | None = None, + client_on_call: OnCall | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: + """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_call, c_notify = echo_handlers(client_rec) + s_call, s_notify = echo_handlers(server_rec) + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + client.close() + server.close() + + +@pytest.mark.anyio +async def test_call_returns_result_from_peer_on_call(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.call("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.calls == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_call_reraises_mcperror_from_handler_unchanged(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_call_wraps_non_mcperror_exception_as_internal_error(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_call_with_timeout_raises_mcperror_request_timeout(): + async def on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + return {} + + async with running_pair(server_on_call=on_call) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.call("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(): + async with running_pair() as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_request_round_trips_to_calling_side(): + """A handler's ctx.send_request reaches the side that made the inbound call.""" + + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.send_request("sampling/createMessage", None) + return {} + + async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: + await client.call("tools/call", None) + assert exc.value.method == "sampling/createMessage" + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.call("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + await client.call("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_call_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_call, s_notify = echo_handlers(Recorder()) + c_call, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_call, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.call("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): + async def server_on_call( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(server_on_call=server_on_call) as (client, *_): + with anyio.fail_after(5): + result = await client.call("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.call("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_call, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_call, on_notify) + tg.start_soon(client.run, on_call, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( + TransportContext(kind="direct", can_send_request=True) + ) From 5540d807be15cfde7a794ce7709c9fc7d7bb5a3a Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:03:48 +0000 Subject: [PATCH 2/4] fix: address coverage gaps and stale RequestSender docstring - tests: replace unreachable 'return {}' with 'raise NotImplementedError' (already in coverage exclude_also) and collapse send_request+return into one statement - dispatcher: RequestSender docstring no longer claims Dispatcher satisfies it (Dispatcher exposes call(), not send_request()) --- src/mcp/shared/dispatcher.py | 4 ++-- tests/shared/test_dispatcher.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 09e5e87bb..b63c00c0b 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -66,8 +66,8 @@ class CallOptions(TypedDict, total=False): class RequestSender(Protocol): """Anything that can send a request and await its result. - Both `Dispatcher` (for top-level outbound calls) and `DispatchContext` - (for server-to-client calls made *during* an inbound request) satisfy this. + `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap + a `RequestSender` to provide typed request methods. """ async def send_request( diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index dd8d40721..ddfe1f798 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -109,7 +109,7 @@ async def on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() - return {} + raise NotImplementedError async with running_pair(server_on_call=on_call) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: @@ -148,8 +148,7 @@ async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallo async def server_on_call( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - await ctx.send_request("sampling/createMessage", None) - return {} + return await ctx.send_request("sampling/createMessage", None) async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: From 1da25ec1828925fcd61882f46db40ea7ecdda2fe Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 12:52:58 +0000 Subject: [PATCH 3/4] refactor: rename Dispatcher.call to send_request, replace RequestSender with Outbound MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The design doc's `send_request = call` alias only makes the concrete class satisfy RequestSender, not the abstract Dispatcher Protocol — so any consumer typed against `Dispatcher[TT]` (Connection, ServerRunner) couldn't pass it to something expecting a RequestSender without a cast or hand-written bridge. RequestSender was also half a contract: every implementor (Dispatcher, DispatchContext, Connection, Context) has `notify` too, and PeerMixin needs both for its typed sugar (elicit/sample are requests, log is a notification). Outbound(Protocol) declares both methods; Dispatcher and DispatchContext extend it. PeerMixin will wrap an Outbound. One verb everywhere, no aliases, no extra Protocols. - Dispatcher.call -> send_request - OnCall -> OnRequest, on_call -> on_request - RequestSender -> Outbound (now also declares notify) - Dispatcher(Outbound, Protocol[TT]), DispatchContext(Outbound, Protocol[TT]) --- src/mcp/shared/direct_dispatcher.py | 38 ++++----- src/mcp/shared/dispatcher.py | 100 ++++++++++-------------- tests/shared/test_dispatcher.py | 115 ++++++++++++++-------------- 3 files changed, 115 insertions(+), 138 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 465061942..79b68d054 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -1,7 +1,7 @@ """In-memory `Dispatcher` that wires two peers together with no transport. -`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a call -on one side directly invokes the other side's `on_call`. There is no +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no serialization, no JSON-RPC framing, and no streams. It exists to: * prove the `Dispatcher` Protocol is implementable without JSON-RPC @@ -21,7 +21,7 @@ import anyio -from mcp.shared.dispatcher import CallOptions, OnCall, OnNotify, ProgressFnT +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, REQUEST_TIMEOUT @@ -31,20 +31,20 @@ DIRECT_TRANSPORT_KIND = "direct" -_Call = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] _Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] @dataclass class _DirectDispatchContext: - """`DispatchContext` for an inbound call on a `DirectDispatcher`. + """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound call. + `send_request` reaches the peer that made the inbound request. """ transport: TransportContext - _back_call: _Call + _back_request: _Request _back_notify: _Notify _on_progress: ProgressFnT | None = None cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -60,7 +60,7 @@ async def send_request( ) -> dict[str, Any]: if not self.transport.can_send_request: raise NoBackChannelError(method) - return await self._back_call(method, params, opts) + return await self._back_request(method, params, opts) async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: if self._on_progress is not None: @@ -71,14 +71,14 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `call` on one awaits the peer's `on_call`. - `run` parks until `close` is called. + holds a reference to the other. `send_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. """ def __init__(self, transport_ctx: TransportContext): self._transport_ctx = transport_ctx self._peer: DirectDispatcher | None = None - self._on_call: OnCall | None = None + self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None self._ready = anyio.Event() self._closed = anyio.Event() @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def call( + async def send_request( self, method: str, params: Mapping[str, Any] | None, @@ -94,15 +94,15 @@ async def call( ) -> dict[str, Any]: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") - return await self._peer._dispatch_call(method, params, opts) + return await self._peer._dispatch_request(method, params, opts) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: if self._peer is None: raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") await self._peer._dispatch_notify(method, params) - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: - self._on_call = on_call + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: + self._on_request = on_request self._on_notify = on_notify self._ready.set() await self._closed.wait() @@ -115,25 +115,25 @@ def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispat peer = self._peer return _DirectDispatchContext( transport=self._transport_ctx, - _back_call=lambda m, p, o: peer._dispatch_call(m, p, o), + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), _back_notify=lambda m, p: peer._dispatch_notify(m, p), _on_progress=on_progress, ) - async def _dispatch_call( + async def _dispatch_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None, ) -> dict[str, Any]: await self._ready.wait() - assert self._on_call is not None + assert self._on_request is not None opts = opts or {} dctx = self._make_context(on_progress=opts.get("on_progress")) try: with anyio.fail_after(opts.get("timeout")): try: - return await self._on_call(dctx, method, params) + return await self._on_request(dctx, method, params) except MCPError: raise except Exception as e: diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index b63c00c0b..872fb01ea 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,9 +2,9 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``call(method, params)`` and ``notify(method, params)`` -* an inbound pump: ``run(on_call, on_notify)`` that drives the receive loop and - invokes the supplied handlers for each incoming request/notification +* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification It is deliberately *not* MCP-aware. Method names are strings, params and results are ``dict[str, Any]``. The MCP type layer (request/result models, @@ -28,23 +28,23 @@ "DispatchContext", "DispatchMiddleware", "Dispatcher", - "OnCall", "OnNotify", + "OnRequest", + "Outbound", "ProgressFnT", - "RequestSender", ] TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) class ProgressFnT(Protocol): - """Callback invoked when a progress notification arrives for a pending call.""" + """Callback invoked when a progress notification arrives for a pending request.""" async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... class CallOptions(TypedDict, total=False): - """Per-call options for `RequestSender.send_request` / `Dispatcher.call`. + """Per-call options for `Outbound.send_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -53,21 +53,22 @@ class CallOptions(TypedDict, total=False): """Seconds to wait for a result before raising and sending ``notifications/cancelled``.""" on_progress: ProgressFnT - """Receive ``notifications/progress`` updates for this call.""" + """Receive ``notifications/progress`` updates for this request.""" resumption_token: str - """Opaque token to resume a previously interrupted call (transport-dependent).""" + """Opaque token to resume a previously interrupted request (transport-dependent).""" on_resumption_token: Callable[[str], Awaitable[None]] """Receive a resumption token when the transport issues one.""" @runtime_checkable -class RequestSender(Protocol): - """Anything that can send a request and await its result. +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. - `DispatchContext` satisfies this; `PeerMixin` (and `Connection`/`Peer`) wrap - a `RequestSender` to provide typed request methods. + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to + provide typed MCP request/notification methods. """ async def send_request( @@ -75,15 +76,28 @@ async def send_request( method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, - ) -> dict[str, Any]: ... + ) -> dict[str, Any]: + """Send a request and await its result. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... -class DispatchContext(Protocol[TransportT_co]): - """Per-request context handed to ``on_call`` / ``on_notify``. + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to ``on_request`` / ``on_notify``. Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. + it. `send_request` raises `NoBackChannelError` if + ``transport.can_send_request`` is ``False``. """ @property @@ -96,23 +110,6 @@ def cancel_requested(self) -> anyio.Event: """Set when the peer sends ``notifications/cancelled`` for this request.""" ... - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a notification to the peer.""" - ... - - async def send_request( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request to the peer on the back-channel and await its result. - - Raises: - NoBackChannelError: if ``transport.can_send_request`` is ``False``. - """ - ... - async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the inbound request, if the peer supplied a progress token. @@ -121,47 +118,28 @@ async def progress(self, progress: float, total: float | None = None, message: s ... -OnCall = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] """Handler for inbound requests: ``(ctx, method, params) -> result``. Raise ``MCPError`` to send an error response.""" OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] """Handler for inbound notifications: ``(ctx, method, params)``.""" -DispatchMiddleware = Callable[[OnCall], OnCall] -"""Wraps an ``OnCall`` to produce another ``OnCall``. Applied outermost-first.""" +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an ``OnRequest`` to produce another ``OnRequest``. Applied outermost-first.""" -class Dispatcher(Protocol[TransportT_co]): +class Dispatcher(Outbound, Protocol[TransportT_co]): """A duplex request/notification channel with call-return semantics. - Implementations own correlation of outbound calls to inbound results, the + Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. """ - async def call( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - ) -> dict[str, Any]: - """Send a request and await its result. - - Raises: - MCPError: If the peer responded with an error, or the handler - raised. Implementations normalize all handler exceptions to - `MCPError` so callers see a single exception type. - """ - ... - - async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - """Send a fire-and-forget notification.""" - ... - - async def run(self, on_call: OnCall, on_notify: OnNotify) -> None: + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: """Drive the receive loop until the underlying channel closes. - Each inbound request is dispatched to ``on_call`` in its own task; the - returned dict (or raised ``MCPError``) is sent back as the response. + Each inbound request is dispatched to ``on_request`` in its own task; + the returned dict (or raised ``MCPError``) is sent back as the response. Inbound notifications go to ``on_notify``. """ ... diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index ddfe1f798..44ab622ad 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -13,7 +13,7 @@ import pytest from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair -from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnCall, OnNotify +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.transport_context import TransportContext from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT @@ -21,17 +21,17 @@ class Recorder: def __init__(self) -> None: - self.calls: list[tuple[str, Mapping[str, Any] | None]] = [] + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] self.contexts: list[DispatchContext[TransportContext]] = [] self.notified = anyio.Event() -def echo_handlers(recorder: Recorder) -> tuple[OnCall, OnNotify]: - async def on_call( +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - recorder.calls.append((method, params)) + recorder.requests.append((method, params)) recorder.contexts.append(ctx) return {"echoed": method, "params": dict(params or {})} @@ -39,26 +39,26 @@ async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: recorder.notifications.append((method, params)) recorder.notified.set() - return on_call, on_notify + return on_request, on_notify @asynccontextmanager async def running_pair( *, - server_on_call: OnCall | None = None, + server_on_request: OnRequest | None = None, server_on_notify: OnNotify | None = None, - client_on_call: OnCall | None = None, + client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, can_send_request: bool = True, ) -> AsyncIterator[tuple[DirectDispatcher, DirectDispatcher, Recorder, Recorder]]: """Yield ``(client, server, client_recorder, server_recorder)`` with both ``run()`` loops live.""" client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) client_rec, server_rec = Recorder(), Recorder() - c_call, c_notify = echo_handlers(client_rec) - s_call, s_notify = echo_handlers(server_rec) + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, client_on_call or c_call, client_on_notify or c_notify) - tg.start_soon(server.run, server_on_call or s_call, server_on_notify or s_notify) + tg.start_soon(client.run, client_on_request or c_req, client_on_notify or c_notify) + tg.start_soon(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec finally: @@ -67,53 +67,53 @@ async def running_pair( @pytest.mark.anyio -async def test_call_returns_result_from_peer_on_call(): +async def test_send_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.call("tools/list", {"cursor": "abc"}) + result = await client.send_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} - assert srec.calls == [("tools/list", {"cursor": "abc"})] + assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_call_reraises_mcperror_from_handler_unchanged(): - async def on_call( +async def test_send_request_reraises_mcperror_from_handler_unchanged(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise MCPError(code=INVALID_PARAMS, message="bad cursor") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_call_wraps_non_mcperror_exception_as_internal_error(): - async def on_call( +async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: raise ValueError("oops") - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("tools/list", {}) + await client.send_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_call_with_timeout_raises_mcperror_request_timeout(): - async def on_call( +async def test_send_request_with_timeout_raises_mcperror_request_timeout(): + async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await anyio.sleep_forever() raise NotImplementedError - async with running_pair(server_on_call=on_call) as (client, *_): + async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.call("slow", None, {"timeout": 0}) + await client.send_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -128,53 +128,53 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound call.""" + """A handler's ctx.send_request reaches the side that made the inbound request.""" - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.call("tools/call", None) - assert crec.calls == [("sampling/createMessage", {"prompt": "hi"})] + result = await client.send_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: return await ctx.send_request("sampling/createMessage", None) - async with running_pair(server_on_call=server_on_call, can_send_request=False) as (client, *_): + async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.call("tools/call", None) + await client.send_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @pytest.mark.anyio async def test_ctx_notify_invokes_calling_side_on_notify(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.notify("notifications/message", {"level": "info"}) return {} - async with running_pair(server_on_call=server_on_call) as (client, _server, crec, _srec): + async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.call("tools/call", None) + await client.send_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @pytest.mark.anyio async def test_ctx_progress_invokes_caller_on_progress_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5, total=1.0, message="halfway") @@ -185,27 +185,27 @@ async def server_on_call( async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.call("tools/call", None, {"on_progress": on_progress}) + await client.send_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_call_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() - s_call, s_notify = echo_handlers(Recorder()) - c_call, c_notify = echo_handlers(Recorder()) + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) async def late_start(): await anyio.sleep(0) - await server.run(s_call, s_notify) + await server.run(s_req, s_notify) async with anyio.create_task_group() as tg: - tg.start_soon(client.run, c_call, c_notify) + tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.call("ping", None) + result = await client.send_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -213,23 +213,23 @@ async def late_start(): @pytest.mark.anyio async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(): - async def server_on_call( + async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: await ctx.progress(0.5) return {"ok": True} - async with running_pair(server_on_call=server_on_call) as (client, *_): + async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.call("tools/call", None) + result = await client.send_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.call("ping", None) + await d.send_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None) @@ -237,16 +237,15 @@ async def test_call_and_notify_raise_runtimeerror_when_no_peer_connected(): @pytest.mark.anyio async def test_close_makes_run_return(): client, server = create_direct_dispatcher_pair() - on_call, on_notify = echo_handlers(Recorder()) + on_request, on_notify = echo_handlers(Recorder()) with anyio.fail_after(5): async with anyio.create_task_group() as tg: - tg.start_soon(server.run, on_call, on_notify) - tg.start_soon(client.run, on_call, on_notify) + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) client.close() server.close() if TYPE_CHECKING: - _dispatcher_check: Dispatcher[TransportContext] = DirectDispatcher( - TransportContext(kind="direct", can_send_request=True) - ) + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d From bfb5a771278e926c6398a9d3f91456dbd7427b36 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 16 Apr 2026 21:38:32 +0000 Subject: [PATCH 4/4] refactor: rename Outbound.send_request to send_raw_request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dispatcher-layer raw channel is now `send_raw_request(method, params) -> dict`. This frees the `send_request` name for the typed surface (`send_request(req: Request) -> Result`) that Connection/Context/Client add in later PRs. Mechanical rename across Outbound, Dispatcher, DispatchContext, DirectDispatcher, _DirectDispatchContext, and all tests. `can_send_request` (the transport capability flag) is unchanged — it names the capability, not the method. --- src/mcp/shared/direct_dispatcher.py | 8 +++--- src/mcp/shared/dispatcher.py | 15 +++++----- src/mcp/shared/exceptions.py | 2 +- src/mcp/shared/transport_context.py | 2 +- tests/shared/test_dispatcher.py | 44 ++++++++++++++--------------- 5 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 79b68d054..bb5639a13 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -40,7 +40,7 @@ class _DirectDispatchContext: """`DispatchContext` for an inbound request on a `DirectDispatcher`. The back-channel callables target the *originating* side, so a handler's - `send_request` reaches the peer that made the inbound request. + `send_raw_request` reaches the peer that made the inbound request. """ transport: TransportContext @@ -52,7 +52,7 @@ class _DirectDispatchContext: async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: await self._back_notify(method, params) - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, @@ -71,7 +71,7 @@ class DirectDispatcher: """A `Dispatcher` that calls a peer's handlers directly, in-process. Two instances are wired together with `create_direct_dispatcher_pair`; each - holds a reference to the other. `send_request` on one awaits the peer's + holds a reference to the other. `send_raw_request` on one awaits the peer's `on_request`. `run` parks until `close` is called. """ @@ -86,7 +86,7 @@ def __init__(self, transport_ctx: TransportContext): def connect_to(self, peer: DirectDispatcher) -> None: self._peer = peer - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 872fb01ea..ee02e2389 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -2,7 +2,7 @@ A Dispatcher turns a duplex message channel into two things: -* an outbound API: ``send_request(method, params)`` and ``notify(method, params)`` +* an outbound API: ``send_raw_request(method, params)`` and ``notify(method, params)`` * an inbound pump: ``run(on_request, on_notify)`` that drives the receive loop and invokes the supplied handlers for each incoming request/notification @@ -44,7 +44,7 @@ async def __call__(self, progress: float, total: float | None, message: str | No class CallOptions(TypedDict, total=False): - """Per-call options for `Outbound.send_request`. + """Per-call options for `Outbound.send_raw_request`. All keys are optional. Dispatchers ignore keys they do not understand. """ @@ -67,17 +67,18 @@ class Outbound(Protocol): """Anything that can send requests and notifications to the peer. Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel - during an inbound request) extend this. `PeerMixin` wraps an `Outbound` to - provide typed MCP request/notification methods. + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed ``send_request`` / convenience methods + on top of this raw channel. """ - async def send_request( + async def send_raw_request( self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, ) -> dict[str, Any]: - """Send a request and await its result. + """Send a request and await its raw result dict. Raises: MCPError: If the peer responded with an error, or the handler @@ -96,7 +97,7 @@ class DispatchContext(Outbound, Protocol[TransportT_co]): Carries the transport metadata for the inbound message and provides the back-channel for sending requests/notifications to the peer while handling - it. `send_request` raises `NoBackChannelError` if + it. `send_raw_request` raises `NoBackChannelError` if ``transport.can_send_request`` is ``False``. """ diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index e9dd2c843..b62629b6c 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -46,7 +46,7 @@ class NoBackChannelError(MCPError): Stateless HTTP and JSON-response-mode HTTP have no channel for the server to push requests (sampling, elicitation, roots/list) to the client. This is - raised by `DispatchContext.send_request` when `transport.can_send_request` + raised by `DispatchContext.send_raw_request` when `transport.can_send_request` is ``False``, and serializes to an ``INVALID_REQUEST`` error response. """ diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py index 31230fda9..832cead51 100644 --- a/src/mcp/shared/transport_context.py +++ b/src/mcp/shared/transport_context.py @@ -26,5 +26,5 @@ class TransportContext: ``False`` for stateless HTTP and HTTP with JSON response mode; ``True`` for stdio, SSE, and stateful streamable HTTP. When ``False``, - `DispatchContext.send_request` raises `NoBackChannelError`. + `DispatchContext.send_raw_request` raises `NoBackChannelError`. """ diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 44ab622ad..784ef6698 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -67,16 +67,16 @@ async def running_pair( @pytest.mark.anyio -async def test_send_request_returns_result_from_peer_on_request(): +async def test_send_raw_request_returns_result_from_peer_on_request(): async with running_pair() as (client, _server, _crec, srec): with anyio.fail_after(5): - result = await client.send_request("tools/list", {"cursor": "abc"}) + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} assert srec.requests == [("tools/list", {"cursor": "abc"})] @pytest.mark.anyio -async def test_send_request_reraises_mcperror_from_handler_unchanged(): +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -84,13 +84,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INVALID_PARAMS assert exc.value.error.message == "bad cursor" @pytest.mark.anyio -async def test_send_request_wraps_non_mcperror_exception_as_internal_error(): +async def test_send_raw_request_wraps_non_mcperror_exception_as_internal_error(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -98,13 +98,13 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("tools/list", {}) + await client.send_raw_request("tools/list", {}) assert exc.value.error.code == INTERNAL_ERROR assert isinstance(exc.value.__cause__, ValueError) @pytest.mark.anyio -async def test_send_request_with_timeout_raises_mcperror_request_timeout(): +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(): async def on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: @@ -113,7 +113,7 @@ async def on_request( async with running_pair(server_on_request=on_request) as (client, *_): with anyio.fail_after(5), pytest.raises(MCPError) as exc: - await client.send_request("slow", None, {"timeout": 0}) + await client.send_raw_request("slow", None, {"timeout": 0}) assert exc.value.error.code == REQUEST_TIMEOUT @@ -127,32 +127,32 @@ async def test_notify_invokes_peer_on_notify(): @pytest.mark.anyio -async def test_ctx_send_request_round_trips_to_calling_side(): - """A handler's ctx.send_request reaches the side that made the inbound request.""" +async def test_ctx_send_raw_request_round_trips_to_calling_side(): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - sample = await ctx.send_request("sampling/createMessage", {"prompt": "hi"}) + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) return {"sampled": sample} async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} @pytest.mark.anyio -async def test_ctx_send_request_raises_nobackchannelerror_when_transport_disallows(): +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(): async def server_on_request( ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - return await ctx.send_request("sampling/createMessage", None) + return await ctx.send_raw_request("sampling/createMessage", None) async with running_pair(server_on_request=server_on_request, can_send_request=False) as (client, *_): with anyio.fail_after(5), pytest.raises(NoBackChannelError) as exc: - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) assert exc.value.method == "sampling/createMessage" assert exc.value.error.code == INVALID_REQUEST @@ -167,7 +167,7 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, _server, crec, _srec): with anyio.fail_after(5): - await client.send_request("tools/call", None) + await client.send_raw_request("tools/call", None) await crec.notified.wait() assert crec.notifications == [("notifications/message", {"level": "info"})] @@ -187,12 +187,12 @@ async def on_progress(progress: float, total: float | None, message: str | None) async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - await client.send_request("tools/call", None, {"on_progress": on_progress}) + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) assert received == [(0.5, 1.0, "halfway")] @pytest.mark.anyio -async def test_send_request_issued_before_peer_run_blocks_until_peer_ready(): +async def test_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): client, server = create_direct_dispatcher_pair() s_req, s_notify = echo_handlers(Recorder()) c_req, c_notify = echo_handlers(Recorder()) @@ -205,7 +205,7 @@ async def late_start(): tg.start_soon(client.run, c_req, c_notify) tg.start_soon(late_start) with anyio.fail_after(5): - result = await client.send_request("ping", None) + result = await client.send_raw_request("ping", None) assert result == {"echoed": "ping", "params": {}} client.close() server.close() @@ -221,15 +221,15 @@ async def server_on_request( async with running_pair(server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): - result = await client.send_request("tools/call", None) + result = await client.send_raw_request("tools/call", None) assert result == {"ok": True} @pytest.mark.anyio -async def test_send_request_and_notify_raise_runtimeerror_when_no_peer_connected(): +async def test_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) with pytest.raises(RuntimeError, match="no peer"): - await d.send_request("ping", None) + await d.send_raw_request("ping", None) with pytest.raises(RuntimeError, match="no peer"): await d.notify("ping", None)