diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index aafb112d49..5b2a7bd747 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -260,4 +260,8 @@ pub trait API: Sync + Send { /// Check the OAuth authentication status of an MCP server async fn mcp_auth_status(&self, server_url: &str) -> Result; + + /// Undoes the most recent snapshot for the given file path. + /// Used by the rewind command to revert file changes. + async fn undo_snapshot(&self, path: &str) -> Result<()>; } diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index b56d485bfd..49659e7c74 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -67,7 +67,8 @@ impl< F: CommandInfra + EnvironmentInfra + SkillRepository - + GrpcInfra, + + GrpcInfra + + SnapshotRepository, > API for ForgeAPI { async fn discover(&self) -> Result> { @@ -435,6 +436,11 @@ impl< Ok(forge_infra::mcp_auth_status(server_url, &env).await) } + async fn undo_snapshot(&self, path: &str) -> Result<()> { + self.infra.undo_snapshot(std::path::Path::new(path)).await?; + Ok(()) + } + fn hydrate_channel(&self) -> Result<()> { self.infra.hydrate(); Ok(()) diff --git a/crates/forge_app/src/tool_registry.rs b/crates/forge_app/src/tool_registry.rs index dbfff3da06..1a162df0bb 100644 --- a/crates/forge_app/src/tool_registry.rs +++ b/crates/forge_app/src/tool_registry.rs @@ -218,9 +218,25 @@ impl> ToolReg ) -> ToolResult { let call_id = call.call_id.clone(); let tool_name = call.name.clone(); - let output = self.call_inner(agent, call, context).await; + let output = self.call_inner(agent, call.clone(), context).await; - ToolResult::new(tool_name).call_id(call_id).output(output) + let mut modified_files = Vec::new(); + if let Ok(output) = &output { + if let Some(text) = output.as_str() { + modified_files = forge_domain::extract_modified_files_from_output(text); + } + + // Fallback to extraction from arguments if output didn't yield anything + // This is important for relative paths provided in tool arguments + if modified_files.is_empty() { + modified_files = Self::extract_modified_files(&call); + } + } + + ToolResult::new(tool_name) + .call_id(call_id) + .output(output) + .modified_files(modified_files) } pub async fn list(&self) -> anyhow::Result> { @@ -381,6 +397,26 @@ impl ToolRegistry { IMAGE_EXTENSIONS.iter().any(|ext| path_lower.ends_with(ext)) } + /// Extracts the file path from a tool call's arguments if it's a + /// file-modifying tool (write, patch, remove). Returns an empty vec + /// for non-modifying tools or if the path cannot be parsed. + fn extract_modified_files(call: &ToolCallFull) -> Vec { + match call.name.as_str() { + "write" | "patch" | "multi_patch" | "remove" => { + if let Ok(args) = call.arguments.parse() { + if let Some(file_path) = args.get("file_path").and_then(|v| v.as_str()) { + return vec![file_path.to_string()]; + } + if let Some(path) = args.get("path").and_then(|v| v.as_str()) { + return vec![path.to_string()]; + } + } + vec![] + } + _ => vec![], + } + } + /// Validates if a tool's modality requirements are supported by the current /// model. /// diff --git a/crates/forge_domain/src/compact/summary.rs b/crates/forge_domain/src/compact/summary.rs index 3416dfdba8..6683e39dcc 100644 --- a/crates/forge_domain/src/compact/summary.rs +++ b/crates/forge_domain/src/compact/summary.rs @@ -451,6 +451,7 @@ mod tests { name: ToolName::new(name), call_id: Some(ToolCallId::new(call_id)), output: ToolOutput::text("result").is_error(is_error), + modified_files: vec![], }) } @@ -805,6 +806,7 @@ mod tests { name: ToolName::new("read"), call_id: None, output: ToolOutput::text("result"), + modified_files: vec![], }), ]); diff --git a/crates/forge_domain/src/context.rs b/crates/forge_domain/src/context.rs index c2f0f30fde..0fd9ec6b09 100644 --- a/crates/forge_domain/src/context.rs +++ b/crates/forge_domain/src/context.rs @@ -17,6 +17,7 @@ fn is_false(value: &bool) -> bool { use crate::temperature::Temperature; use crate::top_k::TopK; use crate::top_p::TopP; +use crate::xml::{clean_user_prompt, strip_xml_tags}; use crate::{ Attachment, AttachmentContent, ConversationId, EventValue, Image, MessagePhase, ModelId, ReasoningFull, ToolChoice, ToolDefinition, ToolOutput, ToolValue, Usage, @@ -374,6 +375,17 @@ pub struct MessageEntry { pub usage: Option, } +impl MessageEntry { + /// Returns true if this is a user message that should be visible in history + /// (not droppable). + pub fn is_user_message(&self) -> bool { + match &self.message { + ContextMessage::Text(msg) => msg.role == Role::User && !msg.droppable, + _ => false, + } + } +} + impl From for MessageEntry { fn from(value: ContextMessage) -> Self { MessageEntry { message: value, usage: Default::default() } @@ -432,6 +444,8 @@ pub struct Context { pub response_format: Option, } +const REWIND_PREVIEW_MAX_LEN: usize = 100; + impl Context { pub fn accumulate_usage(&self) -> Option { self.messages @@ -685,14 +699,109 @@ impl Context { self.messages.len() } + /// Truncates the context, keeping messages up to and including the given + /// index, and removing all messages after that point. + /// + /// Also removes tool results that reference tool calls that were + /// truncated, and resets usage data to match the remaining messages. + /// + /// Returns the truncated context. + pub fn truncate(mut self, keep_up_to: usize) -> Self { + if keep_up_to >= self.messages.len() { + return self; + } + self.messages.truncate(keep_up_to + 1); + self + } + + /// Truncates the conversation to keep messages up to (but excluding) the + /// Nth (0-indexed) user message. All subsequent messages are discarded. + pub fn truncate_to_user_message(mut self, nth_user: usize) -> Self { + let cut_at = self + .messages + .iter() + .enumerate() + .filter_map(|(i, entry)| entry.is_user_message().then_some(i)) + .nth(nth_user); + + if let Some(idx) = cut_at { + self.messages.truncate(idx); + } + self + } + + /// Formats user messages with numbered indices for display during + /// interactive rewind. Returns a list of `(full_message_index, + /// display_string)` tuples, where the display string shows a 1-indexed + /// user message number and preview. + pub fn format_messages_for_rewind(&self) -> Vec<(usize, String)> { + self.messages + .iter() + .enumerate() + .filter(|(_, entry)| entry.is_user_message()) + .map(|(full_idx, entry)| { + let content = entry.content().unwrap_or("").trim(); + let cleaned = clean_user_prompt(content); + let preview = strip_xml_tags(&cleaned); + let preview = if preview.len() > REWIND_PREVIEW_MAX_LEN { + format!( + "{}...", + preview + .chars() + .take(REWIND_PREVIEW_MAX_LEN) + .collect::() + ) + } else { + preview.to_string() + }; + (full_idx, preview) + }) + .collect() + } + /// Returns the count of user messages in the context pub fn user_message_count(&self) -> usize { self.messages .iter() - .filter(|msg| msg.has_role(Role::User)) + .filter(|msg| msg.is_user_message()) .count() } + /// Returns the file paths modified by tool results in messages + /// at or after the given index (inclusive). Used by rewind to know which + /// file snapshots to revert. + pub fn modified_files_from(&self, from_index: usize) -> Vec { + let mut files = Vec::new(); + for msg in self.messages.iter().skip(from_index) { + if let ContextMessage::Tool(result) = &msg.message { + // 1. Use the pre-calculated modified files from the tool execution + let mut msg_modified = result.modified_files.clone(); + + // 2. Fallback dynamic extraction for older conversations where modified_files + // might be empty + if let Some(text) = result.output.as_str() { + let extracted_paths = crate::xml::extract_modified_files_from_output(text); + + for p in extracted_paths { + if !msg_modified.contains(&p) { + msg_modified.push(p); + } + } + } + + files.extend(msg_modified); + } + } + files + } + + /// Returns the file paths modified by tool results in messages + /// after the given index (exclusive). Used by rewind to know which + /// file snapshots to revert. + pub fn modified_files_after(&self, keep_up_to: usize) -> Vec { + self.modified_files_from(keep_up_to + 1) + } + /// Returns the count of assistant messages in the context pub fn assistant_message_count(&self) -> usize { self.messages @@ -906,6 +1015,7 @@ mod tests { name: crate::ToolName::new("text_tool"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::text("Text output".to_string()), + modified_files: vec![], }, ToolResult { name: crate::ToolName::new("empty_tool"), @@ -914,6 +1024,7 @@ mod tests { values: vec![crate::ToolValue::Empty], is_error: false, }, + modified_files: vec![], }, ]); @@ -932,6 +1043,7 @@ mod tests { name: crate::ToolName::new("image_tool"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::image(image), + modified_files: vec![], }]); let mut transformer = crate::transformer::ImageHandling::new(); @@ -956,6 +1068,7 @@ mod tests { ], is_error: false, }, + modified_files: vec![], }]); let mut transformer = crate::transformer::ImageHandling::new(); @@ -975,16 +1088,19 @@ mod tests { name: crate::ToolName::new("text_tool"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::text("Text output".to_string()), + modified_files: vec![], }, ToolResult { name: crate::ToolName::new("image_tool1"), call_id: Some(crate::ToolCallId::new("call2")), output: crate::ToolOutput::image(image1), + modified_files: vec![], }, ToolResult { name: crate::ToolName::new("image_tool2"), call_id: Some(crate::ToolCallId::new("call3")), output: crate::ToolOutput::image(image2), + modified_files: vec![], }, ]); @@ -1018,6 +1134,7 @@ mod tests { ], is_error: false, }, + modified_files: vec![], }]); let mut transformer = crate::transformer::ImageHandling::new(); @@ -1036,6 +1153,7 @@ mod tests { values: vec![crate::ToolValue::Image(image)], is_error: true, }, + modified_files: vec![], }]); let mut transformer = crate::transformer::ImageHandling::new(); @@ -1381,11 +1499,13 @@ mod tests { name: crate::ToolName::new("tool1"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::text("Result 1".to_string()), + modified_files: vec![], }, ToolResult { name: crate::ToolName::new("tool2"), call_id: Some(crate::ToolCallId::new("call2")), output: crate::ToolOutput::text("Result 2".to_string()), + modified_files: vec![], }, ]); @@ -1538,6 +1658,7 @@ mod tests { name: crate::ToolName::new("fs_search"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::text("Search results: Found 3 items".to_string()), + modified_files: vec![], }); let actual = fixture.token_count_approx(); let expected = 8; // 30 chars / 4 = 8 tokens (rounded up) @@ -1552,6 +1673,7 @@ mod tests { name: crate::ToolName::new("screenshot"), call_id: Some(crate::ToolCallId::new("call1")), output: crate::ToolOutput::image(fixture_image), + modified_files: vec![], }); let actual = fixture.token_count_approx(); let expected = 0; // Images are not counted in token approximation @@ -1695,6 +1817,98 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_truncate_to_user_message_exclusive() { + let context = Context::default() + .add_message(ContextMessage::user("U1", None)) // idx 0 + .add_message(TextMessage::assistant("A1", None, None)) // idx 1 + .add_message(ContextMessage::user("U2", None)) // idx 2 + .add_message(TextMessage::assistant("A2", None, None)); // idx 3 + + // Rewind to U2 (nth_user = 1) -> keeps [U1, A1] + let rewound = context.clone().truncate_to_user_message(1); + assert_eq!(rewound.messages.len(), 2); + assert_eq!(rewound.messages[0].content().unwrap(), "U1"); + assert_eq!(rewound.messages[1].content().unwrap(), "A1"); + + // Rewind to U1 (nth_user = 0) -> keeps [] + let rewound = context.truncate_to_user_message(0); + assert_eq!(rewound.messages.len(), 0); + } + + #[test] + fn test_modified_files_from() { + use crate::{ToolName, ToolOutput, ToolResult}; + let context = Context::default() + .add_message(ContextMessage::user("U1", None)) + .add_message(ContextMessage::Tool(ToolResult { + name: ToolName::new("write"), + call_id: None, + output: ToolOutput::text("ok"), + modified_files: vec!["file1.txt".to_string()], + })) + .add_message(ContextMessage::user("U2", None)) + .add_message(ContextMessage::Tool(ToolResult { + name: ToolName::new("patch"), + call_id: None, + output: ToolOutput::text("ok"), + modified_files: vec!["file2.txt".to_string()], + })); + + // From U2 (idx 2) + let files = context.modified_files_from(2); + assert_eq!(files, vec!["file2.txt".to_string()]); + + // From U1 (idx 0) + let files = context.modified_files_from(0); + assert_eq!( + files, + vec!["file1.txt".to_string(), "file2.txt".to_string()] + ); + } + + #[test] + fn test_modified_files_from_duplicates() { + use crate::{ToolName, ToolOutput, ToolResult}; + let context = Context::default() + .add_message(ContextMessage::Tool(ToolResult { + name: ToolName::new("write"), + call_id: None, + output: ToolOutput::text("ok"), + modified_files: vec!["file1.txt".to_string()], + })) + .add_message(ContextMessage::Tool(ToolResult { + name: ToolName::new("patch"), + call_id: None, + output: ToolOutput::text("ok"), + modified_files: vec!["file1.txt".to_string()], + })); + + let files = context.modified_files_from(0); + // Should contain duplicates as each is a separate modification + assert_eq!( + files, + vec!["file1.txt".to_string(), "file1.txt".to_string()] + ); + } + + #[test] + fn test_modified_files_from_fallback_dedup() { + use crate::{ToolName, ToolOutput, ToolResult}; + let context = Context::default().add_message(ContextMessage::Tool(ToolResult { + name: ToolName::new("write"), + call_id: None, + // XML tag suggests file1.txt was modified + output: ToolOutput::text(""), + // result.modified_files ALSO has file1.txt + modified_files: vec!["file1.txt".to_string()], + })); + + let files = context.modified_files_from(0); + // Should NOT duplicate within the same tool result + assert_eq!(files, vec!["file1.txt".to_string()]); + } + /// Regression test: when both `reasoning` (raw text) and /// `reasoning_details` (structured, with a cryptographic signature) are /// present, `append_message` must NOT create a duplicate thinking block diff --git a/crates/forge_domain/src/tools/result.rs b/crates/forge_domain/src/tools/result.rs index 1f68ca294f..10518fe1c3 100644 --- a/crates/forge_domain/src/tools/result.rs +++ b/crates/forge_domain/src/tools/result.rs @@ -14,6 +14,10 @@ pub struct ToolResult { pub call_id: Option, #[setters(skip)] pub output: ToolOutput, + /// File paths modified by this tool call (for undo/rewind purposes). + /// Populated by file-modifying tools (Write, Patch, Remove). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub modified_files: Vec, } impl ToolResult { @@ -22,6 +26,7 @@ impl ToolResult { name: name.into(), call_id: Default::default(), output: Default::default(), + modified_files: vec![], } } @@ -75,6 +80,7 @@ impl From for ToolResult { name: value.name, call_id: value.call_id, output: Default::default(), + modified_files: vec![], } } } diff --git a/crates/forge_domain/src/transformer/drop_reasoning_details.rs b/crates/forge_domain/src/transformer/drop_reasoning_details.rs index e6a016feb7..1a2c0ac1c7 100644 --- a/crates/forge_domain/src/transformer/drop_reasoning_details.rs +++ b/crates/forge_domain/src/transformer/drop_reasoning_details.rs @@ -85,6 +85,7 @@ mod tests { name: ToolName::new("test_tool"), call_id: Some(ToolCallId::new("call_123")), output: ToolOutput::text("Tool result".to_string()), + modified_files: vec![], }]) } @@ -172,6 +173,7 @@ mod tests { name: ToolName::new("preserve_tool"), call_id: Some(ToolCallId::new("call_preserve")), output: ToolOutput::text("Tool output".to_string()), + modified_files: vec![], }]); let mut transformer = DropReasoningDetails; diff --git a/crates/forge_domain/src/transformer/image_handling.rs b/crates/forge_domain/src/transformer/image_handling.rs index c301b3778a..bc7cc741f7 100644 --- a/crates/forge_domain/src/transformer/image_handling.rs +++ b/crates/forge_domain/src/transformer/image_handling.rs @@ -100,6 +100,7 @@ mod tests { ], is_error: false, }, + modified_files: vec![], }]) } @@ -114,11 +115,13 @@ mod tests { name: ToolName::new("image_tool_1"), call_id: Some(ToolCallId::new("call_1")), output: ToolOutput::image(image1), + modified_files: vec![], }, ToolResult { name: ToolName::new("image_tool_2"), call_id: Some(ToolCallId::new("call_2")), output: ToolOutput::image(image2), + modified_files: vec![], }, ]) } @@ -141,6 +144,7 @@ mod tests { name: ToolName::new("text_tool"), call_id: Some(ToolCallId::new("call_text")), output: ToolOutput::text("Just text output".to_string()), + modified_files: vec![], }]); let mut transformer = ImageHandling::new(); @@ -178,6 +182,7 @@ mod tests { ], is_error: false, }, + modified_files: vec![], }]); let mut transformer = ImageHandling::new(); @@ -201,6 +206,7 @@ mod tests { ], is_error: true, }, + modified_files: vec![], }]); let mut transformer = ImageHandling::new(); @@ -237,6 +243,7 @@ mod tests { name: ToolName::new("image_tool"), call_id: Some(ToolCallId::new("call_preserve")), output: ToolOutput::image(image), + modified_files: vec![], }]); let mut transformer = ImageHandling::new(); diff --git a/crates/forge_domain/src/transformer/mod.rs b/crates/forge_domain/src/transformer/mod.rs index 1f4ccc91b7..5168d8089f 100644 --- a/crates/forge_domain/src/transformer/mod.rs +++ b/crates/forge_domain/src/transformer/mod.rs @@ -133,6 +133,7 @@ mod tests { name: ToolName::new("test_tool"), call_id: Some(ToolCallId::new("call_123")), output: ToolOutput::text("Tool result text".to_string()), + modified_files: vec![], }]) } diff --git a/crates/forge_domain/src/transformer/transform_tool_calls.rs b/crates/forge_domain/src/transformer/transform_tool_calls.rs index da063a8886..0b98650c53 100644 --- a/crates/forge_domain/src/transformer/transform_tool_calls.rs +++ b/crates/forge_domain/src/transformer/transform_tool_calls.rs @@ -119,6 +119,7 @@ mod tests { name: ToolName::new("test_tool"), call_id: Some(ToolCallId::new("call_123")), output: ToolOutput::text("Tool result text".to_string()), + modified_files: vec![], }]) } @@ -137,6 +138,7 @@ mod tests { ], is_error: false, }, + modified_files: vec![], }]) } @@ -206,6 +208,7 @@ mod tests { name: ToolName::new("empty_tool"), call_id: Some(ToolCallId::new("call_empty")), output: ToolOutput { values: vec![ToolValue::Empty], is_error: false }, + modified_files: vec![], }]); let mut transformer = TransformToolCalls::new(); diff --git a/crates/forge_domain/src/xml.rs b/crates/forge_domain/src/xml.rs index 7ba5fbadfc..31363ed35c 100644 --- a/crates/forge_domain/src/xml.rs +++ b/crates/forge_domain/src/xml.rs @@ -1,3 +1,14 @@ +/// Extracts a full XML tag (including brackets) by its name +pub fn extract_tag<'a>(text: &'a str, tag_name: &str) -> Option<&'a str> { + let opening_pattern = format!(r"<{tag_name}(?:\s[^>]*?)?>"); + if let Ok(regex) = regex::Regex::new(&opening_pattern) + && let Some(mat) = regex.find(text) + { + return Some(mat.as_str()); + } + None +} + /// Extracts content between the specified XML-style tags /// /// # Arguments @@ -58,6 +69,76 @@ pub fn remove_tag_with_prefix(text: &str, prefix: &str) -> String { result } +/// Cleans a user prompt by extracting content from tags if present, +/// or stripping all XML tags and meta-information. +pub fn clean_user_prompt(text: &str) -> String { + // 1. Try to extract content from or tags + if let Some(content) = extract_tag_content(text, "feedback") { + return content.to_string(); + } + if let Some(content) = extract_tag_content(text, "task") { + return content.to_string(); + } + + // 2. Remove known meta tags with their content + let mut cleaned = remove_tag_with_prefix(text, "system_"); + cleaned = remove_tag_with_prefix(&cleaned, "context_"); + cleaned = remove_tag_with_prefix(&cleaned, "terminal_context"); + + // 3. Strip all remaining tags but preserve newlines (unlike strip_xml_tags) + let tag_pattern = regex::Regex::new(r"<[^>]*>").unwrap(); + let result = tag_pattern.replace_all(&cleaned, "").to_string(); + + // Trim while preserving internal structure + result.trim().to_string() +} + +/// Extracts the value of an attribute from an XML tag +pub fn extract_attribute(tag: &str, attr_name: &str) -> Option { + let pattern = format!(r#"{attr_name}="([^"]*)""#, attr_name = attr_name); + if let Ok(regex) = regex::Regex::new(&pattern) + && let Some(captures) = regex.captures(tag) + { + return captures.get(1).map(|m| m.as_str().to_string()); + } + None +} + +/// Removes all XML/HTML tags from the text, keeping only the content between +/// tags. Multiple whitespace characters are collapsed into a single space. +pub fn strip_xml_tags(text: &str) -> String { + let tag_pattern = regex::Regex::new(r"<[^>]*>").unwrap(); + let result = tag_pattern.replace_all(text, "").to_string(); + // Collapse multiple whitespace characters into a single space + let re_whitespace = regex::Regex::new(r"\s+").unwrap(); + re_whitespace.replace_all(&result, " ").trim().to_string() +} + +/// Extracts file paths from XML tags in the given text. +/// Supports tags: plan_created, file_created, file_overwritten, file_diff, +/// file_removed. +pub fn extract_modified_files_from_output(text: &str) -> Vec { + let mut modified_files = Vec::new(); + let tags = [ + "plan_created", + "file_created", + "file_overwritten", + "file_diff", + "file_removed", + ]; + + for tag_name in tags { + if let Some(tag) = extract_tag(text, tag_name) + && let Some(path) = extract_attribute(tag, "path") + { + modified_files.push(path); + // Return only the first matching tag to maintain parity with existing logic + return modified_files; + } + } + modified_files +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; @@ -72,6 +153,14 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_extract_tag_content_with_duplicate_closing_tags() { + let fixture = "Some text 123 and more text"; + let actual = extract_tag_content(fixture, "summary"); + let expected = Some("123"); + assert_eq!(actual, expected); + } + #[test] fn test_extract_tag_content_no_tags() { let fixture = "Some text without any tags"; @@ -159,10 +248,60 @@ mod tests { } #[test] - fn test_with_duplicate_closing_tags() { - let fixture = "123"; - let actual = extract_tag_content(fixture, "foo").unwrap(); - let expected = "123"; + fn test_clean_user_prompt_with_tags() { + let fixture = "add feature to determine recipe from images using vision llm models\n::: 2026-05-16\n::: "; + let actual = clean_user_prompt(fixture); + // Should extract ONLY feedback and trim + let expected = "add feature to determine recipe from images using vision llm models"; + assert_eq!(actual, expected); + } + + #[test] + fn test_clean_user_prompt_without_feedback() { + let fixture = "Just plain text 2026"; + let actual = clean_user_prompt(fixture); + // Should strip system_date and its tags + let expected = "Just plain text"; + assert_eq!(actual, expected); + } + + #[test] + fn test_clean_user_prompt_with_terminal_context() { + let fixture = "create a planls"; + let actual = clean_user_prompt(fixture); + let expected = "create a plan"; + assert_eq!(actual, expected); + } + + #[test] + fn test_clean_user_prompt_with_task_only() { + let fixture = "new tasksome info"; + let actual = clean_user_prompt(fixture); + let expected = "new task"; + assert_eq!(actual, expected); + } + + #[test] + fn test_extract_modified_files_from_output() { + let fixture = r#"Some text more text"#; + let actual = extract_modified_files_from_output(fixture); + let expected = vec!["/abs/path/to/file.txt".to_string()]; + assert_eq!(actual, expected); + + let fixture = r#"Plan content"#; + let actual = extract_modified_files_from_output(fixture); + let expected = vec!["plan.md".to_string()]; + assert_eq!(actual, expected); + + let fixture = r#""#; + let actual = extract_modified_files_from_output(fixture); + // Priority: file_created > file_overwritten + let expected = vec!["new.txt".to_string()]; + assert_eq!(actual, expected); + + let fixture = "No tags here"; + let actual = extract_modified_files_from_output(fixture); + let expected: Vec = vec![]; assert_eq!(actual, expected); } } diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index a4c859bcd7..0a031c4164 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -747,6 +747,12 @@ pub enum ConversationCommand { id: ConversationId, }, + /// Rewind a conversation to an earlier message. + Rewind { + /// Conversation ID to rewind. + id: Option, + }, + /// Show last assistant message. Show { /// Conversation ID. diff --git a/crates/forge_main/src/model.rs b/crates/forge_main/src/model.rs index 5ee93c0405..0ac79be6ac 100644 --- a/crates/forge_main/src/model.rs +++ b/crates/forge_main/src/model.rs @@ -143,6 +143,7 @@ impl ForgeCommandManager { | "suggest" | "s" | "clone" + | "rewind" | "conversation-rename" | "copy" | "workspace-sync" @@ -494,6 +495,15 @@ pub enum AppCommand { id: Option, }, + /// Rewind a conversation to an earlier message, discarding subsequent + /// messages. This can be triggered with the '/rewind' command. + #[strum(props(usage = "Rewind a conversation to an earlier message"))] + Rewind { + /// Conversation ID to rewind (optional — prompts interactively if + /// absent) + id: Option, + }, + /// Rename any conversation interactively. /// This can be triggered with the '/conversation-rename' command. #[strum(props(usage = "Rename a conversation interactively"))] @@ -733,6 +743,7 @@ impl AppCommand { AppCommand::CommitPreview => "commit-preview", AppCommand::Suggest { .. } => "suggest", AppCommand::Clone { .. } => "clone", + AppCommand::Rewind { .. } => "rewind", AppCommand::ConversationRename { .. } => "conversation-rename", AppCommand::Copy => "copy", AppCommand::WorkspaceSync => "workspace-sync", diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 7d2034f26d..c9c5b65a0b 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -19,6 +19,7 @@ use forge_config::ForgeConfig; use forge_display::MarkdownFormat; use forge_domain::{ AuthMethod, ChatResponseContent, ConsoleWriter, ContextMessage, Role, TitleFormat, UserCommand, + clean_user_prompt, }; use forge_fs::ForgeFS; use forge_select::{ForgeWidget, SelectRow}; @@ -117,6 +118,49 @@ pub struct UI A> { } impl A + Send + Sync> UI { + /// Resolves a conversation ID from an optional string ID, the current + /// active conversation, or by showing an interactive picker. + async fn resolve_conversation_id( + &mut self, + id: Option, + ) -> anyhow::Result> { + if let Some(id_str) = id { + return Ok(Some(ConversationId::parse(&id_str).map_err(|_| { + anyhow::anyhow!("Invalid conversation ID: {id_str}") + })?)); + } + + if let Some(cid) = self.state.conversation_id { + return Ok(Some(cid)); + } + + if let Some(cid) = self.cli.conversation_id { + return Ok(Some(cid)); + } + + // Show conversation picker + let conversations = self + .api + .get_conversations(Some(self.config.max_conversations)) + .await?; + + if conversations.is_empty() { + self.writeln_title(TitleFormat::error( + "No conversations found. Start a conversation first.", + ))?; + return Ok(None); + } + + let selected = ConversationSelector::select_conversation( + &conversations, + self.state.conversation_id, + None, + ) + .await?; + + Ok(selected.map(|conv| conv.id)) + } + /// Writes a line to the console output /// Takes anything that implements ToString trait fn writeln(&mut self, content: T) -> anyhow::Result<()> { @@ -917,6 +961,9 @@ impl A + Send + Sync> UI self.writeln_title(TitleFormat::info(format!("Resumed conversation: {id}")))?; // Interactive mode will be handled by the main loop } + ConversationCommand::Rewind { id } => { + self.on_slash_rewind(id.map(|i| i.to_string())).await?; + } ConversationCommand::Show { id, md } => { let conversation = self.validate_conversation_exists(&id).await?; @@ -2271,6 +2318,9 @@ impl A + Send + Sync> UI AppCommand::Clone { id } => { self.on_slash_clone(id).await?; } + AppCommand::Rewind { id } => { + self.on_slash_rewind(id).await?; + } AppCommand::ConversationRename { name } => { let args = if name.is_empty() { None @@ -2567,34 +2617,8 @@ impl A + Send + Sync> UI /// conversation is used; if no active conversation, an interactive picker /// is shown. async fn on_slash_clone(&mut self, id: Option) -> anyhow::Result<()> { - let target_id = if let Some(id_str) = id { - ConversationId::parse(&id_str) - .map_err(|_| anyhow::anyhow!("Invalid conversation ID: {id_str}"))? - } else { - // Show conversation picker - let conversations = self - .api - .get_conversations(Some(self.config.max_conversations)) - .await?; - - if conversations.is_empty() { - self.writeln_title(TitleFormat::error( - "No conversations found. Start a conversation first.", - ))?; - return Ok(()); - } - - let selected = ConversationSelector::select_conversation( - &conversations, - self.state.conversation_id, - None, - ) - .await?; - - match selected { - Some(conv) => conv.id, - None => return Ok(()), - } + let Some(target_id) = self.resolve_conversation_id(id).await? else { + return Ok(()); }; // Fetch the conversation to clone @@ -2622,6 +2646,160 @@ impl A + Send + Sync> UI Ok(()) } + /// Rewinds a conversation to an earlier message, discarding all subsequent + /// messages and their associated tool results and usage data. + /// + /// # Arguments + /// * `id` - Optional conversation ID. If `None`, an interactive picker is + /// shown. + async fn on_slash_rewind(&mut self, id: Option) -> anyhow::Result<()> { + let Some(target_id) = self.resolve_conversation_id(id).await? else { + return Ok(()); + }; + + // Fetch the conversation + let mut conversation = self + .api + .conversation(&target_id) + .await? + .ok_or_else(|| anyhow::anyhow!("Conversation '{target_id}' not found"))?; + + let context = conversation + .context + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Conversation has no context to rewind"))?; + + if context.messages.is_empty() { + self.writeln_title(TitleFormat::error( + "Conversation has no messages to rewind.", + ))?; + return Ok(()); + } + + // Build and show the interactive TUI message selector + let lines = context.format_messages_for_rewind(); + let user_count = lines.len(); + if user_count == 0 { + self.writeln_title(TitleFormat::error("No user messages to rewind to."))?; + return Ok(()); + } + + let last_full_idx = lines + .last() + .map(|(fi, _)| fi.to_string()) + .unwrap_or_default(); + let mut rows: Vec = Vec::with_capacity(lines.len()); + for (full_idx, display) in &lines { + rows.push( + SelectRow::new(full_idx.to_string(), display.clone()).search(display.clone()), + ); + } + + let selected = tokio::task::spawn_blocking(move || -> anyhow::Result> { + ForgeWidget::select_rows("Rewind to message", rows) + .initial_raw(last_full_idx) + .prompt() + }) + .await??; + + let full_idx = match selected { + Some(row) => row + .raw + .parse::() + .map_err(|_| anyhow::anyhow!("Invalid message index"))?, + None => { + self.writeln_title(TitleFormat::info("Rewind cancelled."))?; + return Ok(()); + } + }; + + // Find the 0-indexed user message position for the existing truncation method + let keep_nth_user = lines + .iter() + .position(|(fi, _)| *fi == full_idx) + .ok_or_else(|| anyhow::anyhow!("Selected message index {full_idx} not found"))?; + + let total_before = context.messages.len(); + + // Collect file paths that were modified by tool results in messages + // that will be removed. These files' snapshots need to be reverted. + // We use full_idx because it's the index of the user message we are rewinding + // AT. + let modified_files = context.modified_files_from(full_idx); + + // Perform the truncation (keep messages up to but excluding the selected user + // message) + let rewound_message_content = context + .messages + .get(full_idx) + .and_then(|m| m.content()) + .map(clean_user_prompt) + .unwrap_or_default(); + + let truncated_context = context.clone().truncate_to_user_message(keep_nth_user); + let removed = total_before - truncated_context.messages.len(); + let num_messages = truncated_context.messages.len(); + + conversation.context = Some(truncated_context); + conversation.metadata.updated_at = Some(chrono::Utc::now()); + + self.api.upsert_conversation(conversation).await?; + + // Revert file modifications from the removed messages (best-effort) + if !modified_files.is_empty() { + self.writeln_title(TitleFormat::info("Reverting file changes..."))?; + let mut failed = 0u32; + let total_undos = modified_files.len(); + for file_path in modified_files.iter().rev() { + match self.api.undo_snapshot(file_path).await { + Ok(_) => { + tracing::info!(file = %file_path, "Rewind reverted snapshot"); + } + Err(e) => { + tracing::warn!(file = %file_path, error = %e, "Failed to revert snapshot during rewind"); + failed += 1; + } + } + } + let status = if failed == 0 { + format!("Reverted {total_undos} file change(s).") + } else { + format!( + "Reverted {} of {total_undos} file change(s). {failed} failed.", + total_undos - failed as usize + ) + }; + self.writeln(status)?; + } + + let summary = if removed > 0 { + format!("Removed {removed} messages. Now has {num_messages} messages.") + } else { + format!("No messages removed. Still {total_before} messages.") + }; + + self.writeln_title( + TitleFormat::info("Rewound") + .sub_title(format!("[{}] {summary}", target_id.into_string())), + )?; + + // Set the rewound message content in the buffer for review/editing + if !rewound_message_content.is_empty() { + if self.cli.is_interactive() { + self.console.set_buffer(rewound_message_content); + } else { + // Check if we should write to a temporary file for the shell plugin + if let Ok(rewind_file) = std::env::var("FORGE_REWIND_FILE") { + let _ = std::fs::write(rewind_file, &rewound_message_content); + } + // Also print to stdout as fallback + println!("{rewound_message_content}"); + } + } + + Ok(()) + } + /// Renames any conversation interactively or by explicit ID and name. /// /// # Arguments diff --git a/crates/forge_repo/src/conversation/conversation_record.rs b/crates/forge_repo/src/conversation/conversation_record.rs index 7df99bf5a3..8be3654c8c 100644 --- a/crates/forge_repo/src/conversation/conversation_record.rs +++ b/crates/forge_repo/src/conversation/conversation_record.rs @@ -467,6 +467,7 @@ pub(super) struct ToolResultRecord { name: ToolNameRecord, call_id: Option, output: ToolOutputRecord, + modified_files: Vec, } impl From<&forge_domain::ToolResult> for ToolResultRecord { @@ -475,6 +476,7 @@ impl From<&forge_domain::ToolResult> for ToolResultRecord { name: ToolNameRecord::from(&result.name), call_id: result.call_id.as_ref().map(ToolCallIdRecord::from), output: ToolOutputRecord::from(&result.output), + modified_files: result.modified_files.clone(), } } } @@ -487,6 +489,7 @@ impl TryFrom for forge_domain::ToolResult { name: record.name.into(), call_id: record.call_id.map(Into::into), output: record.output.try_into()?, + modified_files: record.modified_files, }) } } diff --git a/crates/forge_repo/src/conversation/conversation_repo.rs b/crates/forge_repo/src/conversation/conversation_repo.rs index eeef25af71..3852611ebb 100644 --- a/crates/forge_repo/src/conversation/conversation_repo.rs +++ b/crates/forge_repo/src/conversation/conversation_repo.rs @@ -716,6 +716,7 @@ mod tests { is_error: false, values: vec![ToolValue::Text("Result text".to_string()), ToolValue::Empty], }, + modified_files: vec![], }) .into(), forge_domain::MessageEntry { diff --git a/crates/forge_repo/src/provider/anthropic.rs b/crates/forge_repo/src/provider/anthropic.rs index c9449d1bca..ae293519be 100644 --- a/crates/forge_repo/src/provider/anthropic.rs +++ b/crates/forge_repo/src/provider/anthropic.rs @@ -576,6 +576,7 @@ mod tests { name: ToolName::new("math"), call_id: Some(ToolCallId::new("math-1")), output: ToolOutput::text(serde_json::json!({"result": 4}).to_string()), + modified_files: vec![], }]) .tool_choice(ToolChoice::Call(ToolName::new("math"))); let request = Request::try_from(context) diff --git a/crates/forge_repo/src/provider/google.rs b/crates/forge_repo/src/provider/google.rs index c70b2b28c8..a4dbe77415 100644 --- a/crates/forge_repo/src/provider/google.rs +++ b/crates/forge_repo/src/provider/google.rs @@ -396,6 +396,7 @@ mod tests { name: ToolName::new("math"), call_id: Some(ToolCallId::new("math-1")), output: ToolOutput::text(serde_json::json!({"result": 4}).to_string()), + modified_files: vec![], }]) .tool_choice(ToolChoice::Call(ToolName::new("math"))); diff --git a/crates/forge_services/src/tool_services/fs_write.rs b/crates/forge_services/src/tool_services/fs_write.rs index 976cbb750b..3b63628f52 100644 --- a/crates/forge_services/src/tool_services/fs_write.rs +++ b/crates/forge_services/src/tool_services/fs_write.rs @@ -95,10 +95,8 @@ impl< (None, default_ending) }; - // SNAPSHOT COORDINATION: Capture snapshot before writing if file exists - if file_exists { - self.infra.insert_snapshot(path).await?; - } + // SNAPSHOT COORDINATION: Capture snapshot before writing + self.infra.insert_snapshot(path).await?; // Normalize line endings to match the target style before writing let normalized_content = content diff --git a/crates/forge_services/src/tool_services/plan_create.rs b/crates/forge_services/src/tool_services/plan_create.rs index 81b9077c00..d9259f7f85 100644 --- a/crates/forge_services/src/tool_services/plan_create.rs +++ b/crates/forge_services/src/tool_services/plan_create.rs @@ -7,6 +7,7 @@ use forge_app::{ EnvironmentInfra, FileDirectoryInfra, FileInfoInfra, FileReaderInfra, FileWriterInfra, PlanCreateOutput, PlanCreateService, }; +use forge_domain::SnapshotRepository; /// Creates a new plan file with the specified name, version, and content. Use /// this tool to create structured project plans, task breakdowns, or @@ -27,6 +28,7 @@ impl< + FileReaderInfra + FileWriterInfra + EnvironmentInfra + + SnapshotRepository + Send + Sync, > PlanCreateService for ForgePlanCreate @@ -65,6 +67,9 @@ impl< )); } + // SNAPSHOT COORDINATION: Capture snapshot before writing + self.0.insert_snapshot(&file_path).await?; + // Write the plan file self.0 .write(&file_path, Bytes::from(content)) diff --git a/crates/forge_snaps/src/service.rs b/crates/forge_snaps/src/service.rs index 4fa11b53dc..c27b426342 100644 --- a/crates/forge_snaps/src/service.rs +++ b/crates/forge_snaps/src/service.rs @@ -28,9 +28,14 @@ impl SnapshotService { ForgeFS::create_dir_all(parent).await?; } - let content = ForgeFS::read(&snapshot.path).await?; - let path = snapshot.snapshot_path(Some(self.snapshots_directory.clone())); - ForgeFS::write(path, content).await?; + if ForgeFS::exists(&snapshot.path) { + let content = ForgeFS::read(&snapshot.path).await?; + ForgeFS::write(snapshot_path, content).await?; + } else { + // Write a special marker for "non-existent file" + let marker_path = snapshot_path.with_extension("none"); + ForgeFS::write(marker_path, b"").await?; + } Ok(snapshot) } @@ -43,7 +48,8 @@ impl SnapshotService { while let Some(entry) = dir.next_entry().await? { let filename = entry.file_name().to_string_lossy().to_string(); - if filename.ends_with(".snap") + // Match both .snap (regular) and .none (file didn't exist) + if (filename.ends_with(".snap") || filename.ends_with(".none")) && (latest_filename.is_none() || filename > latest_filename.clone().unwrap()) { latest_filename = Some(filename); @@ -70,9 +76,16 @@ impl SnapshotService { .await? .context(format!("No valid snapshots found for {path:?}"))?; - // Restore the content - let content = ForgeFS::read(&snapshot_path).await?; - ForgeFS::write(&path, content).await?; + if snapshot_path.extension().and_then(|e| e.to_str()) == Some("none") { + // The file didn't exist when snapshot was taken, so delete it if it exists now + if ForgeFS::exists(&path) { + ForgeFS::remove_file(&path).await?; + } + } else { + // Restore the content + let content = ForgeFS::read(&snapshot_path).await?; + ForgeFS::write(&path, content).await?; + } // Remove the used snapshot ForgeFS::remove_file(&snapshot_path).await?; @@ -208,6 +221,29 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_undo_file_creation() -> Result<()> { + // Arrange + let ctx = TestContext::new().await?; + let content = "New file content"; + + // Act + // 1. Snapshot non-existent file + ctx.service.create_snapshot(ctx.test_file.clone()).await?; + + // 2. Create the file + ctx.write_content(content).await?; + assert!(ForgeFS::exists(&ctx.test_file)); + + // 3. Undo creation + ctx.undo_snapshot().await?; + + // Assert + assert!(!ForgeFS::exists(&ctx.test_file)); + + Ok(()) + } + #[tokio::test] async fn test_multiple_snapshots() -> Result<()> { // Arrange diff --git a/shell-plugin/lib/actions/conversation.zsh b/shell-plugin/lib/actions/conversation.zsh index 4a31c8bbf1..e48483389a 100644 --- a/shell-plugin/lib/actions/conversation.zsh +++ b/shell-plugin/lib/actions/conversation.zsh @@ -235,17 +235,70 @@ function _forge_action_conversation_rename() { fi } +# Action handler: Rewind conversation +# Usage: :rewind [] +function _forge_action_rewind() { + local input_text="$1" + local forge_dir=".forge" + + # Create .forge directory if it doesn't exist + if [[ ! -d "$forge_dir" ]]; then + mkdir -p "$forge_dir" || return 1 + fi + + local rewind_file="${forge_dir}/FORGE_REWIND_MSG" + rm -f "$rewind_file" + + # Use _forge_exec_interactive to allow the Rust TUI message picker + # We export FORGE_REWIND_FILE so the rust process knows where to write the content + local -x FORGE_REWIND_FILE="$rewind_file" + + local target_id="$input_text" + if [[ -z "$target_id" ]]; then + target_id="$_FORGE_CONVERSATION_ID" + fi + + _forge_exec_interactive conversation rewind $target_id + + # Check if a message was rewound and we have content + if [[ -f "$rewind_file" ]]; then + local content + content=$(cat "$rewind_file" | tr -d '\r') + if [[ -n "$content" ]]; then + # Pre-populate buffer with the rewound message + # Prefix with : to indicate it's a forge command if the user wants to resubmit + # Actually, if it's a regular message, they might want to just type it. + # But usually in zsh mode they are typing : + BUFFER=": $content" + CURSOR=${#BUFFER} + fi + rm -f "$rewind_file" + fi + + zle reset-prompt +} + # Helper function to clone and switch to conversation function _forge_clone_and_switch() { local clone_target="$1" + local target_id="$clone_target" + if [[ -z "$target_id" ]]; then + target_id="$_FORGE_CONVERSATION_ID" + fi + # Store original conversation ID to check if we're cloning current conversation local original_conversation_id="$_FORGE_CONVERSATION_ID" # Execute clone command - _forge_log info "Cloning conversation \033[1m${clone_target}\033[0m" + if [[ -n "$target_id" ]]; then + _forge_log info "Cloning conversation \033[1m${target_id}\033[0m" + else + _forge_log info "Cloning conversation" + fi + local clone_output - clone_output=$($_FORGE_BIN conversation clone "$clone_target" 2>&1) + clone_output=$($_FORGE_BIN conversation clone "$target_id" 2>&1) local clone_exit_code=$? if [[ $clone_exit_code -eq 0 ]]; then diff --git a/shell-plugin/lib/dispatcher.zsh b/shell-plugin/lib/dispatcher.zsh index c5f13db0e0..9eee4bc00f 100644 --- a/shell-plugin/lib/dispatcher.zsh +++ b/shell-plugin/lib/dispatcher.zsh @@ -241,6 +241,14 @@ function forge-accept-line() { conversation-rename) _forge_action_conversation_rename "$input_text" ;; + rewind) + _forge_action_rewind "$input_text" + local action_status=$? + _forge_osc133_emit "D;$action_status" + _forge_osc133_emit "A" + # Note: rewind action intentionally modifies BUFFER and handles its own prompt reset + return $action_status + ;; copy) _forge_action_copy ;;