From 9869aad322d754da120f21fd4e79c9c233fbf15f Mon Sep 17 00:00:00 2001 From: Gabriel Date: Mon, 1 Jun 2026 17:44:50 +0800 Subject: [PATCH] feat(provider): provider support client --- Cargo.lock | 1 + Cargo.toml | 1 + src/core/client/ai_client.rs | 109 +++++++++++++ src/core/providers/gemini/mod.rs | 2 + src/core/providers/gemini/provider.rs | 72 +++++++++ src/core/providers/ollama/mod.rs | 2 + src/core/providers/ollama/provider.rs | 62 ++++++++ src/core/transport/http_transoprt.rs | 27 ++++ src/core/transport/sse.rs | 221 ++++++++++++++++++++++++++ 9 files changed, 497 insertions(+) create mode 100644 src/core/client/ai_client.rs create mode 100644 src/core/providers/gemini/mod.rs create mode 100644 src/core/providers/gemini/provider.rs create mode 100644 src/core/providers/ollama/mod.rs create mode 100644 src/core/providers/ollama/provider.rs create mode 100644 src/core/transport/sse.rs diff --git a/Cargo.lock b/Cargo.lock index c657879..1c4721d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -33,6 +33,7 @@ dependencies = [ "rust-i18n", "serde", "serde_json", + "serde_urlencoded", "sha2 0.11.0", "tar", "thiserror 2.0.18", diff --git a/Cargo.toml b/Cargo.toml index 46a4097..5bef96a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ futures = "0.3.32" futures-util = "0.3.32" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.150" +serde_urlencoded = "0.7.1" rmcp = "1.7.0" rusqlite = { version = "0.40.0", features = ["bundled"] } ratatui = "0.30.0" diff --git a/src/core/client/ai_client.rs b/src/core/client/ai_client.rs new file mode 100644 index 0000000..daf2e6c --- /dev/null +++ b/src/core/client/ai_client.rs @@ -0,0 +1,109 @@ +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::Stream; +use reqwest::Method; +use serde::{de::DeserializeOwned, Serialize}; +use std::collections::HashMap; +use std::pin::Pin; + +use crate::core::client::base_client::BaseClient; +use crate::core::client::config::ClientConfig; +use crate::core::transport::http_transoprt::HttpTransport; +use crate::core::transport::retry::RetryPolicy; +use crate::core::transport::TransportRequest; + +#[derive(Clone)] +pub struct AiClient { + transport: HttpTransport, +} + +impl AiClient { + pub fn new(config: ClientConfig, retry: RetryPolicy) -> Result { + let base_client = BaseClient::new(config)?; + let transport = HttpTransport::new(base_client, retry); + + Ok(Self { transport }) + } + + pub async fn post_json( + &self, + url: &str, + headers: HashMap, + body: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let body_bytes = serde_json::to_vec(body) + .map_err(|e| anyhow!("Failed to serialize request body: {}", e))?; + + let request = TransportRequest { + method: Method::POST, + url: url.to_string(), + headers, + body: Some(Bytes::from(body_bytes)), + }; + + let response = self.transport.execute(request).await?; + + serde_json::from_slice(&response.body) + .map_err(|e| anyhow!("Failed to deserialize response: {}", e)) + } + + pub async fn get_json( + &self, + url: &str, + headers: HashMap, + query: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let query_string = serde_urlencoded::to_string(query) + .map_err(|e| anyhow!("Failed to serialize query: {}", e))?; + + let url_with_query = if query_string.is_empty() { + url.to_string() + } else { + format!("{}?{}", url, query_string) + }; + + let request = TransportRequest { + method: Method::GET, + url: url_with_query, + headers, + body: None, + }; + + let response = self.transport.execute(request).await?; + + serde_json::from_slice(&response.body) + .map_err(|e| anyhow!("Failed to deserialize response: {}", e)) + } + + pub async fn post_stream( + &self, + url: &str, + headers: HashMap, + body: &impl Serialize, + ) -> Result> + Send + 'static>>> + where + T: DeserializeOwned + Send + 'static, + { + let body_bytes = serde_json::to_vec(body) + .map_err(|e| anyhow!("Failed to serialize request body: {}", e))?; + + let request = TransportRequest { + method: Method::POST, + url: url.to_string(), + headers, + body: Some(Bytes::from(body_bytes)), + }; + + self.transport.execute_stream(request).await.map_err(|e| anyhow!(e)) + } + + pub fn transport(&self) -> &HttpTransport { + &self.transport + } +} \ No newline at end of file diff --git a/src/core/providers/gemini/mod.rs b/src/core/providers/gemini/mod.rs new file mode 100644 index 0000000..8b9cf10 --- /dev/null +++ b/src/core/providers/gemini/mod.rs @@ -0,0 +1,2 @@ +pub mod provider; +pub use provider::GeminiProvider; \ No newline at end of file diff --git a/src/core/providers/gemini/provider.rs b/src/core/providers/gemini/provider.rs new file mode 100644 index 0000000..46e4ecf --- /dev/null +++ b/src/core/providers/gemini/provider.rs @@ -0,0 +1,72 @@ +use anyhow::Result; +use async_trait::async_trait; +use std::collections::HashMap; +use crate::core::client::{AiClient, ClientConfig}; +use crate::core::transport::retry::RetryPolicy; +use crate::core::providers::provider::Provider; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +pub struct GeminiProvider { + client: AiClient, + api_key: String, + model: String, +} + +impl GeminiProvider { + pub fn new(api_key: String, model: String) -> Result { + let config = ClientConfig::default(); + let retry = RetryPolicy::default(); + let client = AiClient::new(config, retry)?; + + Ok(Self { + client, + api_key, + model, + }) + } +} + +#[async_trait] +impl Provider for GeminiProvider { + fn client(&self) -> &AiClient { + &self.client + } + + async fn complete(&self, request: CompletionRequest) -> Result { + let mut headers = HashMap::new(); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + let url = format!("https://generativelanguage.googleapis.com/v1/models/{model}:generateContent?key={api_key}", + model = self.model, + api_key = self.api_key + ); + + let response: serde_json::Value = self.client + .post_json(&url, headers, &serde_json::json!({ + "contents": request.messages.iter().map(|m| { + serde_json::json!({ + "role": match m.role { + crate::core::providers::types::Role::System => "system", + crate::core::providers::types::Role::User => "user", + crate::core::providers::types::Role::Assistant => "model", + }, + "parts": [serde_json::json!({"text": m.content})] + }) + }).collect::>(), + "generationConfig": { + "temperature": request.temperature.unwrap_or(0.7), + "maxOutputTokens": request.max_tokens, + } + })) + .await + .map_err(|e| anyhow::anyhow!("Gemini API error: {}", e))?; + + let content = response["candidates"][0]["content"]["parts"][0]["text"].as_str().unwrap_or("").to_string(); + + Ok(CompletionResponse { + content, + finish_reason: None, + usage: None, + }) + } +} \ No newline at end of file diff --git a/src/core/providers/ollama/mod.rs b/src/core/providers/ollama/mod.rs new file mode 100644 index 0000000..ff9ae19 --- /dev/null +++ b/src/core/providers/ollama/mod.rs @@ -0,0 +1,2 @@ +pub mod provider; +pub use provider::OllamaProvider; \ No newline at end of file diff --git a/src/core/providers/ollama/provider.rs b/src/core/providers/ollama/provider.rs new file mode 100644 index 0000000..9032181 --- /dev/null +++ b/src/core/providers/ollama/provider.rs @@ -0,0 +1,62 @@ +use anyhow::Result; +use async_trait::async_trait; +use std::collections::HashMap; +use crate::core::client::{AiClient, ClientConfig}; +use crate::core::transport::retry::RetryPolicy; +use crate::core::providers::provider::Provider; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +pub struct OllamaProvider { + client: AiClient, + base_url: String, +} + +impl OllamaProvider { + pub fn new(base_url: String) -> Result { + let config = ClientConfig::default(); + let retry = RetryPolicy::default(); + let client = AiClient::new(config, retry)?; + + Ok(Self { + client, + base_url, + }) + } +} + +#[async_trait] +impl Provider for OllamaProvider { + fn client(&self) -> &AiClient { + &self.client + } + + async fn complete(&self, request: CompletionRequest) -> Result { + let mut headers = HashMap::new(); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + let url = format!("{}/api/chat", self.base_url); + + let response: serde_json::Value = self.client + .post_json(&url, headers, &serde_json::json!({ + "model": request.model, + "messages": request.messages, + "stream": request.stream, + "temperature": request.temperature, + "max_tokens": request.max_tokens, + })) + .await + .map_err(|e| anyhow::anyhow!("Ollama API error: {}", e))?; + + let content = response["message"]["content"].as_str().unwrap_or("").to_string(); + let finish_reason = match response["done"].as_bool() { + Some(true) => Some(crate::core::providers::types::FinishReason::Stop), + _ => None, + }; + + Ok(CompletionResponse { + content, + finish_reason, + usage: None, + }) + } +} \ No newline at end of file diff --git a/src/core/transport/http_transoprt.rs b/src/core/transport/http_transoprt.rs index 116136b..ebd1fa5 100644 --- a/src/core/transport/http_transoprt.rs +++ b/src/core/transport/http_transoprt.rs @@ -1,11 +1,16 @@ use std::time::Duration; +use std::pin::Pin; +use futures_util::Stream; use tokio::time::sleep; use crate::core::transport::errors::TransportError; use crate::core::transport::request::TransportRequest; use crate::core::transport::response::TransportResponse; use crate::core::transport::retry::RetryPolicy; +use crate::core::transport::sse::{SseStream, sse_stream_to_json}; use crate::core::client::base_client::BaseClient; +use serde::de::DeserializeOwned; + #[derive(Clone)] pub struct HttpTransport { client: BaseClient, @@ -100,4 +105,26 @@ impl HttpTransport { fn backoff(&self, retry: usize) -> Duration { self.retry.base_delay * retry as u32 } + + pub async fn execute_stream( + &self, + req: TransportRequest, + ) -> Result> + Send + 'static>>, TransportError> + where + T: DeserializeOwned + Send + 'static, + { + let request = self.build_request(req)?; + + let response = self.client.execute(request).await?; + + let status = response.status(); + if status.is_client_error() || status.is_server_error() { + return Err(TransportError::Server(status.to_string())); + } + + let sse_stream = SseStream::new(response); + let json_stream = sse_stream_to_json(sse_stream).await; + + Ok(json_stream) + } } \ No newline at end of file diff --git a/src/core/transport/sse.rs b/src/core/transport/sse.rs new file mode 100644 index 0000000..6f83f8b --- /dev/null +++ b/src/core/transport/sse.rs @@ -0,0 +1,221 @@ +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::{Stream, StreamExt}; +use reqwest::Response; +use serde::de::DeserializeOwned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug, Clone, PartialEq)] +pub struct SseEvent { + pub event: Option, + pub data: String, + pub id: Option, +} + +#[derive(Debug)] +pub enum SseError { + Http(reqwest::Error), + InvalidData(String), + Deserialization(String), +} + +impl std::fmt::Display for SseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SseError::Http(e) => write!(f, "HTTP error: {}", e), + SseError::InvalidData(s) => write!(f, "Invalid SSE data: {}", s), + SseError::Deserialization(s) => write!(f, "Deserialization error: {}", s), + } + } +} + +impl std::error::Error for SseError {} + +pub struct SseStream { + inner: Pin> + Send>>, + buffer: Vec, +} + +impl SseStream { + pub fn new(response: Response) -> Self { + Self { + inner: Box::pin(response.bytes_stream()), + buffer: Vec::new(), + } + } + + fn parse_chunk(&mut self, chunk: &[u8]) -> Vec { + self.buffer.extend_from_slice(chunk); + + let mut events = Vec::new(); + let mut start = 0; + + while let Some(end) = self.buffer[start..].windows(2).position(|w| w == b"\r\n" || w == b"\n\n") { + let end_pos = start + end + 2; + let line = &self.buffer[start..end_pos]; + + let event = self.parse_line(line); + if let Some(e) = event { + events.push(e); + } + + start = end_pos; + } + + if start > 0 { + self.buffer = self.buffer[start..].to_vec(); + } + + events + } + + fn parse_line(&self, line: &[u8]) -> Option { + let line = String::from_utf8_lossy(line).trim().to_string(); + + if line.is_empty() { + return None; + } + + let parts: Vec<&str> = line.splitn(2, ':').collect(); + if parts.is_empty() { + return None; + } + + let field = parts[0].trim(); + let value = parts.get(1).map(|s| s.trim_start_matches(' ')).unwrap_or(""); + + let mut event = SseEvent { + event: None, + data: String::new(), + id: None, + }; + + match field { + "event" => event.event = Some(value.to_string()), + "data" => event.data = value.to_string(), + "id" => event.id = Some(value.to_string()), + _ => {} + } + + Some(event) + } +} + +impl Stream for SseStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let mut events = self.parse_chunk(&chunk); + if !events.is_empty() { + return Poll::Ready(Some(Ok(events.remove(0)))); + } + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(SseError::Http(e)))); + } + Poll::Ready(None) => { + if !self.buffer.is_empty() { + let mut events = self.parse_chunk(&[]); + if !events.is_empty() { + return Poll::Ready(Some(Ok(events.remove(0)))); + } + } + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +pub async fn parse_sse_chunk(chunk: &[u8]) -> Result { + let line = String::from_utf8_lossy(chunk).trim().to_string(); + + let parts: Vec<&str> = line.splitn(2, ':').collect(); + if parts.is_empty() { + return Err(anyhow!("Empty SSE line")); + } + + let field = parts[0].trim(); + let value = parts.get(1).map(|s| s.trim_start_matches(' ')).unwrap_or(""); + + Ok(SseEvent { + event: if field == "event" { Some(value.to_string()) } else { None }, + data: if field == "data" { value.to_string() } else { String::new() }, + id: if field == "id" { Some(value.to_string()) } else { None }, + }) +} + +pub async fn parse_sse_raw(raw: &str) -> Result> { + let mut events = Vec::new(); + let mut current_event = SseEvent { + event: None, + data: String::new(), + id: None, + }; + + for line in raw.lines() { + let line = line.trim(); + + if line.is_empty() { + if !current_event.data.is_empty() { + events.push(current_event.clone()); + } + current_event = SseEvent { + event: None, + data: String::new(), + id: None, + }; + continue; + } + + let parts: Vec<&str> = line.splitn(2, ':').collect(); + if parts.len() < 2 { + continue; + } + + let field = parts[0].trim(); + let value = parts[1].trim_start_matches(' '); + + match field { + "event" => current_event.event = Some(value.to_string()), + "data" => current_event.data += value, + "id" => current_event.id = Some(value.to_string()), + _ => {} + } + } + + if !current_event.data.is_empty() { + events.push(current_event); + } + + Ok(events) +} + +pub async fn sse_stream_to_json( + stream: SseStream, +) -> Pin> + Send + 'static>> +where + T: DeserializeOwned + Send + 'static, +{ + let mapped = stream.map(move |event_result| { + match event_result { + Ok(event) => { + if event.data.is_empty() { + return Err(anyhow!("Empty SSE data")); + } + serde_json::from_str::(&event.data) + .map_err(|e| anyhow!("Failed to deserialize SSE event: {}", e)) + } + Err(e) => Err(anyhow!("SSE error: {}", e)), + } + }); + + Box::pin(mapped) +} \ No newline at end of file