Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/mcp/server/_typed_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Shape-2 typed ``send_request`` for server-to-client requests.

`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over
the host's raw `Outbound.send_raw_request`. Spec server-to-client request types
have their result type inferred via per-type overloads; custom requests pass
``result_type=`` explicitly.

A `HasResult[R]` protocol (one generic signature, mapping declared on the
request type) is the cleaner long-term shape — see FOLLOWUPS.md. This per-spec
overload set is used for now to avoid touching `mcp.types`.
"""

from typing import Any, TypeVar, overload

from pydantic import BaseModel

from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.peer import dump_params
from mcp.types import (
CreateMessageRequest,
CreateMessageResult,
ElicitRequest,
ElicitResult,
EmptyResult,
ListRootsRequest,
ListRootsResult,
PingRequest,
Request,
)

__all__ = ["TypedServerRequestMixin"]

ResultT = TypeVar("ResultT", bound=BaseModel)

_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = {
CreateMessageRequest: CreateMessageResult,
ElicitRequest: ElicitResult,
ListRootsRequest: ListRootsResult,
PingRequest: EmptyResult,
}


class TypedServerRequestMixin:
"""Typed ``send_request`` for the server-to-client request set.

Mixed into `Connection` and the server `Context`. Each method constrains
``self`` to `Outbound` so any host with ``send_raw_request`` works.
"""

@overload
async def send_request(
self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None
) -> CreateMessageResult: ...
@overload
async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ...
@overload
async def send_request(
self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None
) -> ListRootsResult: ...
@overload
async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ...
@overload
async def send_request(
self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None
) -> ResultT: ...
async def send_request(
self: Outbound,
req: Request[Any, Any],
*,
result_type: type[BaseModel] | None = None,
opts: CallOptions | None = None,
) -> BaseModel:
"""Send a typed server-to-client request and return its typed result.

For spec request types the result type is inferred. For custom requests
pass ``result_type=`` explicitly.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: No back-channel for server-initiated requests.
KeyError: ``result_type`` omitted for a non-spec request type.
"""
raw = await self.send_raw_request(req.method, dump_params(req.params), opts)
cls = result_type if result_type is not None else _RESULT_FOR[type(req)]
return cls.model_validate(raw)
146 changes: 146 additions & 0 deletions src/mcp/server/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""`Connection` — per-client connection state and the standalone outbound channel.

Always present on `Context` (never ``None``), even in stateless deployments.
Holds peer info populated at ``initialize`` time, the per-connection lifespan
output, and an `Outbound` for the standalone stream (the SSE GET stream in
streamable HTTP, or the single duplex stream in stdio).

`notify` is best-effort: it never raises. If there's no standalone channel
(stateless HTTP) or the stream has been dropped, the notification is
debug-logged and silently discarded — server-initiated notifications are
inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when
there's no channel; `ping` is the only spec-sanctioned standalone request.
"""

import logging
from collections.abc import Mapping
from typing import Any

import anyio

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.shared.dispatcher import CallOptions, Outbound
from mcp.shared.exceptions import NoBackChannelError
from mcp.shared.peer import Meta, dump_params
from mcp.types import ClientCapabilities, Implementation, LoggingLevel

__all__ = ["Connection"]

logger = logging.getLogger(__name__)


def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None:
if not meta:
return payload
out = dict(payload or {})
out["_meta"] = meta
return out


class Connection(TypedServerRequestMixin):
"""Per-client connection state and standalone-stream `Outbound`.

Constructed by `ServerRunner` once per connection. The peer-info fields are
``None`` until ``initialize`` completes; ``initialized`` is set then.
"""

def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
self._outbound = outbound
self.has_standalone_channel = has_standalone_channel

self.client_info: Implementation | None = None
self.client_capabilities: ClientCapabilities | None = None
self.protocol_version: str | None = None
self.initialized: anyio.Event = anyio.Event()
# TODO: make this generic (Connection[StateT]) once connection_lifespan
# wiring lands in ServerRunner — see FOLLOWUPS.md.
self.state: Any = None

async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a raw request on the standalone stream.

Low-level `Outbound` channel. Prefer the typed ``send_request`` (from
`TypedServerRequestMixin`) or the convenience methods below; use this
directly only for off-spec messages.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
if not self.has_standalone_channel:
raise NoBackChannelError(method)
return await self._outbound.send_raw_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a best-effort notification on the standalone stream.

Never raises. If there's no standalone channel or the stream is broken,
the notification is dropped and debug-logged.
"""
if not self.has_standalone_channel:
logger.debug("dropped %s: no standalone channel", method)
return
try:
await self._outbound.notify(method, params)
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
logger.debug("dropped %s: standalone stream closed", method)

async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None:
"""Send a ``ping`` request on the standalone stream.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``has_standalone_channel`` is ``False``.
"""
await self.send_raw_request("ping", dump_params(None, meta), opts)

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a ``notifications/message`` log entry on the standalone stream. Best-effort."""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
await self.notify("notifications/message", _notification_params(params, meta))

async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/tools/list_changed", _notification_params(None, meta))

async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/prompts/list_changed", _notification_params(None, meta))

async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/list_changed", _notification_params(None, meta))

async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None:
await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta))

def check_capability(self, capability: ClientCapabilities) -> bool:
"""Return whether the connected client declared the given capability.

Returns ``False`` if ``initialize`` hasn't completed yet.
"""
# TODO: redesign — mirrors v1 ServerSession.check_client_capability
# verbatim for parity. See FOLLOWUPS.md.
if self.client_capabilities is None:
return False
have = self.client_capabilities
if capability.roots is not None:
if have.roots is None:
return False
if capability.roots.list_changed and not have.roots.list_changed:
return False
if capability.sampling is not None and have.sampling is None:
return False
if capability.elicitation is not None and have.elicitation is None:
return False
if capability.experimental is not None:
if have.experimental is None:
return False
for k in capability.experimental:
if k not in have.experimental:
return False
return True
60 changes: 60 additions & 0 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@

from typing_extensions import TypeVar

from mcp.server._typed_request import TypedServerRequestMixin
from mcp.server.connection import Connection
from mcp.server.experimental.request_context import Experimental
from mcp.server.session import ServerSession
from mcp.shared._context import RequestContext
from mcp.shared.context import BaseContext
from mcp.shared.dispatcher import DispatchContext
from mcp.shared.message import CloseSSEStreamCallback
from mcp.shared.peer import Meta, PeerMixin
from mcp.shared.transport_context import TransportContext
from mcp.types import LoggingLevel, RequestParamsMeta

LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
RequestT = TypeVar("RequestT", default=Any)
Expand All @@ -21,3 +28,56 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex
request: RequestT | None = None
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None


LifespanT = TypeVar("LifespanT", default=Any)
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext)


class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
"""Server-side per-request context.

Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
`PeerMixin` (kwarg-style ``sample``/``elicit_*``/``list_roots``/``ping``),
and `TypedServerRequestMixin` (typed ``send_request(req) -> Result``). Adds
``lifespan`` and ``connection``.

Constructed by `ServerRunner` (PR4) per inbound request and handed to the
user's handler.
"""

def __init__(
self,
dctx: DispatchContext[TransportT],
*,
lifespan: LifespanT,
connection: Connection,
meta: RequestParamsMeta | None = None,
) -> None:
super().__init__(dctx, meta=meta)
self._lifespan = lifespan
self._connection = connection

@property
def lifespan(self) -> LifespanT:
"""The server-wide lifespan output (what `Server(..., lifespan=...)` yielded)."""
return self._lifespan

@property
def connection(self) -> Connection:
"""The per-client `Connection` for this request's connection."""
return self._connection

async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
"""Send a request-scoped ``notifications/message`` log entry.

Uses this request's back-channel (so the entry rides the request's SSE
stream in streamable HTTP), not the standalone stream — use
``ctx.connection.log(...)`` for that.
"""
params: dict[str, Any] = {"level": level, "data": data}
if logger is not None:
params["logger"] = logger
if meta:
params["_meta"] = meta
await self.notify("notifications/message", params)
82 changes: 82 additions & 0 deletions src/mcp/shared/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""`BaseContext` — the user-facing per-request context.

Composition over a `DispatchContext`: forwards the transport metadata, the
back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel
event. Adds `meta` (the inbound request's `_meta` field).

Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context`
mixes that in directly). Shared between client and server: the server's
`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an
alias.
"""

from collections.abc import Mapping
from typing import Any, Generic

import anyio
from typing_extensions import TypeVar

from mcp.shared.dispatcher import CallOptions, DispatchContext
from mcp.shared.transport_context import TransportContext
from mcp.types import RequestParamsMeta

__all__ = ["BaseContext"]

TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext)


class BaseContext(Generic[TransportT]):
"""Per-request context wrapping a `DispatchContext`.

`ServerRunner` (PR4) constructs one per inbound request and passes it to
the user's handler.
"""

def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None:
self._dctx = dctx
self._meta = meta

@property
def transport(self) -> TransportT:
"""Transport-specific metadata for this inbound request."""
return self._dctx.transport

@property
def cancel_requested(self) -> anyio.Event:
"""Set when the peer sends ``notifications/cancelled`` for this request."""
return self._dctx.cancel_requested

@property
def can_send_request(self) -> bool:
"""Whether the back-channel can deliver server-initiated requests."""
return self._dctx.transport.can_send_request

@property
def meta(self) -> RequestParamsMeta | None:
"""The inbound request's ``_meta`` field, if present."""
return self._meta

async def send_raw_request(
self,
method: str,
params: Mapping[str, Any] | None,
opts: CallOptions | None = None,
) -> dict[str, Any]:
"""Send a request to the peer on the back-channel.

Raises:
MCPError: The peer responded with an error.
NoBackChannelError: ``can_send_request`` is ``False``.
"""
return await self._dctx.send_raw_request(method, params, opts)

async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
"""Send a notification to the peer on the back-channel."""
await self._dctx.notify(method, params)

async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
"""Report progress for this request, if the peer supplied a progress token.

A no-op when no token was supplied.
"""
await self._dctx.progress(progress, total, message)
Loading
Loading