Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions crates/goat-agent/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashSet,
collections::{HashMap, HashSet},
fmt::Write as _,
path::Path,
sync::{
Expand All @@ -14,9 +14,9 @@ use goat_auth::{
};
use goat_core::Engine;
use goat_protocol::{
AccountChoice, AccountEntry, AccountInfo, AuthMethod, Effort, Event, LoginCredential,
LoginProvider, ModelEntry, ModelTarget, NotifyKind, Op, SkillInfo, TaskId, ThreadSummary,
ToolCall, ToolCallId, ToolOutcome, TranscriptEntry,
AccountChoice, AccountEntry, AccountInfo, AskQuestion, AuthMethod, Effort, Event,
LoginCredential, LoginProvider, ModelEntry, ModelTarget, NotifyKind, Op, SkillInfo, TaskId,
ThreadSummary, ToolCall, ToolCallId, ToolOutcome, TranscriptEntry,
};
use goat_provider::{
ContentBlock, Message, MessageRole, Provider, Request, StreamEvent, ToolDefinition,
Expand All @@ -26,7 +26,7 @@ use goat_store::{NewMessage, NewThread, NewToolCall, NewTurn, Store};
use goat_tool::ToolContext;
use goat_tools::ToolRegistry;
use tokio::{
sync::{Semaphore, mpsc},
sync::{Mutex, Semaphore, mpsc, oneshot},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
Expand All @@ -40,6 +40,7 @@ pub use agent::{AgentRegistry, AgentSpec, ToolSelection};
const MAX_TOOL_ROUNDS: usize = 20;
const MAX_CONCURRENT_AGENTS: usize = 8;
const AGENT_TOOL_NAME: &str = "Agent";
const ASK_TOOL_NAME: &str = "Ask";
const CHILD_ID_BASE: u64 = 1 << 32;
const SYSTEM_PROMPT: &str = "You are Goat, an expert software engineering assistant. You help users understand, build, and improve software by reading code, running tools, and providing accurate, actionable guidance. When using tools, prefer targeted reads and searches over broad exploration. Always verify your understanding before making changes.";

Expand Down Expand Up @@ -123,6 +124,7 @@ struct Ctx<'a> {
instructions: Option<&'a str>,
semaphore: &'a Arc<Semaphore>,
child_ids: &'a AtomicU64,
asks: &'a Mutex<HashMap<ToolCallId, oneshot::Sender<Vec<String>>>>,
rl_cache: &'a std::sync::Mutex<rate_limit_cache::RateLimitCache>,
rl_path: Option<&'a std::path::Path>,
}
Expand Down Expand Up @@ -331,6 +333,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver<Op>, events: mpsc::Sender
let project_instructions = instructions::load_project_instructions(&cwd);
let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_AGENTS));
let child_ids = AtomicU64::new(CHILD_ID_BASE);
let asks: Mutex<HashMap<ToolCallId, oneshot::Sender<Vec<String>>>> = Mutex::new(HashMap::new());
let _ = events
.send(Event::SkillsChanged {
skills: skills.clone(),
Expand Down Expand Up @@ -370,6 +373,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver<Op>, events: mpsc::Sender
instructions: project_instructions.as_deref(),
semaphore: &semaphore,
child_ids: &child_ids,
asks: &asks,
rl_cache: &rl_cache,
rl_path: rl_path.as_deref(),
};
Expand All @@ -387,7 +391,7 @@ async fn run(agent: GoatAgent, mut ops: mpsc::Receiver<Op>, events: mpsc::Sender
break;
}
}
Op::Interrupt { .. } | Op::SetTheme { .. } => {}
Op::Interrupt { .. } | Op::SetTheme { .. } | Op::Answer { .. } => {}
Op::Clear => {
history.clear();
thread_id = None;
Expand Down Expand Up @@ -973,7 +977,9 @@ async fn execute_tool(
tool_ctx: &ToolContext,
token: &CancellationToken,
) -> ToolExecResult {
let step = if prep.name == AGENT_TOOL_NAME && env.allow_delegate {
let step = if prep.name == ASK_TOOL_NAME && env.allow_delegate {
Some(run_ask(ctx, run, prep.input_json, ToolCallId(prep.tui_id), token).await)
} else if prep.name == AGENT_TOOL_NAME && env.allow_delegate {
match ctx.semaphore.acquire().await {
Ok(_permit) if !token.is_cancelled() => {
Some(run_delegation(ctx, env, prep.input_json, run.id, token).await)
Expand Down Expand Up @@ -1438,12 +1444,50 @@ fn build_tool_defs(
input_schema: spec.parameters,
})
.collect();
if allow_delegate && !ctx.agents.is_empty() {
defs.push(agent_tool_def(ctx));
if allow_delegate {
if !ctx.agents.is_empty() {
defs.push(agent_tool_def(ctx));
}
defs.push(ask_tool_def());
}
defs
}

fn ask_tool_def() -> ToolDefinition {
ToolDefinition {
name: ASK_TOOL_NAME.to_owned(),
description: "Pause execution and ask the user one or more questions, each with optional choice options. Returns the user's answers as a JSON array of strings in the same order as the questions. Use when you need the user's input or a decision before proceeding.".to_owned(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"questions": {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"properties": {
"question": { "type": "string" },
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": { "type": "string" },
"description": { "type": "string" }
},
"required": ["label"]
}
}
},
"required": ["question"]
}
}
},
"required": ["questions"]
}),
}
}

