diff --git a/Cargo.lock b/Cargo.lock index cc5cdae..52b6bbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,7 @@ name = "OmegaCode" version = "0.1.0-alpha" dependencies = [ "anyhow", + "async-trait", "bytes", "chrono", "clap", @@ -22,6 +23,7 @@ dependencies = [ "log", "md5", "once_cell", + "rand 0.10.1", "ratatui 0.30.0", "ratatui-kit", "reqwest", @@ -35,6 +37,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", + "tracing", "zip", ] @@ -342,6 +345,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + [[package]] name = "chrono" version = "0.4.44" @@ -1116,6 +1130,7 @@ dependencies = [ "js-sys", "libc", "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", "wasip3", "wasm-bindgen", @@ -1840,6 +1855,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2286,6 +2311,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + [[package]] name = "rand_chacha" version = "0.9.0" @@ -2311,6 +2347,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" + [[package]] name = "ratatui" version = "0.29.0" @@ -2526,6 +2568,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "percent-encoding", "pin-project-lite", "quinn", @@ -2534,6 +2577,7 @@ dependencies = [ "rustls-platform-verifier", "serde", "serde_json", + "serde_urlencoded", "sync_wrapper", "tokio", "tokio-rustls", @@ -2943,6 +2987,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" @@ -3624,6 +3680,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.24" diff --git a/Cargo.toml b/Cargo.toml index 3eab6d6..a25a366 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ anyhow = "1.0.102" tokio = { version = "1.52.3", features = ["full"] } colored = "3.1.1" indicatif = "0.18.4" -reqwest = { version = "0.13.3", features = ["stream", "json"] } +reqwest = { version = "0.13.3", features = ["stream", "json", "multipart", "query"] } tokio-stream = "0.1.18" futures = "0.3.32" futures-util = "0.3.32" @@ -40,4 +40,7 @@ tar = "0.4.46" thiserror = "2.0.18" hex = "0.4.3" dotenv = "0.15.0" -bytes = "1.11.1" \ No newline at end of file +bytes = "1.11.1" +rand = "0.10.1" +tracing = "0.1.44" +async-trait = "0.1.89" \ No newline at end of file diff --git a/src/core/chat/common/call.rs b/src/core/chat/common/call.rs new file mode 100644 index 0000000..39f5b99 --- /dev/null +++ b/src/core/chat/common/call.rs @@ -0,0 +1,355 @@ +// ============================================================ +// ai/call/mod.rs +// ============================================================ + +use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::stream::BoxStream; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +// ============================================================ +// Role +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +// ============================================================ +// Message +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + + pub content: Vec, +} + +// ============================================================ +// Content Part +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { + text: String, + }, + + #[serde(rename = "image")] + Image { + url: String, + }, + + #[serde(rename = "thinking")] + Thinking { + text: String, + }, + + #[serde(rename = "tool_call")] + ToolCall { + id: String, + + name: String, + + arguments: Value, + }, + + #[serde(rename = "tool_result")] + ToolResult { + tool_call_id: String, + + content: String, + }, +} + +// ============================================================ +// Tool +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub name: String, + + pub description: Option, + + pub parameters: Value, +} + +// ============================================================ +// Reasoning Config +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReasoningConfig { + pub enabled: bool, + + pub budget_tokens: Option, +} + +// ============================================================ +// Request +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmRequest { + pub model: String, + + pub messages: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + #[serde(default)] + pub stream: bool, + + #[serde(default)] + pub tools: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + #[serde(default)] + pub metadata: Value, +} + +// ============================================================ +// Usage +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + pub prompt_tokens: u32, + + pub completion_tokens: u32, + + pub total_tokens: u32, + + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +// ============================================================ +// Response +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmResponse { + pub model: String, + + pub message: Message, + + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + + #[serde(default)] + pub metadata: Value, +} + +// ============================================================ +// Tool Call Event +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallEvent { + pub id: String, + + pub name: String, + + pub arguments: Value, +} + +// ============================================================ +// Tool Result Event +// ============================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultEvent { + pub tool_call_id: String, + + pub content: String, +} + +// ============================================================ +// Stream Event +// ============================================================ + +#[derive(Debug, Clone)] +pub enum LlmEvent { + Text(String), + + Reasoning(String), + + ToolCall(ToolCallEvent), + + ToolResult(ToolResultEvent), + + Usage(TokenUsage), + + Metadata(Value), + + Binary(Bytes), + + Done, +} + +// ============================================================ +// Stream Type +// ============================================================ + +pub type LlmStream = +BoxStream<'static, Result>; + +// ============================================================ +// Provider Trait +// ============================================================ + +#[async_trait] +pub trait LlmProvider: +Send + Sync + 'static +{ + /// provider name + fn name(&self) -> &'static str; + + /// supported models + async fn models(&self) -> Result>; + + /// supports reasoning + fn supports_reasoning(&self) -> bool { + false + } + + /// supports tools + fn supports_tools(&self) -> bool { + false + } + + /// supports stream + fn supports_stream(&self) -> bool { + true + } + + /// normal call + async fn call( + &self, + request: LlmRequest, + ) -> Result; + + /// stream call + async fn stream( + &self, + request: LlmRequest, + ) -> Result; +} + +// ============================================================ +// Provider Registry +// ============================================================ + +use std::collections::HashMap; +use std::sync::Arc; + +pub struct ProviderRegistry { + providers: + HashMap>, +} + +impl ProviderRegistry { + pub fn new() -> Self { + Self { + providers: HashMap::new(), + } + } + + pub fn register

( + mut self, + provider: P, + ) -> Self + where + P: LlmProvider, + { + self.providers.insert( + provider.name().to_string(), + Arc::new(provider), + ); + + self + } + + pub fn get( + &self, + name: &str, + ) -> Option> { + self.providers.get(name).cloned() + } +} + +// ============================================================ +// Call Manager +// ============================================================ + +pub struct CallManager { + registry: ProviderRegistry, +} + +impl CallManager { + pub fn new( + registry: ProviderRegistry, + ) -> Self { + Self { registry } + } + + pub async fn call( + &self, + provider: &str, + request: LlmRequest, + ) -> Result { + let provider = self + .registry + .get(provider) + .ok_or_else(|| { + anyhow::anyhow!( + "provider not found: {}", + provider + ) + })?; + + provider.call(request).await + } + + pub async fn stream( + &self, + provider: &str, + request: LlmRequest, + ) -> Result { + let provider = self + .registry + .get(provider) + .ok_or_else(|| { + anyhow::anyhow!( + "provider not found: {}", + provider + ) + })?; + + provider.stream(request).await + } +} \ No newline at end of file diff --git a/src/core/chat/common/mod.rs b/src/core/chat/common/mod.rs new file mode 100644 index 0000000..7d2e8c6 --- /dev/null +++ b/src/core/chat/common/mod.rs @@ -0,0 +1,2 @@ +mod call; +mod request; \ No newline at end of file diff --git a/src/core/chat/common/request.rs b/src/core/chat/common/request.rs new file mode 100644 index 0000000..620c499 --- /dev/null +++ b/src/core/chat/common/request.rs @@ -0,0 +1,654 @@ +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::stream::Stream; +use rand::{Rng, RngExt}; +use reqwest::{ + multipart, + Client, + Method, + Request, + RequestBuilder, + Response, + StatusCode, +}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::time::Duration; +use tokio::fs; +use tracing::{info, warn}; + +// ============================================================ +// Retry Config +// ============================================================ + +#[derive(Debug, Clone)] +pub struct RetryConfig { + pub max_retries: u32, + + pub retry_interval: u64, + + pub use_exponential_backoff: bool, + + pub backoff_factor: f64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + retry_interval: 1000, + use_exponential_backoff: true, + backoff_factor: 2.0, + } + } +} + +// ============================================================ +// Request Context +// ============================================================ + +#[derive(Debug, Clone)] +pub struct RequestContext { + pub method: Method, + pub url: String, +} + +// ============================================================ +// Request Result +// ============================================================ + +pub enum RequestResult { + Response(Response), + Error(reqwest::Error), +} + +// ============================================================ +// Retry Predicate +// ============================================================ + +pub type RetryPredicate = +Box< + dyn Fn(&RequestContext, &RequestResult) -> bool + + Send + + Sync + + 'static, +>; + +// ============================================================ +// Request Factory +// ============================================================ + +type BoxFuture = +Pin + Send>>; + +pub type RequestFactory = +Box< + dyn Fn() -> BoxFuture> + + Send + + Sync, +>; + +// ============================================================ +// Retry Middleware +// ============================================================ + +pub struct RetryMiddleware { + pub config: RetryConfig, + + pub predicate: Option, +} + +impl Default for RetryMiddleware { + fn default() -> Self { + Self { + config: RetryConfig::default(), + predicate: None, + } + } +} + +impl RetryMiddleware { + pub fn new(config: RetryConfig) -> Self { + Self { + config, + predicate: None, + } + } + + pub fn with_retry_predicate( + mut self, + f: F, + ) -> Self + where + F: Fn(&RequestContext, &RequestResult) -> bool + + Send + + Sync + + 'static, + { + self.predicate = Some(Box::new(f)); + + self + } + + // ======================================================== + // Full Jitter Backoff + // ======================================================== + + fn calculate_wait_time( + &self, + attempts: u32, + ) -> Duration { + let base = self.config.retry_interval as f64; + + let max_wait = if self + .config + .use_exponential_backoff + { + base + * self + .config + .backoff_factor + .powi(attempts as i32) + } else { + base + }; + + let mut rng = rand::rng(); + + let wait_ms = + rng.random_range(0.0..=max_wait); + + Duration::from_millis(wait_ms as u64) + } + + // ======================================================== + // Retry-After + // ======================================================== + + fn retry_after( + response: &Response, + ) -> Option { + let header = response + .headers() + .get("retry-after")?; + + let value = header.to_str().ok()?; + + let secs = value.parse::().ok()?; + + Some(Duration::from_secs(secs)) + } + + // ======================================================== + // Default Retry Rule + // ======================================================== + + fn should_retry( + &self, + ctx: &RequestContext, + result: &RequestResult, + ) -> bool { + if let Some(predicate) = &self.predicate { + return predicate(ctx, result); + } + + // 默认只 retry 幂等请求 + let retryable_method = matches!( + ctx.method, + Method::GET + | Method::HEAD + | Method::PUT + | Method::DELETE + ); + + if !retryable_method { + return false; + } + + match result { + RequestResult::Response(resp) => { + resp.status().is_server_error() + || resp.status() + == StatusCode::TOO_MANY_REQUESTS + } + + RequestResult::Error(err) => { + err.is_timeout() + || err.is_connect() + || err + .status() + .map(|s| s.is_server_error()) + .unwrap_or(false) + } + } + } + + // ======================================================== + // Execute + // ======================================================== + + pub async fn execute( + &self, + client: &Client, + factory: RequestFactory, + ) -> Result { + let mut attempt = 0; + + loop { + let (ctx, request) = factory().await?; + + info!( + method = %ctx.method, + url = %ctx.url, + attempt, + "sending request" + ); + + let result = + match client.execute(request).await { + Ok(resp) => { + RequestResult::Response(resp) + } + + Err(err) => { + RequestResult::Error(err) + } + }; + + let should_retry = + self.should_retry(&ctx, &result); + + if !should_retry + || attempt >= self.config.max_retries + { + return match result { + RequestResult::Response(resp) => { + Ok(resp) + } + + RequestResult::Error(err) => { + Err(anyhow!(err)) + } + }; + } + + attempt += 1; + + let wait = match &result { + RequestResult::Response(resp) => { + Self::retry_after(resp) + .unwrap_or_else(|| { + self.calculate_wait_time( + attempt, + ) + }) + } + + _ => self.calculate_wait_time(attempt), + }; + + warn!( + method = %ctx.method, + url = %ctx.url, + attempt, + wait_ms = wait.as_millis(), + "retrying request" + ); + + tokio::time::sleep(wait).await; + } + } +} + +// ============================================================ +// Http Client +// ============================================================ + +pub struct HttpClient { + client: Client, + + retry_middleware: RetryMiddleware, +} + +impl HttpClient { + pub fn new( + retry_config: Option, + ) -> Result { + let client = Client::builder() + .connect_timeout(Duration::from_secs( + 10, + )) + .timeout(Duration::from_secs(300)) + .pool_idle_timeout(Duration::from_secs( + 90, + )) + .tcp_keepalive(Duration::from_secs(60)) + .build()?; + + Ok(Self { + client, + + retry_middleware: retry_config + .map(RetryMiddleware::new) + .unwrap_or_default(), + }) + } + + // ======================================================== + // Core Send + // ======================================================== + + pub async fn send( + &self, + builder: RequestBuilder, + ) -> Result { + let builder = builder + .try_clone() + .ok_or_else(|| { + anyhow!( + "request builder cannot be cloned" + ) + })?; + + let factory: RequestFactory = + Box::new(move || { + let builder = builder + .try_clone() + .ok_or_else(|| { + anyhow!( + "request builder clone failed" + ) + }); + + Box::pin(async move { + let builder = builder?; + + let request = builder.build()?; + + let ctx = RequestContext { + method: request.method().clone(), + + url: request.url().to_string(), + }; + + Ok((ctx, request)) + }) + }); + + self.retry_middleware + .execute(&self.client, factory) + .await + } + + // ======================================================== + // GET JSON + // ======================================================== + + pub async fn get_json( + &self, + url: &str, + query: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let response = self + .send(self.client.get(url).query(query)) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // POST JSON + // ======================================================== + + pub async fn post_json( + &self, + url: &str, + body: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let response = self + .send(self.client.post(url).json(body)) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // GET TEXT + // ======================================================== + + pub async fn get_text( + &self, + url: &str, + ) -> Result { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.text().await?) + } + + // ======================================================== + // GET BYTES + // ======================================================== + + pub async fn get_bytes( + &self, + url: &str, + ) -> Result { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.bytes().await?) + } + + // ======================================================== + // SSE / STREAM + // ======================================================== + + pub async fn get_stream( + &self, + url: &str, + ) -> Result< + impl Stream< + Item = Result, + >, + > { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.bytes_stream()) + } + + // ======================================================== + // Multipart Upload + // ======================================================== + + pub async fn upload_file( + &self, + url: &str, + field_name: &str, + file_path: impl Into, + ) -> Result + where + T: DeserializeOwned, + { + let path = file_path.into(); + + let client = self.client.clone(); + + let url = url.to_string(); + + let field_name = field_name.to_string(); + + let factory: RequestFactory = + Box::new(move || { + let client = client.clone(); + + let path = path.clone(); + + let url = url.clone(); + + let field_name = + field_name.clone(); + + Box::pin(async move { + let bytes = + fs::read(&path).await?; + + let filename = path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = + multipart::Part::bytes( + bytes, + ) + .file_name(filename); + + let form = + multipart::Form::new() + .part( + field_name, + part, + ); + + let request = client + .post(url) + .multipart(form) + .build()?; + + let ctx = RequestContext { + method: request + .method() + .clone(), + + url: request + .url() + .to_string(), + }; + + Ok((ctx, request)) + }) + }); + + let response = self + .retry_middleware + .execute(&self.client, factory) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // Stream Upload + // ======================================================== + + pub async fn upload_stream( + &self, + url: &str, + stream_factory: impl Fn() -> S + + Send + + Sync + + 'static, + ) -> Result + where + T: DeserializeOwned, + + S: Stream< + Item = Result< + Bytes, + std::io::Error, + >, + > + Send + + 'static, + { + let client = self.client.clone(); + + let url = url.to_string(); + + let factory: RequestFactory = + Box::new(move || { + let client = client.clone(); + + let url = url.clone(); + + let stream = stream_factory(); + + Box::pin(async move { + let body = + reqwest::Body::wrap_stream( + stream, + ); + + let request = client + .post(url) + .body(body) + .build()?; + + let ctx = RequestContext { + method: request + .method() + .clone(), + + url: request + .url() + .to_string(), + }; + + Ok((ctx, request)) + }) + }); + + let response = self + .retry_middleware + .execute(&self.client, factory) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // Retry Predicate + // ======================================================== + + pub fn with_retry_predicate( + mut self, + predicate: F, + ) -> Self + where + F: Fn( + &RequestContext, + &RequestResult, + ) -> bool + + Send + + Sync + + 'static, + { + self.retry_middleware = self + .retry_middleware + .with_retry_predicate(predicate); + + self + } + + // ======================================================== + // Inner Client + // ======================================================== + + pub fn inner(&self) -> &Client { + &self.client + } +} \ No newline at end of file diff --git a/src/core/chat/mod.rs b/src/core/chat/mod.rs index 0278328..18f4556 100644 --- a/src/core/chat/mod.rs +++ b/src/core/chat/mod.rs @@ -1,2 +1,3 @@ mod client; -mod provider; \ No newline at end of file +mod provider; +mod common; \ No newline at end of file diff --git a/src/core/chat/provider/common/mod.rs b/src/core/chat/provider/common/mod.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/core/chat/provider/mod.rs b/src/core/chat/provider/mod.rs index 556dda3..199127c 100644 --- a/src/core/chat/provider/mod.rs +++ b/src/core/chat/provider/mod.rs @@ -1,4 +1,3 @@ mod anthropic; -mod common; mod openai; mod deepseek; \ No newline at end of file