diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index aa996b5b..4e699718 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -50,9 +50,9 @@ pub fn schema_for_type() -> Arc { }) } -/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level -/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM). -fn validate_and_strip(raw: &Arc, purpose: &str) -> Result, String> { +/// Validate that the schema root is `type: "object"` (per MCP spec for inputSchema) and +/// strip top-level `title`/`description` (the wrapper type name and doc, which are noise to the LLM). +fn validate_and_strip_input(raw: &Arc) -> Result, String> { match raw.get("type") { Some(serde_json::Value::String(t)) if t == "object" => { let mut object = raw.as_ref().clone(); @@ -61,17 +61,27 @@ fn validate_and_strip(raw: &Arc, purpose: &str) -> Result Err(format!( - "MCP specification requires tool {purpose} to have root type 'object', but found '{t}'." - )), - None => Err(format!( - "Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'." + "MCP specification requires tool inputSchema to have root type 'object', but found '{t}'." )), + None => Err( + "Schema is missing 'type' field. MCP specification requires inputSchema to have root type 'object'.".to_string() + ), Some(other) => Err(format!( "Schema 'type' field has unexpected format: {other:?}. Expected \"object\"." )), } } +/// Strip top-level `title`/`description` from a JSON schema for outputSchema. +/// Unlike inputSchema, outputSchema may have any JSON Schema 2020-12 root type +/// (objects, arrays, primitives, compositions) per SEP-2106. +fn validate_and_strip_output(raw: &Arc) -> Arc { + let mut object = raw.as_ref().clone(); + object.remove("title"); + object.remove("description"); + Arc::new(object) +} + /// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object"; /// top-level "title" and "description" are removed). pub fn schema_for_input() -> Result, String> { @@ -86,7 +96,7 @@ pub fn schema_for_input() -> Result(), "inputSchema"); + let result = validate_and_strip_input(&schema_for_type::()); cache .write() .expect("input schema cache lock poisoned") @@ -106,7 +116,10 @@ pub fn schema_for_empty_input() -> Arc { EMPTY.clone() } -/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed) +/// Generate and strip a JSON schema for outputSchema. +/// Unlike inputSchema, outputSchema accepts any JSON Schema 2020-12 root type +/// (objects, arrays, primitives, compositions) per SEP-2106. +/// Top-level "title" and "description" are always removed. pub fn schema_for_output() -> Result, String> { thread_local! { static CACHE_FOR_OUTPUT: std::sync::RwLock, String>>> = Default::default(); @@ -122,10 +135,8 @@ pub fn schema_for_output() -> Result(), "outputSchema"); + let result = Ok(validate_and_strip_output(&schema_for_type::())); - // Cache the result (both success and error cases) cache .write() .expect("output schema cache lock poisoned") @@ -306,9 +317,63 @@ mod tests { } #[test] - fn test_schema_for_output_rejects_primitive() { + fn test_schema_for_output_accepts_primitive() { let result = schema_for_output::(); - assert!(result.is_err(),); + assert!(result.is_ok()); + } + + #[test] + fn test_schema_for_output_accepts_array() { + let result = schema_for_output::>(); + assert!(result.is_ok()); + let schema = result.unwrap(); + assert_eq!(schema.get("type"), Some(&serde_json::json!("array"))); + assert!(schema.contains_key("items")); + } + + #[test] + fn test_schema_for_output_strips_title_for_primitive() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("title")); + } + + #[test] + fn test_schema_for_output_strips_description_for_primitive() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("description")); + } + + #[test] + fn test_schema_for_output_accepts_composition() { + let result = schema_for_output::>(); + assert!(result.is_ok()); + let schema = result.unwrap(); + let schema_str = serde_json::to_string(&schema).unwrap(); + assert!( + schema_str.contains("anyOf") || schema_str.contains("oneOf") || schema_str.contains("null"), + "Expected composition schema for Option, got: {schema_str}" + ); + } + + #[test] + fn test_schema_for_output_caches_result() { + let result1 = schema_for_output::(); + let result2 = schema_for_output::(); + assert!(result1.is_ok()); + assert!(result2.is_ok()); + assert!(Arc::ptr_eq(result1.as_ref().unwrap(), result2.as_ref().unwrap())); + } + + #[test] + fn test_schema_for_input_rejects_array() { + let result = schema_for_input::>(); + assert!(result.is_err()); + } + + #[test] + fn test_schema_for_output_accepts_unit() { + let result = schema_for_output::<()>(); + assert!(result.is_ok()); } #[test] diff --git a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs index b0bf9e2d..69e50d0d 100644 --- a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs +++ b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs @@ -346,4 +346,43 @@ mod tests { assert_eq!(result, ErrorData::invalid_params("invalid params", None)); } } + + struct ArrayTool; + impl ToolBase for ArrayTool { + type Parameter = AddParameter; + type Output = Vec; + type Error = ErrorData; + + fn name() -> Cow<'static, str> { + "array-tool".into() + } + } + impl SyncTool for ArrayTool { + fn invoke( + _service: &TraitBasedToolServer, + _param: Self::Parameter, + ) -> Result { + Ok(vec![]) + } + } + impl AsyncTool for ArrayTool { + async fn invoke( + _service: &TraitBasedToolServer, + _param: Self::Parameter, + ) -> Result { + Ok(vec![]) + } + } + + #[test] + fn test_toolbase_output_schema_with_array_output() { + let schema = ArrayTool::output_schema(); + assert!(schema.is_some()); + let schema = schema.unwrap(); + let schema_value: serde_json::Value = serde_json::from_str( + &serde_json::to_string(&*schema).expect("failed to serialize schema"), + ) + .expect("failed to parse schema JSON"); + assert_eq!(schema_value["type"], "array"); + } } diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 2ed89d6a..847736ec 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -318,7 +318,7 @@ impl Tool { /// /// # Panics /// - /// Panics if the generated schema does not have root type "object" as required by MCP specification. + /// Panics if output schema generation fails. #[cfg(feature = "server")] pub fn with_output_schema(mut self) -> Self { let schema = crate::handler::server::tool::schema_for_output::() diff --git a/crates/rmcp/tests/test_json_schema_detection.rs b/crates/rmcp/tests/test_json_schema_detection.rs index 5d982cd6..9cf73cb8 100644 --- a/crates/rmcp/tests/test_json_schema_detection.rs +++ b/crates/rmcp/tests/test_json_schema_detection.rs @@ -60,6 +60,28 @@ impl TestServer { pub async fn explicit_schema(&self) -> Result { Ok("test".to_string()) } + + /// Tool that returns Json> - array output schema + #[tool(name = "with-json-array")] + pub async fn with_json_array(&self) -> Result>, String> { + Ok(Json(vec![TestData { + value: "test".to_string(), + }])) + } + + /// Tool that returns Result>, ErrorData> - array output schema + #[tool(name = "result-with-json-array")] + pub async fn result_with_json_array(&self) -> Result>, rmcp::ErrorData> { + Ok(Json(vec![TestData { + value: "test".to_string(), + }])) + } + + /// Tool that returns Json - string output schema + #[tool(name = "with-json-string")] + pub async fn with_json_string(&self) -> Result, String> { + Ok(Json("test".to_string())) + } } #[tokio::test] @@ -113,3 +135,57 @@ async fn test_explicit_schema_override() { "Explicit output_schema attribute should work" ); } + +#[tokio::test] +async fn test_json_array_type_generates_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + let array_tool = tools.iter().find(|t| t.name == "with-json-array").unwrap(); + assert!( + array_tool.output_schema.is_some(), + "Json> return type should generate output schema" + ); + let schema = array_tool.output_schema.as_ref().unwrap(); + assert_eq!( + schema.get("type").and_then(|v| v.as_str()), + Some("array"), + "Json> should produce an array schema" + ); +} + +#[tokio::test] +async fn test_result_with_json_array_generates_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + let result_array_tool = tools + .iter() + .find(|t| t.name == "result-with-json-array") + .unwrap(); + assert!( + result_array_tool.output_schema.is_some(), + "Result>, ErrorData> return type should generate output schema" + ); +} + +#[tokio::test] +async fn test_json_string_type_generates_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + let string_tool = tools + .iter() + .find(|t| t.name == "with-json-string") + .unwrap(); + assert!( + string_tool.output_schema.is_some(), + "Json return type should generate output schema" + ); + let schema = string_tool.output_schema.as_ref().unwrap(); + assert_eq!( + schema.get("type").and_then(|v| v.as_str()), + Some("string"), + "Json should produce a string schema" + ); +} diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index adbdfec5..b1959b64 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -93,6 +93,24 @@ impl TestServer { Err("User not found".to_string()) } } + + /// Tool that returns a list of calculation results + #[tool(name = "calculate-list", description = "Return a list of calculation results")] + pub async fn calculate_list( + &self, + params: Parameters, + ) -> Result>, String> { + Ok(Json(vec![CalculationResult { + sum: params.0.a + params.0.b, + product: params.0.a * params.0.b, + }])) + } + + /// Tool that returns a count + #[tool(name = "get-count", description = "Return a count")] + pub async fn get_count(&self) -> Result, String> { + Ok(Json(42)) + } } #[tokio::test] @@ -359,3 +377,39 @@ fn test_call_tool_result_deserialize_without_content() { assert!(result.content.is_empty()); assert!(result.structured_content.is_some()); } + +#[tokio::test] +async fn test_tool_with_array_output_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the calculate-list tool + let calculate_list_tool = tools.iter().find(|t| t.name == "calculate-list").unwrap(); + + // Verify it has an output schema + assert!(calculate_list_tool.output_schema.is_some()); + + let schema = calculate_list_tool.output_schema.as_ref().unwrap(); + + // Check that the schema contains array type + let schema_str = serde_json::to_string(schema).unwrap(); + assert!(schema_str.contains("array")); +} + +#[tokio::test] +async fn test_tool_with_primitive_output_schema() { + let server = TestServer::new(); + let tools = server.tool_router.list_all(); + + // Find the get-count tool + let get_count_tool = tools.iter().find(|t| t.name == "get-count").unwrap(); + + // Verify it has an output schema + assert!(get_count_tool.output_schema.is_some()); + + let schema = get_count_tool.output_schema.as_ref().unwrap(); + + // Check that the schema contains integer type + let schema_str = serde_json::to_string(schema).unwrap(); + assert!(schema_str.contains("integer")); +} diff --git a/crates/rmcp/tests/test_tool_builder_methods.rs b/crates/rmcp/tests/test_tool_builder_methods.rs index 8be7e5c3..9b17b5eb 100644 --- a/crates/rmcp/tests/test_tool_builder_methods.rs +++ b/crates/rmcp/tests/test_tool_builder_methods.rs @@ -61,3 +61,45 @@ fn test_chained_builder_methods() { assert!(output_schema_str.contains("greeting")); assert!(output_schema_str.contains("is_adult")); } + +#[test] +fn test_with_output_schema_primitive() { + let tool = Tool::new("test", "Test tool", JsonObject::new()).with_output_schema::(); + + assert!(tool.output_schema.is_some()); + + let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap(); + assert!(schema_str.contains("\"type\":\"integer\"")); + // title should be stripped from output schema + assert!(!schema_str.contains("title")); +} + +#[test] +fn test_with_output_schema_array() { + let tool = Tool::new("test", "Test tool", JsonObject::new()) + .with_output_schema::>(); + + assert!(tool.output_schema.is_some()); + + let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap(); + assert!(schema_str.contains("\"type\":\"array\"")); + assert!(schema_str.contains("items")); + // title should be stripped from output schema + assert!(!schema_str.contains("title")); +} + +#[test] +fn test_with_output_schema_option() { + let tool = Tool::new("test", "Test tool", JsonObject::new()) + .with_output_schema::>(); + + assert!(tool.output_schema.is_some()); + + let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap(); + // Option generates a composition schema (anyOf/oneOf/type array with null) + assert!( + schema_str.contains("anyOf") || schema_str.contains("oneOf") || schema_str.contains("null"), + "Expected composition schema for Option, got: {schema_str}" + ); + assert!(!schema_str.contains("title")); +}