Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,24 @@ use crate::{
tool::{CallToolHandler, DynCallToolHandler, ToolCallContext},
tool_name_validation::validate_and_warn_tool_name,
},
model::{CallToolResult, Tool, ToolAnnotations},
model::{CallToolResult, Content, ErrorCode, Tool, ToolAnnotations},
service::{MaybeBoxFuture, MaybeSend},
};

const TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX: &str = "failed to deserialize parameters:";

fn into_tool_argument_error(error: crate::ErrorData) -> Result<CallToolResult, crate::ErrorData> {
if error.code == ErrorCode::INVALID_PARAMS
&& error
.message
.starts_with(TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX)
{
return Ok(CallToolResult::error(vec![Content::text(error.message)]));
}

Err(error)
}

#[non_exhaustive]
pub struct ToolRoute<S> {
#[allow(clippy::type_complexity)]
Expand Down Expand Up @@ -555,7 +569,10 @@ where
.get(name)
.ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?;

let result = (item.call)(context).await?;
let result = match (item.call)(context).await {
Ok(result) => result,
Err(error) => return into_tool_argument_error(error),
};

Ok(result)
}
Expand Down Expand Up @@ -611,13 +628,65 @@ mod tests {
use super::*;
use crate::{
RoleServer,
handler::server::wrapper::Parameters,
model::{CallToolRequestParams, ErrorCode, NumberOrString},
service::{AtomicU32RequestIdProvider, Peer, RequestContext},
};

struct DummyService;
impl crate::handler::server::ServerHandler for DummyService {}

#[derive(serde::Deserialize, schemars::JsonSchema)]
struct RequiredParams {
project: String,
}

fn requires_params(Parameters(params): Parameters<RequiredParams>) -> String {
params.project
}

#[tokio::test]
async fn test_argument_deserialization_error_returns_tool_error_result() {
let service = DummyService;
let router = ToolRouter::new().with_route(ToolRoute::new(
crate::model::Tool::new(
"requires_params",
"requires params",
Arc::new(Default::default()),
),
requires_params,
));

let id_provider: Arc<dyn crate::service::RequestIdProvider> =
Arc::new(AtomicU32RequestIdProvider::default());
let (peer, _rx) = Peer::<RoleServer>::new(id_provider, None);
let ctx = crate::handler::server::tool::ToolCallContext::new(
&service,
CallToolRequestParams {
meta: None,
name: Cow::Borrowed("requires_params"),
arguments: Some(Default::default()),
task: None,
},
RequestContext::new(NumberOrString::Number(1), peer),
);

let result = router
.call(ctx)
.await
.expect("argument validation should be a tool result");
assert_eq!(result.is_error, Some(true));

let text = result
.content
.first()
.and_then(|content| content.raw.as_text())
.map(|text| text.text.as_str())
.expect("tool error result should include text");
assert!(text.contains("failed to deserialize parameters"));
assert!(text.contains("missing field `project`"));
}

#[tokio::test]
async fn test_call_disabled_tool_returns_error() {
let service = DummyService;
Expand Down