Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 79 additions & 14 deletions crates/rmcp/src/handler/server/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
})
}

/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level
/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM).
fn validate_and_strip(raw: &Arc<JsonObject>, purpose: &str) -> Result<Arc<JsonObject>, String> {
/// Validate that the schema root is `type: "object"` (per MCP spec for inputSchema) and
/// strip top-level `title`/`description` (the wrapper type name and doc, which are noise to the LLM).
fn validate_and_strip_input(raw: &Arc<JsonObject>) -> Result<Arc<JsonObject>, String> {
match raw.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => {
let mut object = raw.as_ref().clone();
Expand All @@ -61,17 +61,27 @@ fn validate_and_strip(raw: &Arc<JsonObject>, purpose: &str) -> Result<Arc<JsonOb
Ok(Arc::new(object))
}
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool {purpose} to have root type 'object', but found '{t}'."
)),
None => Err(format!(
"Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'."
"MCP specification requires tool inputSchema to have root type 'object', but found '{t}'."
)),
None => Err(
"Schema is missing 'type' field. MCP specification requires inputSchema to have root type 'object'.".to_string()
),
Some(other) => Err(format!(
"Schema 'type' field has unexpected format: {other:?}. Expected \"object\"."
)),
}
}

/// Strip top-level `title`/`description` from a JSON schema for outputSchema.
/// Unlike inputSchema, outputSchema may have any JSON Schema 2020-12 root type
/// (objects, arrays, primitives, compositions) per SEP-2106.
fn validate_and_strip_output(raw: &Arc<JsonObject>) -> Arc<JsonObject> {
let mut object = raw.as_ref().clone();
object.remove("title");
object.remove("description");
Arc::new(object)
}

/// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object";
/// top-level "title" and "description" are removed).
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
Expand All @@ -86,7 +96,7 @@ pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObjec
{
return result.clone();
}
let result = validate_and_strip(&schema_for_type::<T>(), "inputSchema");
let result = validate_and_strip_input(&schema_for_type::<T>());
cache
.write()
.expect("input schema cache lock poisoned")
Expand All @@ -106,7 +116,10 @@ pub fn schema_for_empty_input() -> Arc<JsonObject> {
EMPTY.clone()
}

/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed)
/// Generate and strip a JSON schema for outputSchema.
/// Unlike inputSchema, outputSchema accepts any JSON Schema 2020-12 root type
/// (objects, arrays, primitives, compositions) per SEP-2106.
/// Top-level "title" and "description" are always removed.
pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
thread_local! {
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
Expand All @@ -122,10 +135,8 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
return result.clone();
}

// Generate, validate, and strip unnecessary top-level fields
let result = validate_and_strip(&schema_for_type::<T>(), "outputSchema");
let result = Ok(validate_and_strip_output(&schema_for_type::<T>()));

// Cache the result (both success and error cases)
cache
.write()
.expect("output schema cache lock poisoned")
Expand Down Expand Up @@ -306,9 +317,63 @@ mod tests {
}

#[test]
fn test_schema_for_output_rejects_primitive() {
fn test_schema_for_output_accepts_primitive() {
let result = schema_for_output::<i32>();
assert!(result.is_err(),);
assert!(result.is_ok());
}

#[test]
fn test_schema_for_output_accepts_array() {
let result = schema_for_output::<Vec<i32>>();
assert!(result.is_ok());
let schema = result.unwrap();
assert_eq!(schema.get("type"), Some(&serde_json::json!("array")));
assert!(schema.contains_key("items"));
}

#[test]
fn test_schema_for_output_strips_title_for_primitive() {
let schema = schema_for_output::<i32>().unwrap();
assert!(!schema.contains_key("title"));
}

#[test]
fn test_schema_for_output_strips_description_for_primitive() {
let schema = schema_for_output::<i32>().unwrap();
assert!(!schema.contains_key("description"));
}

#[test]
fn test_schema_for_output_accepts_composition() {
let result = schema_for_output::<Option<String>>();
assert!(result.is_ok());
let schema = result.unwrap();
let schema_str = serde_json::to_string(&schema).unwrap();
assert!(
schema_str.contains("anyOf") || schema_str.contains("oneOf") || schema_str.contains("null"),
"Expected composition schema for Option<String>, got: {schema_str}"
);
}

#[test]
fn test_schema_for_output_caches_result() {
let result1 = schema_for_output::<i32>();
let result2 = schema_for_output::<i32>();
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(Arc::ptr_eq(result1.as_ref().unwrap(), result2.as_ref().unwrap()));
}

#[test]
fn test_schema_for_input_rejects_array() {
let result = schema_for_input::<Vec<i32>>();
assert!(result.is_err());
}

#[test]
fn test_schema_for_output_accepts_unit() {
let result = schema_for_output::<()>();
assert!(result.is_ok());
}

#[test]
Expand Down
39 changes: 39 additions & 0 deletions crates/rmcp/src/handler/server/router/tool/tool_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,43 @@ mod tests {
assert_eq!(result, ErrorData::invalid_params("invalid params", None));
}
}

