diff --git a/Cargo.lock b/Cargo.lock index 43fccfa..c657879 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7,6 +7,7 @@ name = "OmegaCode" version = "0.1.0-alpha" dependencies = [ "anyhow", + "async-trait", "bytes", "chrono", "clap", diff --git a/Cargo.toml b/Cargo.toml index 7e4032a..46a4097 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,4 +42,5 @@ hex = "0.4.3" dotenv = "0.15.0" bytes = "1.11.1" rand = "0.10.1" -retry = "2.2.0" \ No newline at end of file +retry = "2.2.0" +async-trait = "0.1.89" \ No newline at end of file diff --git a/src/core/providers/anthropic/mod.rs b/src/core/providers/anthropic/mod.rs deleted file mode 100644 index 08c3d84..0000000 --- a/src/core/providers/anthropic/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::core::providers::basic::InterfaceFormat; - -mod model; - -pub struct Anthropic { - base_url: String, - anthropic_version: String, - interface_format: InterfaceFormat, - pub x_api_key: String, -} - -impl Default for Anthropic { - fn default() -> Self { - Anthropic { - base_url: "https://api.anthropic.com".to_string(), - anthropic_version: "2023-06-01".to_string(), - interface_format: InterfaceFormat::Messages, - x_api_key: "sk-ant-...".to_string() - } - } -} \ No newline at end of file diff --git a/src/core/providers/anthropic/model.rs b/src/core/providers/anthropic/model.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/core/providers/basic/completions.rs b/src/core/providers/basic/completions.rs deleted file mode 100644 index 46c292c..0000000 --- a/src/core/providers/basic/completions.rs +++ /dev/null @@ -1,244 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum Role { - System, - User, - Assistant, - Tool, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum ToolType { - Function, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum ResponseFormatType { - Text, - JsonObject, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum FinishReason { - Stop, - Length, - ToolCalls, - ContentFilter, -} - -impl Default for ResponseFormatType { - fn default() -> Self { - Self::Text - } -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct StreamOptions { - #[serde(skip_serializing_if = "Option::is_none")] - pub include_usage: Option, -} - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct ResponseFormat { - pub r#type: ResponseFormatType, -} - -/// 聊天补全请求 -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct CreateChatCompletionRequest { - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_logprobs: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, -} - -/// 聊天补全响应 -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, - pub created: i64, - pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub system_fingerprint: Option, - pub choices: Vec, - pub usage: Usage, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_filter_results: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Choice { - pub index: i32, - pub message: ChatMessage, - #[serde(skip_serializing_if = "Option::is_none")] - pub finish_reason: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logprobs: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Logprobs { - pub content: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LogprobContent { - pub token: String, - pub logprob: f32, - #[serde(skip_serializing_if = "Option::is_none")] - pub bytes: Option>, - pub top_logprobs: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TopLogprob { - pub token: String, - pub logprob: f32, - #[serde(skip_serializing_if = "Option::is_none")] - pub bytes: Option>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt_tokens_details: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub completion_tokens_details: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PromptTokensDetails { - pub cached_tokens: u32, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CompletionTokensDetails { - pub reasoning_tokens: u32, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum MessageContent { - Text(String), - Parts(Vec), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ContentPart { - pub r#type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub image_url: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ImageUrl { - pub url: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChatMessage { - pub role: Role, - pub content: MessageContent, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -impl Default for ChatMessage { - fn default() -> Self { - Self { - role: Role::User, - content: MessageContent::Text(String::new()), - name: None, - tool_call_id: None, - tool_calls: None, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct Tool { - pub r#type: ToolType, - pub function: FunctionObject, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FunctionObject { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolCall { - pub id: String, - pub r#type: ToolType, - pub function: ToolCallFunction, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolCallFunction { - pub name: String, - pub arguments: String, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolChoice { - Strategy(String), - Tool(ToolChoiceObject), -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolChoiceObject { - pub r#type: ToolType, - pub function: ToolChoiceFunction, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ToolChoiceFunction { - pub name: String, -} \ No newline at end of file diff --git a/src/core/providers/basic/messages.rs b/src/core/providers/basic/messages.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/core/providers/basic/mod.rs b/src/core/providers/basic/mod.rs deleted file mode 100644 index 0328a11..0000000 --- a/src/core/providers/basic/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod completions; -mod responses; -mod messages; - -pub enum InterfaceFormat { - Completions, - Responses, - Messages -} \ No newline at end of file diff --git a/src/core/providers/basic/responses.rs b/src/core/providers/basic/responses.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/core/providers/claude_messages/adapter.rs b/src/core/providers/claude_messages/adapter.rs new file mode 100644 index 0000000..70c18c7 --- /dev/null +++ b/src/core/providers/claude_messages/adapter.rs @@ -0,0 +1,18 @@ +use crate::core::providers::claude_messages::request::ClaudeMessage; +use crate::core::providers::types::{Message, Role}; + +impl From for ClaudeMessage { + fn from(value: Message) -> Self { + Self { + role: match value.role { + Role::User => "user", + Role::Assistant => "assistant", + + Role::System => "user", + } + .into(), + + content: value.content, + } + } +} \ No newline at end of file diff --git a/src/core/providers/claude_messages/mod.rs b/src/core/providers/claude_messages/mod.rs new file mode 100644 index 0000000..5611b82 --- /dev/null +++ b/src/core/providers/claude_messages/mod.rs @@ -0,0 +1,4 @@ +mod request; +mod response; +mod adapter; +mod provider; \ No newline at end of file diff --git a/src/core/providers/claude_messages/provider.rs b/src/core/providers/claude_messages/provider.rs new file mode 100644 index 0000000..f30a331 --- /dev/null +++ b/src/core/providers/claude_messages/provider.rs @@ -0,0 +1,57 @@ +use async_trait::async_trait; +use reqwest::Client; +use crate::core::providers::claude_messages::request::ClaudeRequest; +use crate::core::providers::claude_messages::response::ClaudeResponse; +use crate::core::providers::provider::Provider; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +pub struct ClaudeProvider { + client: Client, + + api_key: String, +} + +impl ClaudeProvider { + pub fn new( + api_key: String, + ) -> Self { + Self { + client: Client::new(), + api_key, + } + } +} + +#[async_trait] +impl Provider for ClaudeProvider { + async fn complete( + &self, + request: CompletionRequest, + ) -> Result { + + let body = + ClaudeRequest::from(request); + + let resp = self + .client + .post( + "https://api.anthropic.com/v1/messages" + ) + .header( + "x-api-key", + &self.api_key, + ) + .header( + "anthropic-version", + "2023-06-01", + ) + .json(&body) + .send() + .await?; + + let response: ClaudeResponse = + resp.json().await?; + + Ok(response.into()) + } +} \ No newline at end of file diff --git a/src/core/providers/claude_messages/request.rs b/src/core/providers/claude_messages/request.rs new file mode 100644 index 0000000..1115e89 --- /dev/null +++ b/src/core/providers/claude_messages/request.rs @@ -0,0 +1,22 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub struct ClaudeRequest { + pub model: String, + + pub max_tokens: u32, + + pub messages: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + pub stream: bool, +} + +#[derive(Serialize)] +pub struct ClaudeMessage { + pub role: String, + + pub content: String, +} \ No newline at end of file diff --git a/src/core/providers/claude_messages/response.rs b/src/core/providers/claude_messages/response.rs new file mode 100644 index 0000000..183b521 --- /dev/null +++ b/src/core/providers/claude_messages/response.rs @@ -0,0 +1,22 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct ClaudeResponse { + pub content: Vec, + + pub stop_reason: Option, + + pub usage: Option, +} + +#[derive(Deserialize)] +pub struct ClaudeContent { + pub text: String, +} + +#[derive(Deserialize)] +pub struct ClaudeUsage { + pub input_tokens: Option, + + pub output_tokens: Option, +} \ No newline at end of file diff --git a/src/core/providers/config.rs b/src/core/providers/config.rs new file mode 100644 index 0000000..8bd32ac --- /dev/null +++ b/src/core/providers/config.rs @@ -0,0 +1,19 @@ +// providers/config.rs + +#[derive(Debug, Clone)] +pub enum ProviderKind { + OpenAiCompatible, + Claude, + Gemini, +} + +#[derive(Debug, Clone)] +pub struct ProviderConfig { + pub kind: ProviderKind, + + pub api_key: String, + + pub base_url: String, + + pub model: String, +} \ No newline at end of file diff --git a/src/core/providers/deepseek/mod.rs b/src/core/providers/deepseek/mod.rs deleted file mode 100644 index 15b74a0..0000000 --- a/src/core/providers/deepseek/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::core::providers::basic::InterfaceFormat; - -mod model; - -pub struct Deepseek { - base_url: String, - interface_format: InterfaceFormat, - pub authorization: String, -} - -impl Default for Deepseek { - fn default() -> Self { - Deepseek { - base_url: "https://api.deepseek.com".to_string(), - interface_format: InterfaceFormat::Completions, - authorization: "sk-...".to_string() - } - } -} \ No newline at end of file diff --git a/src/core/providers/deepseek/model.rs b/src/core/providers/deepseek/model.rs deleted file mode 100644 index 7cee7ef..0000000 --- a/src/core/providers/deepseek/model.rs +++ /dev/null @@ -1,20 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// GET /models 请求(无参数,空结构体) -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct ListModelsRequest {} - -/// 模型列表响应 -#[derive(Debug, Serialize, Deserialize)] -pub struct ListModelsResponse { - pub object: String, - pub data: Vec, -} - -/// 单个模型信息 -#[derive(Debug, Serialize, Deserialize)] -pub struct ModelInfo { - pub id: String, - pub object: String, - pub owned_by: String, -} \ No newline at end of file diff --git a/src/core/providers/gemini/adapter.rs b/src/core/providers/gemini/adapter.rs new file mode 100644 index 0000000..5104425 --- /dev/null +++ b/src/core/providers/gemini/adapter.rs @@ -0,0 +1,19 @@ +use crate::core::providers::gemini::request::{GeminiContent, GeminiPart}; +use crate::core::providers::types::{Message, Role}; + +impl From for GeminiContent { + fn from(value: Message) -> Self { + Self { + role: match value.role { + Role::System => "user", + Role::User => "user", + Role::Assistant => "model", + } + .into(), + + parts: vec![GeminiPart { + text: value.content, + }], + } + } +} \ No newline at end of file diff --git a/src/core/providers/gemini/mod.rs b/src/core/providers/gemini/mod.rs new file mode 100644 index 0000000..5611b82 --- /dev/null +++ b/src/core/providers/gemini/mod.rs @@ -0,0 +1,4 @@ +mod request; +mod response; +mod adapter; +mod provider; \ No newline at end of file diff --git a/src/core/providers/gemini/provider.rs b/src/core/providers/gemini/provider.rs new file mode 100644 index 0000000..38eaf71 --- /dev/null +++ b/src/core/providers/gemini/provider.rs @@ -0,0 +1,43 @@ +use async_trait::async_trait; +use reqwest::Client; +use crate::core::providers::gemini::request::GeminiRequest; +use crate::core::providers::gemini::response::GeminiResponse; +use crate::core::providers::provider::Provider; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +pub struct GeminiProvider { + client: Client, + + api_key: String, +} + +#[async_trait] +impl Provider for GeminiProvider { + async fn complete( + &self, + request: CompletionRequest, + ) -> Result { + + let model = + request.model.clone(); + + let body = + GeminiRequest::from(request); + + let resp = self + .client + .post(format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", + model, + self.api_key + )) + .json(&body) + .send() + .await?; + + let response: GeminiResponse = + resp.json().await?; + + Ok(response.into()) + } +} \ No newline at end of file diff --git a/src/core/providers/gemini/request.rs b/src/core/providers/gemini/request.rs new file mode 100644 index 0000000..12043a4 --- /dev/null +++ b/src/core/providers/gemini/request.rs @@ -0,0 +1,28 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub struct GeminiRequest { + pub contents: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub generation_config: Option, +} + +#[derive(Serialize)] +pub struct GeminiContent { + pub role: String, + + pub parts: Vec, +} + +#[derive(Serialize)] +pub struct GeminiPart { + pub text: String, +} + +#[derive(Serialize)] +pub struct GeminiGenerationConfig { + pub temperature: Option, + + pub max_output_tokens: Option, +} \ No newline at end of file diff --git a/src/core/providers/gemini/response.rs b/src/core/providers/gemini/response.rs new file mode 100644 index 0000000..926a4d9 --- /dev/null +++ b/src/core/providers/gemini/response.rs @@ -0,0 +1,23 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct GeminiResponse { + pub candidates: Vec, +} + +#[derive(Deserialize)] +pub struct GeminiCandidate { + pub content: GeminiContentResponse, + + pub finish_reason: Option, +} + +#[derive(Deserialize)] +pub struct GeminiContentResponse { + pub parts: Vec, +} + +#[derive(Deserialize)] +pub struct GeminiPartResponse { + pub text: String, +} \ No newline at end of file diff --git a/src/core/providers/mod.rs b/src/core/providers/mod.rs index 6312b2c..e76cff0 100644 --- a/src/core/providers/mod.rs +++ b/src/core/providers/mod.rs @@ -1,4 +1,42 @@ -mod anthropic; -mod openai; -mod deepseek; -mod basic; \ No newline at end of file +use crate::core::providers::config::{ProviderConfig, ProviderKind}; +use crate::core::providers::provider::Provider; + +mod types; +mod provider; +mod config; +mod openai_compatible; +mod claude_messages; +mod gemini; + +pub fn create_provider( + config: ProviderConfig, +) -> Box { + + match config.kind { + + ProviderKind::OpenAiCompatible => { + Box::new( + OpenAiCompatibleProvider::new( + config.api_key, + config.base_url, + ) + ) + } + + ProviderKind::Claude => { + Box::new( + ClaudeProvider::new( + config.api_key, + ) + ) + } + + ProviderKind::Gemini => { + Box::new( + GeminiProvider::new( + config.api_key, + ) + ) + } + } +} \ No newline at end of file diff --git a/src/core/providers/openai/mod.rs b/src/core/providers/openai/mod.rs deleted file mode 100644 index e457d83..0000000 --- a/src/core/providers/openai/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -use crate::core::providers::basic::InterfaceFormat; - -mod request; -mod model; - -pub struct OpenAI { - base_url: String, - interface_format: InterfaceFormat, - pub authorization: String, -} - -impl Default for OpenAI { - fn default() -> Self { - OpenAI { - base_url: "https://api.openai.com".to_string(), - interface_format: InterfaceFormat::Responses, - authorization: "sk-...".to_string() - } - } -} \ No newline at end of file diff --git a/src/core/providers/openai/model.rs b/src/core/providers/openai/model.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/core/providers/openai/request.rs b/src/core/providers/openai/request.rs deleted file mode 100644 index 02fd852..0000000 --- a/src/core/providers/openai/request.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Request { - -} \ No newline at end of file diff --git a/src/core/providers/openai_compatible/adapter.rs b/src/core/providers/openai_compatible/adapter.rs new file mode 100644 index 0000000..d6b1364 --- /dev/null +++ b/src/core/providers/openai_compatible/adapter.rs @@ -0,0 +1,16 @@ +use crate::core::providers::openai_compatible::request::OpenAiMessage; +use crate::core::providers::types::{Message, Role}; + +impl From for OpenAiMessage { + fn from(value: Message) -> Self { + Self { + role: match value.role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + } + .into(), + content: value.content, + } + } +} \ No newline at end of file diff --git a/src/core/providers/openai_compatible/mod.rs b/src/core/providers/openai_compatible/mod.rs new file mode 100644 index 0000000..5611b82 --- /dev/null +++ b/src/core/providers/openai_compatible/mod.rs @@ -0,0 +1,4 @@ +mod request; +mod response; +mod adapter; +mod provider; \ No newline at end of file diff --git a/src/core/providers/openai_compatible/provider.rs b/src/core/providers/openai_compatible/provider.rs new file mode 100644 index 0000000..0aa4b4e --- /dev/null +++ b/src/core/providers/openai_compatible/provider.rs @@ -0,0 +1,54 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqwest::Client; +use crate::core::providers::openai_compatible::request::OpenAiRequest; +use crate::core::providers::openai_compatible::response::OpenAiResponse; +use crate::core::providers::provider::Provider; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +pub struct OpenAiCompatibleProvider { + client: Client, + + api_key: String, + + base_url: String, +} + +impl OpenAiCompatibleProvider { + pub fn new( + api_key: String, + base_url: String, + ) -> Self { + Self { + client: Client::new(), + api_key, + base_url, + } + } +} + +#[async_trait] +impl Provider for OpenAiCompatibleProvider { + async fn complete( + &self, + request: CompletionRequest, + ) -> Result { + let body = OpenAiRequest::from(request); + + let resp = self + .client + .post(format!( + "{}/chat/completions", + self.base_url + )) + .bearer_auth(&self.api_key) + .json(&body) + .send() + .await?; + + let response: OpenAiResponse = + resp.json().await?; + + Ok(response.into()) + } +} \ No newline at end of file diff --git a/src/core/providers/openai_compatible/request.rs b/src/core/providers/openai_compatible/request.rs new file mode 100644 index 0000000..652dd4d --- /dev/null +++ b/src/core/providers/openai_compatible/request.rs @@ -0,0 +1,23 @@ +use serde::Serialize; + +#[derive(Serialize)] +pub struct OpenAiRequest { + pub model: String, + + pub messages: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + pub stream: bool, +} + +#[derive(Serialize)] +pub struct OpenAiMessage { + pub role: String, + + pub content: String, +} \ No newline at end of file diff --git a/src/core/providers/openai_compatible/response.rs b/src/core/providers/openai_compatible/response.rs new file mode 100644 index 0000000..db8dd53 --- /dev/null +++ b/src/core/providers/openai_compatible/response.rs @@ -0,0 +1,29 @@ +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct OpenAiResponse { + pub choices: Vec, + + pub usage: Option, +} + +#[derive(Deserialize)] +pub struct OpenAiChoice { + pub message: OpenAiAssistantMessage, + + pub finish_reason: Option, +} + +#[derive(Deserialize)] +pub struct OpenAiAssistantMessage { + pub content: String, +} + +#[derive(Deserialize)] +pub struct OpenAiUsage { + pub prompt_tokens: Option, + + pub completion_tokens: Option, + + pub total_tokens: Option, +} \ No newline at end of file diff --git a/src/core/providers/provider.rs b/src/core/providers/provider.rs new file mode 100644 index 0000000..0386a18 --- /dev/null +++ b/src/core/providers/provider.rs @@ -0,0 +1,13 @@ +// core/provider.rs + +use anyhow::Result; +use async_trait::async_trait; +use crate::core::providers::types::{CompletionRequest, CompletionResponse}; + +#[async_trait] +pub trait Provider: Send + Sync { + async fn complete( + &self, + request: CompletionRequest, + ) -> Result; +} \ No newline at end of file diff --git a/src/core/providers/types.rs b/src/core/providers/types.rs new file mode 100644 index 0000000..b16a8de --- /dev/null +++ b/src/core/providers/types.rs @@ -0,0 +1,56 @@ +// core/types.rs + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionRequest { + pub model: String, + + pub messages: Vec, + + pub temperature: Option, + + pub max_tokens: Option, + + pub stream: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + Error, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: Option, + + pub completion_tokens: Option, + + pub total_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionResponse { + pub content: String, + + pub finish_reason: Option, + + pub usage: Option, +} \ No newline at end of file