From 1d024a7c632830a9741b4d001e8e5820f0046a3f Mon Sep 17 00:00:00 2001 From: larock22 Date: Sun, 12 Apr 2026 19:59:51 -0500 Subject: [PATCH] Add minimal ask_user tool support --- src/looper.rs | 81 +++++++++++++++---------------------- src/looper_stream.rs | 92 ++++++++++++++++--------------------------- src/tools/ask_user.rs | 85 +++++++++++++++++++++++++++++++++++++++ src/tools/handler.rs | 65 ++++++++++++++++++++++++++++++ src/tools/mod.rs | 6 +++ src/types/ask_user.rs | 27 +++++++++++++ src/types/mod.rs | 3 ++ 7 files changed, 252 insertions(+), 107 deletions(-) create mode 100644 src/tools/ask_user.rs create mode 100644 src/tools/handler.rs create mode 100644 src/types/ask_user.rs diff --git a/src/looper.rs b/src/looper.rs index 38acb6b..0541273 100644 --- a/src/looper.rs +++ b/src/looper.rs @@ -13,8 +13,8 @@ use crate::{ openai_responses_non_streaming::OpenAIResponsesNonStreamingHandler, }, }, - tools::{EmptyToolSet, LooperTools, SubAgentTool}, - types::{Handlers, MessageHistory, turn::TurnResult}, + tools::{AskUserTool, CompositeToolSet, LooperTools, SubAgentTool}, + types::{AskUserSender, Handlers, MessageHistory, turn::TurnResult}, }; pub struct Looper { @@ -29,6 +29,7 @@ pub struct LooperBuilder<'a> { tools: Option>, instructions: Option, sub_agent: Option, + ask_user_channel: Option, } impl<'a> LooperBuilder<'a> { @@ -57,8 +58,26 @@ impl<'a> LooperBuilder<'a> { self } + pub fn ask_user_channel(mut self, channel: AskUserSender) -> Self { + self.ask_user_channel = Some(channel); + self + } + pub async fn build(mut self) -> Result { let sub_agent_enabled = self.sub_agent.is_some(); + let mut tool_set = CompositeToolSet::new(self.tools.take()); + + if let Some(sub_agent) = self.sub_agent.take() { + tool_set + .add_tool(Arc::new(SubAgentTool::new(sub_agent))) + .await; + } + + if let Some(channel) = self.ask_user_channel.take() { + tool_set.add_tool(Arc::new(AskUserTool::new(channel))).await; + } + + let tool_definitions = tool_set.get_tools().await; let handler: Box = match self.handler_type { Handlers::Anthropic(m) => { @@ -66,15 +85,7 @@ impl<'a> LooperBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::OpenAICompletions(m) => { @@ -82,15 +93,7 @@ impl<'a> LooperBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::OpenAIResponses(m) => { @@ -98,15 +101,7 @@ impl<'a> LooperBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::Gemini(m) => { @@ -114,31 +109,16 @@ impl<'a> LooperBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } }; - match self.tools { - Some(t) => Ok(Looper { - handler, - message_history: self.message_history, - tools: Arc::from(t), - }), - None => Ok(Looper { - handler, - message_history: self.message_history, - tools: Arc::new(EmptyToolSet), - }), - } + Ok(Looper { + handler, + message_history: self.message_history, + tools: Arc::new(tool_set), + }) } } @@ -150,6 +130,7 @@ impl Looper { tools: None, sub_agent: None, instructions: None, + ask_user_channel: None, } } diff --git a/src/looper_stream.rs b/src/looper_stream.rs index 5c4f2e9..8f01c8e 100644 --- a/src/looper_stream.rs +++ b/src/looper_stream.rs @@ -8,8 +8,10 @@ use crate::{ StreamingChatHandler, anthropic::AnthropicHandler, gemini::GeminiHandler, openai_completions::OpenAIChatHandler, openai_responses::OpenAIResponsesHandler, }, - tools::{EmptyToolSet, LooperTools, SubAgentTool}, - types::{HandlerToLooperMessage, Handlers, LooperToInterfaceMessage, MessageHistory}, + tools::{AskUserTool, CompositeToolSet, LooperTools, SubAgentTool}, + types::{ + AskUserSender, HandlerToLooperMessage, Handlers, LooperToInterfaceMessage, MessageHistory, + }, }; use anyhow::Result; use tera::{Context, Tera}; @@ -29,6 +31,7 @@ pub struct LooperStreamBuilder<'a> { tools: Option>, instructions: Option, sub_agent: Option, + ask_user_channel: Option, buffered_output: bool, } @@ -58,6 +61,11 @@ impl<'a> LooperStreamBuilder<'a> { self } + pub fn ask_user_channel(mut self, channel: AskUserSender) -> Self { + self.ask_user_channel = Some(channel); + self + } + pub fn buffered_output(mut self) -> Self { self.buffered_output = true; self @@ -67,6 +75,19 @@ impl<'a> LooperStreamBuilder<'a> { let sub_agent_enabled = self.sub_agent.is_some(); let (handler_looper_sender, mut handler_looper_receiver) = mpsc::channel(10000); let (looper_ui_sender, looper_ui_receiver) = mpsc::channel(10000); + let mut tool_set = CompositeToolSet::new(self.tools.take()); + + if let Some(sub_agent) = self.sub_agent.take() { + tool_set + .add_tool(Arc::new(SubAgentTool::new(sub_agent))) + .await; + } + + if let Some(channel) = self.ask_user_channel.take() { + tool_set.add_tool(Arc::new(AskUserTool::new(channel))).await; + } + + let tool_definitions = tool_set.get_tools().await; let handler: Box = match self.handler_type { Handlers::OpenAICompletions(m) => { @@ -75,15 +96,7 @@ impl<'a> LooperStreamBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::OpenAIResponses(m) => { @@ -92,15 +105,7 @@ impl<'a> LooperStreamBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::Anthropic(m) => { @@ -109,15 +114,7 @@ impl<'a> LooperStreamBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } Handlers::Gemini(m) => { @@ -126,15 +123,7 @@ impl<'a> LooperStreamBuilder<'a> { m, &get_system_message(self.instructions.as_deref(), sub_agent_enabled)?, )?; - - if let Some(t) = self.tools.as_mut() { - if let Some(sa) = self.sub_agent { - let agent_tools = Arc::new(SubAgentTool::new(sa)); - let _ = t.add_tool(agent_tools).await; - } - handler.set_tools(t.get_tools().await); - } - + handler.set_tools(tool_definitions.clone()); Box::new(handler) } }; @@ -196,24 +185,13 @@ impl<'a> LooperStreamBuilder<'a> { } }); - match self.tools { - Some(t) => { - let ls = LooperStream { - handler, - message_history: self.message_history, - tools: Arc::from(t), - }; - Ok((ls, looper_ui_receiver)) - } - None => { - let ls = LooperStream { - handler, - message_history: self.message_history, - tools: Arc::new(EmptyToolSet), - }; - Ok((ls, looper_ui_receiver)) - } - } + let ls = LooperStream { + handler, + message_history: self.message_history, + tools: Arc::new(tool_set), + }; + + Ok((ls, looper_ui_receiver)) } } @@ -225,7 +203,7 @@ impl LooperStream { tools: None, sub_agent: None, instructions: None, - // interface_sender: None, + ask_user_channel: None, buffered_output: false, } } diff --git a/src/tools/ask_user.rs b/src/tools/ask_user.rs new file mode 100644 index 0000000..195700c --- /dev/null +++ b/src/tools/ask_user.rs @@ -0,0 +1,85 @@ +use async_trait::async_trait; +use serde_json::{Value, json}; +use tokio::sync::oneshot; +use uuid::Uuid; + +use crate::{ + tools::LooperTool, + types::{AskUserRequest, LooperToolDefinition}, +}; + +pub struct AskUserTool { + sender: crate::types::AskUserSender, +} + +impl AskUserTool { + pub const NAME: &'static str = "ask_user"; + + pub fn new(sender: crate::types::AskUserSender) -> Self { + AskUserTool { sender } + } +} + +#[async_trait] +impl LooperTool for AskUserTool { + fn get_tool_name(&self) -> String { + Self::NAME.to_string() + } + + fn tool(&self) -> LooperToolDefinition { + LooperToolDefinition::default() + .set_name(Self::NAME) + .set_description( + "Ask the human user a blocking question when you need input to continue. IMPORTANT: call this tool alone, never in the same assistant turn as any other tool.", + ) + .set_paramters(json!({ + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "The exact question to show the user." + }, + "options": { + "type": "array", + "description": "Optional suggested answer choices to show the user.", + "items": { "type": "string" } + } + }, + "required": ["question"] + })) + } + + async fn execute(&mut self, args: &Value) -> Value { + let Some(question) = args.get("question").and_then(Value::as_str) else { + return json!({ "error": "Missing 'question' argument" }); + }; + + let options = args + .get("options") + .and_then(Value::as_array) + .map(|values| { + values + .iter() + .filter_map(|value| value.as_str().map(ToString::to_string)) + .collect::>() + }) + .unwrap_or_default(); + + let (response_tx, response_rx) = oneshot::channel(); + let request = AskUserRequest { + id: Uuid::new_v4().to_string(), + question: question.to_string(), + options, + response_tx, + }; + + if self.sender.send(request).await.is_err() { + return json!({ "error": "ask_user channel is closed" }); + } + + match response_rx.await { + Ok(response) => json!({ "answer": response.answer }), + Err(_) => json!({ "error": "ask_user response channel was dropped" }), + } + } +} diff --git a/src/tools/handler.rs b/src/tools/handler.rs new file mode 100644 index 0000000..3c4edcb --- /dev/null +++ b/src/tools/handler.rs @@ -0,0 +1,65 @@ +use std::{collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use serde_json::{Value, json}; +use tokio::sync::Mutex; + +use crate::{ + tools::{LooperTool, LooperTools}, + types::LooperToolDefinition, +}; + +pub struct CompositeToolSet { + base: Option>, + injected: HashMap>>, +} + +impl CompositeToolSet { + pub fn new(base: Option>) -> Self { + CompositeToolSet { + base, + injected: HashMap::new(), + } + } +} + +#[async_trait] +impl LooperTools for CompositeToolSet { + async fn get_tools(&self) -> Vec { + let mut tools = if let Some(base) = &self.base { + base.get_tools().await + } else { + Vec::new() + }; + + tools.retain(|tool| !self.injected.contains_key(&tool.name)); + + for registered in self.injected.values() { + let guard = registered.lock().await; + tools.push(guard.tool().clone()); + } + + tools + } + + async fn add_tool(&mut self, tool: Arc) { + let tool_name = tool.get_tool_name(); + self.injected.insert(tool_name, Mutex::new(tool)); + } + + async fn run_tool(&self, name: String, args: Value) -> Value { + if let Some(registered) = self.injected.get(&name) { + let mut guard = registered.lock().await; + let tool = Arc::get_mut(&mut guard) + .expect("tool has multiple references; injected tools must be uniquely owned"); + + return tool.execute(&args).await; + } + + if let Some(base) = &self.base { + return base.run_tool(name, args).await; + } + + json!({ "error": format!("Unknown function: {}", name) }) + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 13e281f..93f0b5b 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,3 +1,9 @@ +pub mod ask_user; +pub use ask_user::*; + +pub mod handler; +pub use handler::*; + pub mod sub_agent; pub use sub_agent::*; diff --git a/src/types/ask_user.rs b/src/types/ask_user.rs new file mode 100644 index 0000000..04f5e97 --- /dev/null +++ b/src/types/ask_user.rs @@ -0,0 +1,27 @@ +use std::fmt; + +use tokio::sync::{mpsc, oneshot}; + +pub type AskUserSender = mpsc::Sender; + +pub struct AskUserRequest { + pub id: String, + pub question: String, + pub options: Vec, + pub response_tx: oneshot::Sender, +} + +impl fmt::Debug for AskUserRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AskUserRequest") + .field("id", &self.id) + .field("question", &self.question) + .field("options", &self.options) + .finish_non_exhaustive() + } +} + +#[derive(Debug)] +pub struct AskUserResponse { + pub answer: String, +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 8556b8a..a559eca 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,3 +1,6 @@ +pub mod ask_user; +pub use ask_user::*; + pub mod messages; pub use messages::*;