struct ArrayTool;
impl ToolBase for ArrayTool {
type Parameter = AddParameter;
type Output = Vec<AddOutput>;
type Error = ErrorData;

fn name() -> Cow<'static, str> {
"array-tool".into()
}
}
impl SyncTool<TraitBasedToolServer> for ArrayTool {
fn invoke(
_service: &TraitBasedToolServer,
_param: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Ok(vec![])
}
}
impl AsyncTool<TraitBasedToolServer> for ArrayTool {
async fn invoke(
_service: &TraitBasedToolServer,
_param: Self::Parameter,
) -> Result<Self::Output, Self::Error> {
Ok(vec![])
}
}

#[test]
fn test_toolbase_output_schema_with_array_output() {
let schema = ArrayTool::output_schema();
assert!(schema.is_some());
let schema = schema.unwrap();
let schema_value: serde_json::Value = serde_json::from_str(
&serde_json::to_string(&*schema).expect("failed to serialize schema"),
)
.expect("failed to parse schema JSON");
assert_eq!(schema_value["type"], "array");
}
}
2 changes: 1 addition & 1 deletion crates/rmcp/src/model/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ impl Tool {
///
/// # Panics
///
/// Panics if the generated schema does not have root type "object" as required by MCP specification.
/// Panics if output schema generation fails.
#[cfg(feature = "server")]
pub fn with_output_schema<T: JsonSchema + 'static>(mut self) -> Self {
let schema = crate::handler::server::tool::schema_for_output::<T>()
Expand Down
76 changes: 76 additions & 0 deletions crates/rmcp/tests/test_json_schema_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ impl TestServer {
pub async fn explicit_schema(&self) -> Result<String, String> {
Ok("test".to_string())
}

/// Tool that returns Json<Vec<T>> - array output schema
#[tool(name = "with-json-array")]
pub async fn with_json_array(&self) -> Result<Json<Vec<TestData>>, String> {
Ok(Json(vec![TestData {
value: "test".to_string(),
}]))
}

/// Tool that returns Result<Json<Vec<T>>, ErrorData> - array output schema
#[tool(name = "result-with-json-array")]
pub async fn result_with_json_array(&self) -> Result<Json<Vec<TestData>>, rmcp::ErrorData> {
Ok(Json(vec![TestData {
value: "test".to_string(),
}]))
}

/// Tool that returns Json<String> - string output schema
#[tool(name = "with-json-string")]
pub async fn with_json_string(&self) -> Result<Json<String>, String> {
Ok(Json("test".to_string()))
}
}

#[tokio::test]
Expand Down Expand Up @@ -113,3 +135,57 @@ async fn test_explicit_schema_override() {
"Explicit output_schema attribute should work"
);
}

#[tokio::test]
async fn test_json_array_type_generates_schema() {
let server = TestServer::new();
let tools = server.tool_router.list_all();

let array_tool = tools.iter().find(|t| t.name == "with-json-array").unwrap();
assert!(
array_tool.output_schema.is_some(),
"Json<Vec<T>> return type should generate output schema"
);
let schema = array_tool.output_schema.as_ref().unwrap();
assert_eq!(
schema.get("type").and_then(|v| v.as_str()),
Some("array"),
"Json<Vec<T>> should produce an array schema"
);
}

#[tokio::test]
async fn test_result_with_json_array_generates_schema() {
let server = TestServer::new();
let tools = server.tool_router.list_all();

let result_array_tool = tools
.iter()
.find(|t| t.name == "result-with-json-array")
.unwrap();
assert!(
result_array_tool.output_schema.is_some(),
"Result<Json<Vec<T>>, ErrorData> return type should generate output schema"
);
}

