diff --git a/src/core/client/mod.rs b/src/core/client/mod.rs index 432ef50..98e8f40 100644 --- a/src/core/client/mod.rs +++ b/src/core/client/mod.rs @@ -1,5 +1,10 @@ mod config; pub mod base_client; +pub mod ai_client; + +pub use config::ClientConfig; +pub use base_client::BaseClient; +pub use ai_client::AiClient; use anyhow::{anyhow, Result}; use bytes::Bytes; @@ -244,10 +249,10 @@ impl RetryMiddleware { let (ctx, request) = factory().await?; info!( - method = %ctx.method, - url = %ctx.url, - attempt, - "sending request" + "sending request: method={}, url={}, attempt={}", + ctx.method, + ctx.url, + attempt ); let result = @@ -294,11 +299,11 @@ impl RetryMiddleware { }; warn!( - method = %ctx.method, - url = %ctx.url, + "retrying request: method={}, url={}, attempt={}, wait_ms={}", + ctx.method, + ctx.url, attempt, - wait_ms = wait.as_millis(), - "retrying request" + wait.as_millis() ); tokio::time::sleep(wait).await; diff --git a/src/core/providers/claude_messages/adapter.rs b/src/core/providers/claude_messages/adapter.rs deleted file mode 100644 index 70c18c7..0000000 --- a/src/core/providers/claude_messages/adapter.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::core::providers::claude_messages::request::ClaudeMessage; -use crate::core::providers::types::{Message, Role}; - -impl From for ClaudeMessage { - fn from(value: Message) -> Self { - Self { - role: match value.role { - Role::User => "user", - Role::Assistant => "assistant", - - Role::System => "user", - } - .into(), - - content: value.content, - } - } -} \ No newline at end of file diff --git a/src/core/providers/claude_messages/mod.rs b/src/core/providers/claude_messages/mod.rs index 5611b82..0654bea 100644 --- a/src/core/providers/claude_messages/mod.rs +++ b/src/core/providers/claude_messages/mod.rs @@ -1,4 +1,7 @@ mod request; mod response; -mod adapter; -mod provider; \ No newline at end of file +mod provider; + +pub use provider::ClaudeProvider; +pub use request::ClaudeRequest; +pub use response::ClaudeResponse; \ No newline at end of file diff --git a/src/core/providers/claude_messages/provider.rs b/src/core/providers/claude_messages/provider.rs index f30a331..d5d4f9a 100644 --- a/src/core/providers/claude_messages/provider.rs +++ b/src/core/providers/claude_messages/provider.rs @@ -1,57 +1,52 @@ +use anyhow::Result; use async_trait::async_trait; -use reqwest::Client; +use std::collections::HashMap; +use crate::core::client::{AiClient, ClientConfig}; +use crate::core::transport::retry::RetryPolicy; use crate::core::providers::claude_messages::request::ClaudeRequest; use crate::core::providers::claude_messages::response::ClaudeResponse; use crate::core::providers::provider::Provider; use crate::core::providers::types::{CompletionRequest, CompletionResponse}; pub struct ClaudeProvider { - client: Client, - + client: AiClient, api_key: String, } impl ClaudeProvider { - pub fn new( - api_key: String, - ) -> Self { - Self { - client: Client::new(), + pub fn new(api_key: String) -> Result { + let config = ClientConfig::default(); + let retry = RetryPolicy::default(); + let client = AiClient::new(config, retry)?; + + Ok(Self { + client, api_key, - } + }) } } #[async_trait] impl Provider for ClaudeProvider { - async fn complete( - &self, - request: CompletionRequest, - ) -> Result { - - let body = - ClaudeRequest::from(request); - - let resp = self - .client - .post( - "https://api.anthropic.com/v1/messages" - ) - .header( - "x-api-key", - &self.api_key, - ) - .header( - "anthropic-version", - "2023-06-01", - ) - .json(&body) - .send() - .await?; - - let response: ClaudeResponse = - resp.json().await?; - - Ok(response.into()) + fn client(&self) -> &AiClient { + &self.client + } + + async fn complete(&self, request: CompletionRequest) -> Result { + let claude_request = ClaudeRequest::from_completion_request(request); + + let mut headers = HashMap::new(); + headers.insert("x-api-key".to_string(), self.api_key.clone()); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + let url = "https://api.anthropic.com/v1/messages"; + + let response: ClaudeResponse = self.client + .post_json(url, headers, &claude_request) + .await + .map_err(|e| anyhow::anyhow!("Claude API error: {}", e))?; + + Ok(CompletionResponse::from(response)) } } \ No newline at end of file diff --git a/src/core/providers/claude_messages/request.rs b/src/core/providers/claude_messages/request.rs index 1115e89..ed74908 100644 --- a/src/core/providers/claude_messages/request.rs +++ b/src/core/providers/claude_messages/request.rs @@ -1,22 +1,57 @@ use serde::Serialize; +use crate::core::providers::types::{CompletionRequest, Message, Role}; #[derive(Serialize)] pub struct ClaudeRequest { pub model: String, - pub max_tokens: u32, - pub messages: Vec, - + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, - pub stream: bool, } #[derive(Serialize)] pub struct ClaudeMessage { pub role: String, - pub content: String, +} + +impl ClaudeRequest { + pub fn from_completion_request(request: CompletionRequest) -> Self { + let mut system_prompt = None; + let mut messages = Vec::new(); + + for msg in request.messages { + if msg.role == Role::System { + system_prompt = Some(msg.content); + } else { + messages.push(ClaudeMessage::from(msg)); + } + } + + Self { + model: request.model, + max_tokens: request.max_tokens.unwrap_or(4096), + messages, + system: system_prompt, + temperature: request.temperature, + stream: request.stream, + } + } +} + +impl From for ClaudeMessage { + fn from(value: Message) -> Self { + Self { + role: match value.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + Role::System => "user".to_string(), + }, + content: value.content, + } + } } \ No newline at end of file diff --git a/src/core/providers/claude_messages/response.rs b/src/core/providers/claude_messages/response.rs index 183b521..bcaac3e 100644 --- a/src/core/providers/claude_messages/response.rs +++ b/src/core/providers/claude_messages/response.rs @@ -1,11 +1,10 @@ use serde::Deserialize; +use crate::core::providers::types::{CompletionResponse, FinishReason, Usage}; #[derive(Deserialize)] pub struct ClaudeResponse { pub content: Vec, - pub stop_reason: Option, - pub usage: Option, } @@ -17,6 +16,33 @@ pub struct ClaudeContent { #[derive(Deserialize)] pub struct ClaudeUsage { pub input_tokens: Option, - pub output_tokens: Option, +} + +impl From for CompletionResponse { + fn from(value: ClaudeResponse) -> Self { + let content = value + .content + .first() + .map(|c| c.text.clone()) + .unwrap_or_default(); + + let finish_reason = value.stop_reason.and_then(|r| match r.as_str() { + "end_turn" => Some(FinishReason::Stop), + "max_tokens" => Some(FinishReason::Length), + _ => Some(FinishReason::Error), + }); + + let usage = value.usage.map(|u| Usage { + prompt_tokens: u.input_tokens, + completion_tokens: u.output_tokens, + total_tokens: None, + }); + + CompletionResponse { + content, + finish_reason, + usage, + } + } } \ No newline at end of file diff --git a/src/core/providers/config.rs b/src/core/providers/config.rs index 8bd32ac..8f47572 100644 --- a/src/core/providers/config.rs +++ b/src/core/providers/config.rs @@ -2,18 +2,17 @@ #[derive(Debug, Clone)] pub enum ProviderKind { - OpenAiCompatible, + Compatible, Claude, Gemini, + Deepseek, + Ollama, } #[derive(Debug, Clone)] pub struct ProviderConfig { pub kind: ProviderKind, - pub api_key: String, - pub base_url: String, - pub model: String, } \ No newline at end of file diff --git a/src/core/providers/gemini/adapter.rs b/src/core/providers/gemini/adapter.rs deleted file mode 100644 index 5104425..0000000 --- a/src/core/providers/gemini/adapter.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::core::providers::gemini::request::{GeminiContent, GeminiPart}; -use crate::core::providers::types::{Message, Role}; - -impl From for GeminiContent { - fn from(value: Message) -> Self { - Self { - role: match value.role { - Role::System => "user", - Role::User => "user", - Role::Assistant => "model", - } - .into(), - - parts: vec![GeminiPart { - text: value.content, - }], - } - } -} \ No newline at end of file diff --git a/src/core/providers/gemini/mod.rs b/src/core/providers/gemini/mod.rs deleted file mode 100644 index 5611b82..0000000 --- a/src/core/providers/gemini/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod request; -mod response; -mod adapter; -mod provider; \ No newline at end of file diff --git a/src/core/providers/gemini/provider.rs b/src/core/providers/gemini/provider.rs deleted file mode 100644 index 38eaf71..0000000 --- a/src/core/providers/gemini/provider.rs +++ /dev/null @@ -1,43 +0,0 @@ -use async_trait::async_trait; -use reqwest::Client; -use crate::core::providers::gemini::request::GeminiRequest; -use crate::core::providers::gemini::response::GeminiResponse; -use crate::core::providers::provider::Provider; -use crate::core::providers::types::{CompletionRequest, CompletionResponse}; - -pub struct GeminiProvider { - client: Client, - - api_key: String, -} - -#[async_trait] -impl Provider for GeminiProvider { - async fn complete( - &self, - request: CompletionRequest, - ) -> Result { - - let model = - request.model.clone(); - - let body = - GeminiRequest::from(request); - - let resp = self - .client - .post(format!( - "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", - model, - self.api_key - )) - .json(&body) - .send() - .await?; - - let response: GeminiResponse = - resp.json().await?; - - Ok(response.into()) - } -} \ No newline at end of file diff --git a/src/core/providers/gemini/request.rs b/src/core/providers/gemini/request.rs deleted file mode 100644 index 12043a4..0000000 --- a/src/core/providers/gemini/request.rs +++ /dev/null @@ -1,28 +0,0 @@ -use serde::Serialize; - -#[derive(Serialize)] -pub struct GeminiRequest { - pub contents: Vec, - - #[serde(skip_serializing_if = "Option::is_none")] - pub generation_config: Option, -} - -#[derive(Serialize)] -pub struct GeminiContent { - pub role: String, - - pub parts: Vec, -} - -#[derive(Serialize)] -pub struct GeminiPart { - pub text: String, -} - -#[derive(Serialize)] -pub struct GeminiGenerationConfig { - pub temperature: Option, - - pub max_output_tokens: Option, -} \ No newline at end of file diff --git a/src/core/providers/gemini/response.rs b/src/core/providers/gemini/response.rs deleted file mode 100644 index 926a4d9..0000000 --- a/src/core/providers/gemini/response.rs +++ /dev/null @@ -1,23 +0,0 @@ -use serde::Deserialize; - -#[derive(Deserialize)] -pub struct GeminiResponse { - pub candidates: Vec, -} - -#[derive(Deserialize)] -pub struct GeminiCandidate { - pub content: GeminiContentResponse, - - pub finish_reason: Option, -} - -#[derive(Deserialize)] -pub struct GeminiContentResponse { - pub parts: Vec, -} - -#[derive(Deserialize)] -pub struct GeminiPartResponse { - pub text: String, -} \ No newline at end of file diff --git a/src/core/providers/mod.rs b/src/core/providers/mod.rs index e76cff0..4b01473 100644 --- a/src/core/providers/mod.rs +++ b/src/core/providers/mod.rs @@ -1,5 +1,6 @@ use crate::core::providers::config::{ProviderConfig, ProviderKind}; use crate::core::providers::provider::Provider; +pub use types::{CompletionRequest, CompletionResponse, Message, Role, FinishReason, Usage}; mod types; mod provider; @@ -7,36 +8,35 @@ mod config; mod openai_compatible; mod claude_messages; mod gemini; +mod ollama; -pub fn create_provider( - config: ProviderConfig, -) -> Box { +use openai_compatible::OpenAiCompatibleProvider; +use claude_messages::ClaudeProvider; +use gemini::GeminiProvider; +use ollama::OllamaProvider; +pub fn create_provider(config: ProviderConfig) -> anyhow::Result> { match config.kind { - - ProviderKind::OpenAiCompatible => { - Box::new( - OpenAiCompatibleProvider::new( - config.api_key, - config.base_url, - ) - ) + ProviderKind::Compatible => { + let provider = OpenAiCompatibleProvider::new(config.api_key, config.base_url)?; + Ok(Box::new(provider)) } - ProviderKind::Claude => { - Box::new( - ClaudeProvider::new( - config.api_key, - ) - ) + let provider = ClaudeProvider::new(config.api_key)?; + Ok(Box::new(provider)) } - ProviderKind::Gemini => { - Box::new( - GeminiProvider::new( - config.api_key, - ) - ) + let provider = GeminiProvider::new(config.api_key, config.model)?; + Ok(Box::new(provider)) + } + ProviderKind::Deepseek => { + let provider = OpenAiCompatibleProvider::new(config.api_key, config.base_url)?; + Ok(Box::new(provider)) + } + ProviderKind::Ollama => { + let provider = OllamaProvider::new(config.base_url)?; + Ok(Box::new(provider)) } } -} \ No newline at end of file +} + diff --git a/src/core/providers/openai_compatible/adapter.rs b/src/core/providers/openai_compatible/adapter.rs deleted file mode 100644 index d6b1364..0000000 --- a/src/core/providers/openai_compatible/adapter.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::core::providers::openai_compatible::request::OpenAiMessage; -use crate::core::providers::types::{Message, Role}; - -impl From for OpenAiMessage { - fn from(value: Message) -> Self { - Self { - role: match value.role { - Role::System => "system", - Role::User => "user", - Role::Assistant => "assistant", - } - .into(), - content: value.content, - } - } -} \ No newline at end of file diff --git a/src/core/providers/openai_compatible/mod.rs b/src/core/providers/openai_compatible/mod.rs index 5611b82..71baa27 100644 --- a/src/core/providers/openai_compatible/mod.rs +++ b/src/core/providers/openai_compatible/mod.rs @@ -1,4 +1,7 @@ mod request; mod response; -mod adapter; -mod provider; \ No newline at end of file +mod provider; + +pub use provider::OpenAiCompatibleProvider; +pub use request::OpenAiRequest; +pub use response::OpenAiResponse; \ No newline at end of file diff --git a/src/core/providers/openai_compatible/provider.rs b/src/core/providers/openai_compatible/provider.rs index 0aa4b4e..2a04174 100644 --- a/src/core/providers/openai_compatible/provider.rs +++ b/src/core/providers/openai_compatible/provider.rs @@ -1,54 +1,53 @@ use anyhow::Result; use async_trait::async_trait; -use reqwest::Client; +use std::collections::HashMap; +use crate::core::client::{AiClient, ClientConfig}; +use crate::core::transport::retry::RetryPolicy; use crate::core::providers::openai_compatible::request::OpenAiRequest; use crate::core::providers::openai_compatible::response::OpenAiResponse; use crate::core::providers::provider::Provider; use crate::core::providers::types::{CompletionRequest, CompletionResponse}; pub struct OpenAiCompatibleProvider { - client: Client, - + client: AiClient, api_key: String, - base_url: String, } impl OpenAiCompatibleProvider { - pub fn new( - api_key: String, - base_url: String, - ) -> Self { - Self { - client: Client::new(), + pub fn new(api_key: String, base_url: String) -> Result { + let config = ClientConfig::default(); + let retry = RetryPolicy::default(); + let client = AiClient::new(config, retry)?; + + Ok(Self { + client, api_key, base_url, - } + }) } } #[async_trait] impl Provider for OpenAiCompatibleProvider { - async fn complete( - &self, - request: CompletionRequest, - ) -> Result { - let body = OpenAiRequest::from(request); - - let resp = self - .client - .post(format!( - "{}/chat/completions", - self.base_url - )) - .bearer_auth(&self.api_key) - .json(&body) - .send() - .await?; - - let response: OpenAiResponse = - resp.json().await?; - - Ok(response.into()) + fn client(&self) -> &AiClient { + &self.client + } + + async fn complete(&self, request: CompletionRequest) -> Result { + let openai_request = OpenAiRequest::from_completion_request(request); + + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), format!("Bearer {}", self.api_key)); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + let url = format!("{}/chat/completions", self.base_url); + + let response: OpenAiResponse = self.client + .post_json(&url, headers, &openai_request) + .await + .map_err(|e| anyhow::anyhow!("OpenAI API error: {}", e))?; + + Ok(CompletionResponse::from(response)) } } \ No newline at end of file diff --git a/src/core/providers/openai_compatible/request.rs b/src/core/providers/openai_compatible/request.rs index 652dd4d..4fca86e 100644 --- a/src/core/providers/openai_compatible/request.rs +++ b/src/core/providers/openai_compatible/request.rs @@ -1,23 +1,45 @@ use serde::Serialize; +use crate::core::providers::Role; +use crate::core::providers::types::{CompletionRequest, Message}; #[derive(Serialize)] pub struct OpenAiRequest { pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, - pub stream: bool, } #[derive(Serialize)] pub struct OpenAiMessage { pub role: String, - pub content: String, +} + +impl OpenAiRequest { + pub fn from_completion_request(request: CompletionRequest) -> Self { + Self { + model: request.model, + messages: request.messages.into_iter().map(OpenAiMessage::from).collect(), + temperature: request.temperature, + max_tokens: request.max_tokens, + stream: request.stream, + } + } +} + +impl From for OpenAiMessage { + fn from(value: Message) -> Self { + Self { + role: match value.role { + Role::System => "system".to_string(), + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + }, + content: value.content, + } + } } \ No newline at end of file diff --git a/src/core/providers/openai_compatible/response.rs b/src/core/providers/openai_compatible/response.rs index db8dd53..a56aa84 100644 --- a/src/core/providers/openai_compatible/response.rs +++ b/src/core/providers/openai_compatible/response.rs @@ -1,16 +1,15 @@ use serde::Deserialize; +use crate::core::providers::types::{CompletionResponse, FinishReason, Usage}; #[derive(Deserialize)] pub struct OpenAiResponse { pub choices: Vec, - pub usage: Option, } #[derive(Deserialize)] pub struct OpenAiChoice { pub message: OpenAiAssistantMessage, - pub finish_reason: Option, } @@ -22,8 +21,38 @@ pub struct OpenAiAssistantMessage { #[derive(Deserialize)] pub struct OpenAiUsage { pub prompt_tokens: Option, - pub completion_tokens: Option, - pub total_tokens: Option, +} + +impl From for CompletionResponse { + fn from(value: OpenAiResponse) -> Self { + let content = value + .choices + .first() + .map(|c| c.message.content.clone()) + .unwrap_or_default(); + + let finish_reason = value + .choices + .first() + .and_then(|c| c.finish_reason.as_ref()) + .and_then(|r| match r.as_str() { + "stop" => Some(FinishReason::Stop), + "length" => Some(FinishReason::Length), + _ => Some(FinishReason::Error), + }); + + let usage = value.usage.map(|u| Usage { + prompt_tokens: u.prompt_tokens, + completion_tokens: u.completion_tokens, + total_tokens: u.total_tokens, + }); + + CompletionResponse { + content, + finish_reason, + usage, + } + } } \ No newline at end of file diff --git a/src/core/providers/provider.rs b/src/core/providers/provider.rs index 0386a18..263a179 100644 --- a/src/core/providers/provider.rs +++ b/src/core/providers/provider.rs @@ -1,13 +1,11 @@ -// core/provider.rs - use anyhow::Result; use async_trait::async_trait; +use crate::core::client::AiClient; use crate::core::providers::types::{CompletionRequest, CompletionResponse}; #[async_trait] pub trait Provider: Send + Sync { - async fn complete( - &self, - request: CompletionRequest, - ) -> Result; + fn client(&self) -> &AiClient; + + async fn complete(&self, request: CompletionRequest) -> Result; } \ No newline at end of file diff --git a/src/core/providers/types.rs b/src/core/providers/types.rs index b16a8de..dc1bc1a 100644 --- a/src/core/providers/types.rs +++ b/src/core/providers/types.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum Role { System, diff --git a/src/core/transport/mod.rs b/src/core/transport/mod.rs index d7162ee..984686a 100644 --- a/src/core/transport/mod.rs +++ b/src/core/transport/mod.rs @@ -1,7 +1,14 @@ mod client; -mod errors; -mod request; -mod response; -mod retry; -mod http_transoprt; -mod stream; \ No newline at end of file +pub(crate) mod errors; +pub(crate) mod request; +pub(crate) mod response; +pub mod retry; +pub(crate) mod http_transoprt; +mod stream; +pub mod sse; + +pub use errors::TransportError; +pub use request::TransportRequest; +pub use response::TransportResponse; +pub use http_transoprt::HttpTransport; +pub use sse::{SseEvent, SseError, SseStream, parse_sse_chunk, parse_sse_raw}; \ No newline at end of file