fn agent_tool_def(ctx: &Ctx<'_>) -> ToolDefinition {
let names: Vec<String> = ctx.agents.names();
let mut description = String::from(
Expand Down Expand Up @@ -1621,6 +1665,52 @@ async fn run_delegation(
result
}

async fn run_ask(
ctx: &Ctx<'_>,
run: &Run<'_>,
input_json: &str,
call_id: ToolCallId,
token: &CancellationToken,
) -> Result<String, String> {
#[derive(serde::Deserialize)]
struct Input {
questions: Vec<AskQuestion>,
}
let args: Input =
serde_json::from_str(input_json).map_err(|err| format!("invalid Ask input: {err}"))?;
if args.questions.is_empty() {
return Err("questions must not be empty".to_owned());
}
let (tx, rx) = oneshot::channel::<Vec<String>>();
ctx.asks.lock().await.insert(call_id, tx);
let _ = ctx
.events
.send(Event::AskStarted {
id: run.id,
call: call_id,
questions: args.questions,
})
.await;
let result = tokio::select! {
biased;
() = token.cancelled() => {
ctx.asks.lock().await.remove(&call_id);
let _ = ctx
.events
.send(Event::AskDismissed { id: run.id, call: call_id })
.await;
return Err("interrupted".to_owned());
}
res = rx => res,
};
match result {
Ok(answers) => {
serde_json::to_string(&answers).map_err(|err| format!("serialize error: {err}"))
}
Err(_) => Err("answer channel closed".to_owned()),
}
}

fn delegation_label(prompt: &str) -> String {
let line = prompt.lines().next().unwrap_or("").trim();
if line.chars().count() > 50 {
Expand Down Expand Up @@ -1929,6 +2019,11 @@ async fn handle_turn(
biased;
result = &mut core => break result,
maybe_op = ops.recv() => match maybe_op {
Some(Op::Answer { call, answers, .. }) => {
if let Some(tx) = ctx.asks.lock().await.remove(&call) {
let _ = tx.send(answers);
}
}
Some(Op::Interrupt { id: target_id }) if target_id == id => token.cancel(),
Some(Op::Shutdown) | None => {
shutdown = true;
Expand Down
28 changes: 28 additions & 0 deletions crates/goat-protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ pub enum Op {
RenameThread {
title: String,
},
Answer {
id: TaskId,
call: ToolCallId,
answers: Vec<String>,
},
Shutdown,
}

Expand Down Expand Up @@ -306,6 +311,15 @@ pub enum Event {
kind: NotifyKind,
message: String,
},
AskStarted {
id: TaskId,
call: ToolCallId,
questions: Vec<AskQuestion>,
},
AskDismissed {
id: TaskId,
call: ToolCallId,
},
Usage {
id: TaskId,
usage: Usage,
Expand All @@ -318,3 +332,17 @@ pub enum Event {
cached_at: i64,
},
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AskOption {
pub label: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AskQuestion {
pub question: String,
#[serde(default)]
pub options: Vec<AskOption>,
}
14 changes: 13 additions & 1 deletion crates/goat-tui/src/app/engine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use goat_protocol::{Event as EngineEvent, Op, TaskId, TranscriptEntry};

use super::{App, Overlay, ResumeIntent};
use crate::picker::ThreadPicker;
use crate::picker::{AskPicker, ThreadPicker};

impl App {
#[allow(clippy::too_many_lines)]
Expand Down Expand Up @@ -142,6 +142,18 @@ impl App {
self.toasts.push(crate::toast::Toast::new(kind, message));
self.dirty = true;
}
EngineEvent::AskStarted {
call, questions, ..
} => {
self.overlay = Overlay::Ask(AskPicker::new(questions), call);
self.dirty = true;
}
EngineEvent::AskDismissed { .. } => {
if matches!(self.overlay, Overlay::Ask(..)) {
self.overlay = Overlay::None;
self.dirty = true;
}
}
EngineEvent::Usage {
id: _,
usage,
Expand Down
Loading
Loading