#[tokio::test]
async fn test_json_string_type_generates_schema() {
let server = TestServer::new();
let tools = server.tool_router.list_all();

let string_tool = tools
.iter()
.find(|t| t.name == "with-json-string")
.unwrap();
assert!(
string_tool.output_schema.is_some(),
"Json<String> return type should generate output schema"
);
let schema = string_tool.output_schema.as_ref().unwrap();
assert_eq!(
schema.get("type").and_then(|v| v.as_str()),
Some("string"),
"Json<String> should produce a string schema"
);
}
54 changes: 54 additions & 0 deletions crates/rmcp/tests/test_structured_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ impl TestServer {
Err("User not found".to_string())
}
}

/// Tool that returns a list of calculation results
#[tool(name = "calculate-list", description = "Return a list of calculation results")]
pub async fn calculate_list(
&self,
params: Parameters<CalculationRequest>,
) -> Result<Json<Vec<CalculationResult>>, String> {
Ok(Json(vec![CalculationResult {
sum: params.0.a + params.0.b,
product: params.0.a * params.0.b,
}]))
}

/// Tool that returns a count
#[tool(name = "get-count", description = "Return a count")]
pub async fn get_count(&self) -> Result<Json<i32>, String> {
Ok(Json(42))
}
}

#[tokio::test]
Expand Down Expand Up @@ -359,3 +377,39 @@ fn test_call_tool_result_deserialize_without_content() {
assert!(result.content.is_empty());
assert!(result.structured_content.is_some());
}

#[tokio::test]
async fn test_tool_with_array_output_schema() {
let server = TestServer::new();
let tools = server.tool_router.list_all();

// Find the calculate-list tool
let calculate_list_tool = tools.iter().find(|t| t.name == "calculate-list").unwrap();

// Verify it has an output schema
assert!(calculate_list_tool.output_schema.is_some());

let schema = calculate_list_tool.output_schema.as_ref().unwrap();

// Check that the schema contains array type
let schema_str = serde_json::to_string(schema).unwrap();
assert!(schema_str.contains("array"));
}

#[tokio::test]
async fn test_tool_with_primitive_output_schema() {
let server = TestServer::new();
let tools = server.tool_router.list_all();

// Find the get-count tool
let get_count_tool = tools.iter().find(|t| t.name == "get-count").unwrap();

// Verify it has an output schema
assert!(get_count_tool.output_schema.is_some());

let schema = get_count_tool.output_schema.as_ref().unwrap();

// Check that the schema contains integer type
let schema_str = serde_json::to_string(schema).unwrap();
assert!(schema_str.contains("integer"));
}
42 changes: 42 additions & 0 deletions crates/rmcp/tests/test_tool_builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,45 @@ fn test_chained_builder_methods() {
assert!(output_schema_str.contains("greeting"));
assert!(output_schema_str.contains("is_adult"));
}

#[test]
fn test_with_output_schema_primitive() {
let tool = Tool::new("test", "Test tool", JsonObject::new()).with_output_schema::<i32>();

assert!(tool.output_schema.is_some());

let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap();
assert!(schema_str.contains("\"type\":\"integer\""));
// title should be stripped from output schema
assert!(!schema_str.contains("title"));
}

#[test]
fn test_with_output_schema_array() {
let tool = Tool::new("test", "Test tool", JsonObject::new())
.with_output_schema::<Vec<String>>();

assert!(tool.output_schema.is_some());

let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap();
assert!(schema_str.contains("\"type\":\"array\""));
assert!(schema_str.contains("items"));
// title should be stripped from output schema
assert!(!schema_str.contains("title"));
}

#[test]
fn test_with_output_schema_option() {
let tool = Tool::new("test", "Test tool", JsonObject::new())
.with_output_schema::<Option<String>>();

assert!(tool.output_schema.is_some());

let schema_str = serde_json::to_string(tool.output_schema.as_ref().unwrap()).unwrap();
// Option<String> generates a composition schema (anyOf/oneOf/type array with null)
assert!(
schema_str.contains("anyOf") || schema_str.contains("oneOf") || schema_str.contains("null"),
"Expected composition schema for Option<String>, got: {schema_str}"
);
assert!(!schema_str.contains("title"));
}
Loading