diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index c2bf299c..35bd25a9 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -137,10 +137,24 @@ use crate::{ tool::{CallToolHandler, DynCallToolHandler, ToolCallContext}, tool_name_validation::validate_and_warn_tool_name, }, - model::{CallToolResult, Tool, ToolAnnotations}, + model::{CallToolResult, Content, ErrorCode, Tool, ToolAnnotations}, service::{MaybeBoxFuture, MaybeSend}, }; +const TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX: &str = "failed to deserialize parameters:"; + +fn into_tool_argument_error(error: crate::ErrorData) -> Result { + if error.code == ErrorCode::INVALID_PARAMS + && error + .message + .starts_with(TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX) + { + return Ok(CallToolResult::error(vec![Content::text(error.message)])); + } + + Err(error) +} + #[non_exhaustive] pub struct ToolRoute { #[allow(clippy::type_complexity)] @@ -555,7 +569,10 @@ where .get(name) .ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?; - let result = (item.call)(context).await?; + let result = match (item.call)(context).await { + Ok(result) => result, + Err(error) => return into_tool_argument_error(error), + }; Ok(result) } @@ -611,6 +628,7 @@ mod tests { use super::*; use crate::{ RoleServer, + handler::server::wrapper::Parameters, model::{CallToolRequestParams, ErrorCode, NumberOrString}, service::{AtomicU32RequestIdProvider, Peer, RequestContext}, }; @@ -618,6 +636,57 @@ mod tests { struct DummyService; impl crate::handler::server::ServerHandler for DummyService {} + #[derive(serde::Deserialize, schemars::JsonSchema)] + struct RequiredParams { + project: String, + } + + fn requires_params(Parameters(params): Parameters) -> String { + params.project + } + + #[tokio::test] + async fn test_argument_deserialization_error_returns_tool_error_result() { + let service = DummyService; + let router = ToolRouter::new().with_route(ToolRoute::new( + crate::model::Tool::new( + "requires_params", + "requires params", + Arc::new(Default::default()), + ), + requires_params, + )); + + let id_provider: Arc = + Arc::new(AtomicU32RequestIdProvider::default()); + let (peer, _rx) = Peer::::new(id_provider, None); + let ctx = crate::handler::server::tool::ToolCallContext::new( + &service, + CallToolRequestParams { + meta: None, + name: Cow::Borrowed("requires_params"), + arguments: Some(Default::default()), + task: None, + }, + RequestContext::new(NumberOrString::Number(1), peer), + ); + + let result = router + .call(ctx) + .await + .expect("argument validation should be a tool result"); + assert_eq!(result.is_error, Some(true)); + + let text = result + .content + .first() + .and_then(|content| content.raw.as_text()) + .map(|text| text.text.as_str()) + .expect("tool error result should include text"); + assert!(text.contains("failed to deserialize parameters")); + assert!(text.contains("missing field `project`")); + } + #[tokio::test] async fn test_call_disabled_tool_returns_error() { let service = DummyService;