Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 31 additions & 50 deletions src/looper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::{
openai_responses_non_streaming::OpenAIResponsesNonStreamingHandler,
},
},
tools::{EmptyToolSet, LooperTools, SubAgentTool},
types::{Handlers, MessageHistory, turn::TurnResult},
tools::{AskUserTool, CompositeToolSet, LooperTools, SubAgentTool},
types::{AskUserSender, Handlers, MessageHistory, turn::TurnResult},
};

pub struct Looper {
Expand All @@ -29,6 +29,7 @@ pub struct LooperBuilder<'a> {
tools: Option<Box<dyn LooperTools>>,
instructions: Option<String>,
sub_agent: Option<Looper>,
ask_user_channel: Option<AskUserSender>,
}

impl<'a> LooperBuilder<'a> {
Expand Down Expand Up @@ -57,88 +58,67 @@ impl<'a> LooperBuilder<'a> {
self
}

pub fn ask_user_channel(mut self, channel: AskUserSender) -> Self {
self.ask_user_channel = Some(channel);
self
}

pub async fn build(mut self) -> Result<Looper> {
let sub_agent_enabled = self.sub_agent.is_some();
let mut tool_set = CompositeToolSet::new(self.tools.take());

if let Some(sub_agent) = self.sub_agent.take() {
tool_set
.add_tool(Arc::new(SubAgentTool::new(sub_agent)))
.await;
}

if let Some(channel) = self.ask_user_channel.take() {
tool_set.add_tool(Arc::new(AskUserTool::new(channel))).await;
}

let tool_definitions = tool_set.get_tools().await;

let handler: Box<dyn ChatHandler> = match self.handler_type {
Handlers::Anthropic(m) => {
let mut handler = AnthropicNonStreamingHandler::new(
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::OpenAICompletions(m) => {
let mut handler = OpenAINonStreamingChatHandler::new(
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::OpenAIResponses(m) => {
let mut handler = OpenAIResponsesNonStreamingHandler::new(
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::Gemini(m) => {
let mut handler = GeminiNonStreamingHandler::new(
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
};

match self.tools {
Some(t) => Ok(Looper {
handler,
message_history: self.message_history,
tools: Arc::from(t),
}),
None => Ok(Looper {
handler,
message_history: self.message_history,
tools: Arc::new(EmptyToolSet),
}),
}
Ok(Looper {
handler,
message_history: self.message_history,
tools: Arc::new(tool_set),
})
}
}

Expand All @@ -150,6 +130,7 @@ impl Looper {
tools: None,
sub_agent: None,
instructions: None,
ask_user_channel: None,
}
}

Expand Down
92 changes: 35 additions & 57 deletions src/looper_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use crate::{
StreamingChatHandler, anthropic::AnthropicHandler, gemini::GeminiHandler,
openai_completions::OpenAIChatHandler, openai_responses::OpenAIResponsesHandler,
},
tools::{EmptyToolSet, LooperTools, SubAgentTool},
types::{HandlerToLooperMessage, Handlers, LooperToInterfaceMessage, MessageHistory},
tools::{AskUserTool, CompositeToolSet, LooperTools, SubAgentTool},
types::{
AskUserSender, HandlerToLooperMessage, Handlers, LooperToInterfaceMessage, MessageHistory,
},
};
use anyhow::Result;
use tera::{Context, Tera};
Expand All @@ -29,6 +31,7 @@ pub struct LooperStreamBuilder<'a> {
tools: Option<Box<dyn LooperTools>>,
instructions: Option<String>,
sub_agent: Option<Looper>,
ask_user_channel: Option<AskUserSender>,
buffered_output: bool,
}

Expand Down Expand Up @@ -58,6 +61,11 @@ impl<'a> LooperStreamBuilder<'a> {
self
}

pub fn ask_user_channel(mut self, channel: AskUserSender) -> Self {
self.ask_user_channel = Some(channel);
self
}

pub fn buffered_output(mut self) -> Self {
self.buffered_output = true;
self
Expand All @@ -67,6 +75,19 @@ impl<'a> LooperStreamBuilder<'a> {
let sub_agent_enabled = self.sub_agent.is_some();
let (handler_looper_sender, mut handler_looper_receiver) = mpsc::channel(10000);
let (looper_ui_sender, looper_ui_receiver) = mpsc::channel(10000);
let mut tool_set = CompositeToolSet::new(self.tools.take());

if let Some(sub_agent) = self.sub_agent.take() {
tool_set
.add_tool(Arc::new(SubAgentTool::new(sub_agent)))
.await;
}

if let Some(channel) = self.ask_user_channel.take() {
tool_set.add_tool(Arc::new(AskUserTool::new(channel))).await;
}

let tool_definitions = tool_set.get_tools().await;

let handler: Box<dyn StreamingChatHandler> = match self.handler_type {
Handlers::OpenAICompletions(m) => {
Expand All @@ -75,15 +96,7 @@ impl<'a> LooperStreamBuilder<'a> {
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::OpenAIResponses(m) => {
Expand All @@ -92,15 +105,7 @@ impl<'a> LooperStreamBuilder<'a> {
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::Anthropic(m) => {
Expand All @@ -109,15 +114,7 @@ impl<'a> LooperStreamBuilder<'a> {
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
Handlers::Gemini(m) => {
Expand All @@ -126,15 +123,7 @@ impl<'a> LooperStreamBuilder<'a> {
m,
&get_system_message(self.instructions.as_deref(), sub_agent_enabled)?,
)?;

if let Some(t) = self.tools.as_mut() {
if let Some(sa) = self.sub_agent {
let agent_tools = Arc::new(SubAgentTool::new(sa));
let _ = t.add_tool(agent_tools).await;
}
handler.set_tools(t.get_tools().await);
}

handler.set_tools(tool_definitions.clone());
Box::new(handler)
}
};
Expand Down Expand Up @@ -196,24 +185,13 @@ impl<'a> LooperStreamBuilder<'a> {
}
});

match self.tools {
Some(t) => {
let ls = LooperStream {
handler,
message_history: self.message_history,
tools: Arc::from(t),
};
Ok((ls, looper_ui_receiver))
}
None => {
let ls = LooperStream {
handler,
message_history: self.message_history,
tools: Arc::new(EmptyToolSet),
};
Ok((ls, looper_ui_receiver))
}
}
let ls = LooperStream {
handler,
message_history: self.message_history,
tools: Arc::new(tool_set),
};

Ok((ls, looper_ui_receiver))
}
}

Expand All @@ -225,7 +203,7 @@ impl LooperStream {
tools: None,
sub_agent: None,
instructions: None,
// interface_sender: None,
ask_user_channel: None,
buffered_output: false,
}
}
Expand Down
85 changes: 85 additions & 0 deletions src/tools/ask_user.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use async_trait::async_trait;
use serde_json::{Value, json};
use tokio::sync::oneshot;
use uuid::Uuid;

use crate::{
tools::LooperTool,
types::{AskUserRequest, LooperToolDefinition},
};

pub struct AskUserTool {
sender: crate::types::AskUserSender,
}

impl AskUserTool {
pub const NAME: &'static str = "ask_user";

pub fn new(sender: crate::types::AskUserSender) -> Self {
AskUserTool { sender }
}
}

#[async_trait]
impl LooperTool for AskUserTool {
fn get_tool_name(&self) -> String {
Self::NAME.to_string()
}

fn tool(&self) -> LooperToolDefinition {
LooperToolDefinition::default()
.set_name(Self::NAME)
.set_description(
"Ask the human user a blocking question when you need input to continue. IMPORTANT: call this tool alone, never in the same assistant turn as any other tool.",
)
.set_paramters(json!({
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The exact question to show the user."
},
"options": {
"type": "array",
"description": "Optional suggested answer choices to show the user.",
"items": { "type": "string" }
}
},
"required": ["question"]
}))
}

async fn execute(&mut self, args: &Value) -> Value {
let Some(question) = args.get("question").and_then(Value::as_str) else {
return json!({ "error": "Missing 'question' argument" });
};

let options = args
.get("options")
.and_then(Value::as_array)
.map(|values| {
values
.iter()
.filter_map(|value| value.as_str().map(ToString::to_string))
.collect::<Vec<_>>()
})
.unwrap_or_default();

let (response_tx, response_rx) = oneshot::channel();
let request = AskUserRequest {
id: Uuid::new_v4().to_string(),
question: question.to_string(),
options,
response_tx,
};

if self.sender.send(request).await.is_err() {
return json!({ "error": "ask_user channel is closed" });
}

match response_rx.await {
Ok(response) => json!({ "answer": response.answer }),
Err(_) => json!({ "error": "ask_user response channel was dropped" }),
}
}
}
Loading