diff --git a/Cargo.lock b/Cargo.lock index 090501c..412a0da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,7 @@ dependencies = [ "ratatui 0.30.0", "ratatui-kit", "reqwest", + "retry", "rmcp", "rusqlite", "rust-i18n", @@ -2590,6 +2591,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "retry" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cab9bd343c737660e523ee69f788018f3db686d537d2fd0f99c9f747c1bda4f" +dependencies = [ + "rand 0.9.4", +] + [[package]] name = "ring" version = "0.17.14" diff --git a/Cargo.toml b/Cargo.toml index b1c3450..9573ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,4 +41,5 @@ thiserror = "2.0.18" hex = "0.4.3" dotenv = "0.15.0" bytes = "1.11.1" -rand = "0.10.1" \ No newline at end of file +rand = "0.10.1" +retry = "2.2.0" \ No newline at end of file diff --git a/src/core/client/base_client.rs b/src/core/client/base_client.rs new file mode 100644 index 0000000..ac1caf3 --- /dev/null +++ b/src/core/client/base_client.rs @@ -0,0 +1,37 @@ +use reqwest::{Client, Request, Response}; +use crate::core::client::config::ClientConfig; + +#[derive(Clone)] +pub struct BaseClient { + inner: Client, +} + +impl BaseClient { + pub fn new(config: ClientConfig) -> anyhow::Result { + let mut builder = Client::builder() + .timeout(config.timeout) + .connect_timeout(config.connect_timeout) + .user_agent(config.user_agent); + + if let Some(proxy) = config.proxy { + builder = builder.proxy(reqwest::Proxy::all(proxy)?); + } + + let client = builder.build()?; + + Ok(Self { + inner: client, + }) + } + + pub async fn execute( + &self, + request: Request, + ) -> Result { + self.inner.execute(request).await + } + + pub fn inner(&self) -> &Client { + &self.inner + } +} \ No newline at end of file diff --git a/src/core/client/config.rs b/src/core/client/config.rs new file mode 100644 index 0000000..7339e4f --- /dev/null +++ b/src/core/client/config.rs @@ -0,0 +1,20 @@ +use std::time::Duration; + +#[derive(Clone)] +pub struct ClientConfig { + pub timeout: Duration, + pub connect_timeout: Duration, + pub proxy: Option, + pub user_agent: String, +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + timeout: Duration::from_secs(300), + connect_timeout: Duration::from_secs(10), + proxy: None, + user_agent: "agent-runtime/0.1.0".into(), + } + } +} \ No newline at end of file diff --git a/src/core/client/mod.rs b/src/core/client/mod.rs new file mode 100644 index 0000000..432ef50 --- /dev/null +++ b/src/core/client/mod.rs @@ -0,0 +1,657 @@ +mod config; +pub mod base_client; + +use anyhow::{anyhow, Result}; +use bytes::Bytes; +use futures_util::stream::Stream; +use rand::{Rng, RngExt}; +use reqwest::{ + multipart, + Client, + Method, + Request, + RequestBuilder, + Response, + StatusCode, +}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; +use std::time::Duration; +use log::{info, warn}; +use tokio::fs; + +// ============================================================ +// Retry Config +// ============================================================ + +#[derive(Debug, Clone)] +pub struct RetryConfig { + pub max_retries: u32, + + pub retry_interval: u64, + + pub use_exponential_backoff: bool, + + pub backoff_factor: f64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + retry_interval: 1000, + use_exponential_backoff: true, + backoff_factor: 2.0, + } + } +} + +// ============================================================ +// Request Context +// ============================================================ + +#[derive(Debug, Clone)] +pub struct RequestContext { + pub method: Method, + pub url: String, +} + +// ============================================================ +// Request Result +// ============================================================ + +pub enum RequestResult { + Response(Response), + Error(reqwest::Error), +} + +// ============================================================ +// Retry Predicate +// ============================================================ + +pub type RetryPredicate = +Box< + dyn Fn(&RequestContext, &RequestResult) -> bool + + Send + + Sync + + 'static, +>; + +// ============================================================ +// Request Factory +// ============================================================ + +type BoxFuture = +Pin + Send>>; + +pub type RequestFactory = +Box< + dyn Fn() -> BoxFuture> + + Send + + Sync, +>; + +// ============================================================ +// Retry Middleware +// ============================================================ + +pub struct RetryMiddleware { + pub config: RetryConfig, + + pub predicate: Option, +} + +impl Default for RetryMiddleware { + fn default() -> Self { + Self { + config: RetryConfig::default(), + predicate: None, + } + } +} + +impl RetryMiddleware { + pub fn new(config: RetryConfig) -> Self { + Self { + config, + predicate: None, + } + } + + pub fn with_retry_predicate( + mut self, + f: F, + ) -> Self + where + F: Fn(&RequestContext, &RequestResult) -> bool + + Send + + Sync + + 'static, + { + self.predicate = Some(Box::new(f)); + + self + } + + // ======================================================== + // Full Jitter Backoff + // ======================================================== + + fn calculate_wait_time( + &self, + attempts: u32, + ) -> Duration { + let base = self.config.retry_interval as f64; + + let max_wait = if self + .config + .use_exponential_backoff + { + base + * self + .config + .backoff_factor + .powi(attempts as i32) + } else { + base + }; + + let mut rng = rand::rng(); + + let wait_ms = + rng.random_range(0.0..=max_wait); + + Duration::from_millis(wait_ms as u64) + } + + // ======================================================== + // Retry-After + // ======================================================== + + fn retry_after( + response: &Response, + ) -> Option { + let header = response + .headers() + .get("retry-after")?; + + let value = header.to_str().ok()?; + + let secs = value.parse::().ok()?; + + Some(Duration::from_secs(secs)) + } + + // ======================================================== + // Default Retry Rule + // ======================================================== + + fn should_retry( + &self, + ctx: &RequestContext, + result: &RequestResult, + ) -> bool { + if let Some(predicate) = &self.predicate { + return predicate(ctx, result); + } + + // 默认只 retry 幂等请求 + let retryable_method = matches!( + ctx.method, + Method::GET + | Method::HEAD + | Method::PUT + | Method::DELETE + ); + + if !retryable_method { + return false; + } + + match result { + RequestResult::Response(resp) => { + resp.status().is_server_error() + || resp.status() + == StatusCode::TOO_MANY_REQUESTS + } + + RequestResult::Error(err) => { + err.is_timeout() + || err.is_connect() + || err + .status() + .map(|s| s.is_server_error()) + .unwrap_or(false) + } + } + } + + // ======================================================== + // Execute + // ======================================================== + + pub async fn execute( + &self, + client: &Client, + factory: RequestFactory, + ) -> Result { + let mut attempt = 0; + + loop { + let (ctx, request) = factory().await?; + + info!( + method = %ctx.method, + url = %ctx.url, + attempt, + "sending request" + ); + + let result = + match client.execute(request).await { + Ok(resp) => { + RequestResult::Response(resp) + } + + Err(err) => { + RequestResult::Error(err) + } + }; + + let should_retry = + self.should_retry(&ctx, &result); + + if !should_retry + || attempt >= self.config.max_retries + { + return match result { + RequestResult::Response(resp) => { + Ok(resp) + } + + RequestResult::Error(err) => { + Err(anyhow!(err)) + } + }; + } + + attempt += 1; + + let wait = match &result { + RequestResult::Response(resp) => { + Self::retry_after(resp) + .unwrap_or_else(|| { + self.calculate_wait_time( + attempt, + ) + }) + } + + _ => self.calculate_wait_time(attempt), + }; + + warn!( + method = %ctx.method, + url = %ctx.url, + attempt, + wait_ms = wait.as_millis(), + "retrying request" + ); + + tokio::time::sleep(wait).await; + } + } +} + +// ============================================================ +// Http Client +// ============================================================ + +pub struct HttpClient { + client: Client, + + retry_middleware: RetryMiddleware, +} + +impl HttpClient { + pub fn new( + retry_config: Option, + ) -> Result { + let client = Client::builder() + .connect_timeout(Duration::from_secs( + 10, + )) + .timeout(Duration::from_secs(300)) + .pool_idle_timeout(Duration::from_secs( + 90, + )) + .tcp_keepalive(Duration::from_secs(60)) + .build()?; + + Ok(Self { + client, + + retry_middleware: retry_config + .map(RetryMiddleware::new) + .unwrap_or_default(), + }) + } + + // ======================================================== + // Core Send + // ======================================================== + + pub async fn send( + &self, + builder: RequestBuilder, + ) -> Result { + let builder = builder + .try_clone() + .ok_or_else(|| { + anyhow!( + "request builder cannot be cloned" + ) + })?; + + let factory: RequestFactory = + Box::new(move || { + let builder = builder + .try_clone() + .ok_or_else(|| { + anyhow!( + "request builder clone failed" + ) + }); + + Box::pin(async move { + let builder = builder?; + + let request = builder.build()?; + + let ctx = RequestContext { + method: request.method().clone(), + + url: request.url().to_string(), + }; + + Ok((ctx, request)) + }) + }); + + self.retry_middleware + .execute(&self.client, factory) + .await + } + + // ======================================================== + // GET JSON + // ======================================================== + + pub async fn get_json( + &self, + url: &str, + query: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let response = self + .send(self.client.get(url).query(query)) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // POST JSON + // ======================================================== + + pub async fn post_json( + &self, + url: &str, + body: &impl Serialize, + ) -> Result + where + T: DeserializeOwned, + { + let response = self + .send(self.client.post(url).json(body)) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // GET TEXT + // ======================================================== + + pub async fn get_text( + &self, + url: &str, + ) -> Result { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.text().await?) + } + + // ======================================================== + // GET BYTES + // ======================================================== + + pub async fn get_bytes( + &self, + url: &str, + ) -> Result { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.bytes().await?) + } + + // ======================================================== + // SSE / STREAM + // ======================================================== + + pub async fn get_stream( + &self, + url: &str, + ) -> Result< + impl Stream< + Item = Result, + >, + > { + let response = self + .send(self.client.get(url)) + .await?; + + Ok(response.bytes_stream()) + } + + // ======================================================== + // Multipart Upload + // ======================================================== + + pub async fn upload_file( + &self, + url: &str, + field_name: &str, + file_path: impl Into, + ) -> Result + where + T: DeserializeOwned, + { + let path = file_path.into(); + + let client = self.client.clone(); + + let url = url.to_string(); + + let field_name = field_name.to_string(); + + let factory: RequestFactory = + Box::new(move || { + let client = client.clone(); + + let path = path.clone(); + + let url = url.clone(); + + let field_name = + field_name.clone(); + + Box::pin(async move { + let bytes = + fs::read(&path).await?; + + let filename = path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = + multipart::Part::bytes( + bytes, + ) + .file_name(filename); + + let form = + multipart::Form::new() + .part( + field_name, + part, + ); + + let request = client + .post(url) + .multipart(form) + .build()?; + + let ctx = RequestContext { + method: request + .method() + .clone(), + + url: request + .url() + .to_string(), + }; + + Ok((ctx, request)) + }) + }); + + let response = self + .retry_middleware + .execute(&self.client, factory) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // Stream Upload + // ======================================================== + + pub async fn upload_stream( + &self, + url: &str, + stream_factory: impl Fn() -> S + + Send + + Sync + + 'static, + ) -> Result + where + T: DeserializeOwned, + + S: Stream< + Item = Result< + Bytes, + std::io::Error, + >, + > + Send + + 'static, + { + let client = self.client.clone(); + + let url = url.to_string(); + + let factory: RequestFactory = + Box::new(move || { + let client = client.clone(); + + let url = url.clone(); + + let stream = stream_factory(); + + Box::pin(async move { + let body = + reqwest::Body::wrap_stream( + stream, + ); + + let request = client + .post(url) + .body(body) + .build()?; + + let ctx = RequestContext { + method: request + .method() + .clone(), + + url: request + .url() + .to_string(), + }; + + Ok((ctx, request)) + }) + }); + + let response = self + .retry_middleware + .execute(&self.client, factory) + .await?; + + Ok(response.json().await?) + } + + // ======================================================== + // Retry Predicate + // ======================================================== + + pub fn with_retry_predicate( + mut self, + predicate: F, + ) -> Self + where + F: Fn( + &RequestContext, + &RequestResult, + ) -> bool + + Send + + Sync + + 'static, + { + self.retry_middleware = self + .retry_middleware + .with_retry_predicate(predicate); + + self + } + + // ======================================================== + // Inner Client + // ======================================================== + + pub fn inner(&self) -> &Client { + &self.client + } +} \ No newline at end of file diff --git a/src/core/mod.rs b/src/core/mod.rs index 1807b9b..2f611c3 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -5,4 +5,5 @@ pub mod runtime; pub mod db; pub mod agent; mod providers; -mod errors; \ No newline at end of file +mod errors; +mod client; \ No newline at end of file diff --git a/src/core/transport/client.rs b/src/core/transport/client.rs index e0c4eea..e69de29 100644 --- a/src/core/transport/client.rs +++ b/src/core/transport/client.rs @@ -1,654 +0,0 @@ -use anyhow::{anyhow, Result}; -use bytes::Bytes; -use futures_util::stream::Stream; -use rand::{Rng, RngExt}; -use reqwest::{ - multipart, - Client, - Method, - Request, - RequestBuilder, - Response, - StatusCode, -}; -use serde::de::DeserializeOwned; -use serde::Serialize; -use std::future::Future; -use std::path::PathBuf; -use std::pin::Pin; -use std::time::Duration; -use log::{info, warn}; -use tokio::fs; - -// ============================================================ -// Retry Config -// ============================================================ - -#[derive(Debug, Clone)] -pub struct RetryConfig { - pub max_retries: u32, - - pub retry_interval: u64, - - pub use_exponential_backoff: bool, - - pub backoff_factor: f64, -} - -impl Default for RetryConfig { - fn default() -> Self { - Self { - max_retries: 3, - retry_interval: 1000, - use_exponential_backoff: true, - backoff_factor: 2.0, - } - } -} - -// ============================================================ -// Request Context -// ============================================================ - -#[derive(Debug, Clone)] -pub struct RequestContext { - pub method: Method, - pub url: String, -} - -// ============================================================ -// Request Result -// ============================================================ - -pub enum RequestResult { - Response(Response), - Error(reqwest::Error), -} - -// ============================================================ -// Retry Predicate -// ============================================================ - -pub type RetryPredicate = -Box< - dyn Fn(&RequestContext, &RequestResult) -> bool - + Send - + Sync - + 'static, ->; - -// ============================================================ -// Request Factory -// ============================================================ - -type BoxFuture = -Pin + Send>>; - -pub type RequestFactory = -Box< - dyn Fn() -> BoxFuture> - + Send - + Sync, ->; - -// ============================================================ -// Retry Middleware -// ============================================================ - -pub struct RetryMiddleware { - pub config: RetryConfig, - - pub predicate: Option, -} - -impl Default for RetryMiddleware { - fn default() -> Self { - Self { - config: RetryConfig::default(), - predicate: None, - } - } -} - -impl RetryMiddleware { - pub fn new(config: RetryConfig) -> Self { - Self { - config, - predicate: None, - } - } - - pub fn with_retry_predicate( - mut self, - f: F, - ) -> Self - where - F: Fn(&RequestContext, &RequestResult) -> bool - + Send - + Sync - + 'static, - { - self.predicate = Some(Box::new(f)); - - self - } - - // ======================================================== - // Full Jitter Backoff - // ======================================================== - - fn calculate_wait_time( - &self, - attempts: u32, - ) -> Duration { - let base = self.config.retry_interval as f64; - - let max_wait = if self - .config - .use_exponential_backoff - { - base - * self - .config - .backoff_factor - .powi(attempts as i32) - } else { - base - }; - - let mut rng = rand::rng(); - - let wait_ms = - rng.random_range(0.0..=max_wait); - - Duration::from_millis(wait_ms as u64) - } - - // ======================================================== - // Retry-After - // ======================================================== - - fn retry_after( - response: &Response, - ) -> Option { - let header = response - .headers() - .get("retry-after")?; - - let value = header.to_str().ok()?; - - let secs = value.parse::().ok()?; - - Some(Duration::from_secs(secs)) - } - - // ======================================================== - // Default Retry Rule - // ======================================================== - - fn should_retry( - &self, - ctx: &RequestContext, - result: &RequestResult, - ) -> bool { - if let Some(predicate) = &self.predicate { - return predicate(ctx, result); - } - - // 默认只 retry 幂等请求 - let retryable_method = matches!( - ctx.method, - Method::GET - | Method::HEAD - | Method::PUT - | Method::DELETE - ); - - if !retryable_method { - return false; - } - - match result { - RequestResult::Response(resp) => { - resp.status().is_server_error() - || resp.status() - == StatusCode::TOO_MANY_REQUESTS - } - - RequestResult::Error(err) => { - err.is_timeout() - || err.is_connect() - || err - .status() - .map(|s| s.is_server_error()) - .unwrap_or(false) - } - } - } - - // ======================================================== - // Execute - // ======================================================== - - pub async fn execute( - &self, - client: &Client, - factory: RequestFactory, - ) -> Result { - let mut attempt = 0; - - loop { - let (ctx, request) = factory().await?; - - info!( - method = %ctx.method, - url = %ctx.url, - attempt, - "sending request" - ); - - let result = - match client.execute(request).await { - Ok(resp) => { - RequestResult::Response(resp) - } - - Err(err) => { - RequestResult::Error(err) - } - }; - - let should_retry = - self.should_retry(&ctx, &result); - - if !should_retry - || attempt >= self.config.max_retries - { - return match result { - RequestResult::Response(resp) => { - Ok(resp) - } - - RequestResult::Error(err) => { - Err(anyhow!(err)) - } - }; - } - - attempt += 1; - - let wait = match &result { - RequestResult::Response(resp) => { - Self::retry_after(resp) - .unwrap_or_else(|| { - self.calculate_wait_time( - attempt, - ) - }) - } - - _ => self.calculate_wait_time(attempt), - }; - - warn!( - method = %ctx.method, - url = %ctx.url, - attempt, - wait_ms = wait.as_millis(), - "retrying request" - ); - - tokio::time::sleep(wait).await; - } - } -} - -// ============================================================ -// Http Client -// ============================================================ - -pub struct HttpClient { - client: Client, - - retry_middleware: RetryMiddleware, -} - -impl HttpClient { - pub fn new( - retry_config: Option, - ) -> Result { - let client = Client::builder() - .connect_timeout(Duration::from_secs( - 10, - )) - .timeout(Duration::from_secs(300)) - .pool_idle_timeout(Duration::from_secs( - 90, - )) - .tcp_keepalive(Duration::from_secs(60)) - .build()?; - - Ok(Self { - client, - - retry_middleware: retry_config - .map(RetryMiddleware::new) - .unwrap_or_default(), - }) - } - - // ======================================================== - // Core Send - // ======================================================== - - pub async fn send( - &self, - builder: RequestBuilder, - ) -> Result { - let builder = builder - .try_clone() - .ok_or_else(|| { - anyhow!( - "request builder cannot be cloned" - ) - })?; - - let factory: RequestFactory = - Box::new(move || { - let builder = builder - .try_clone() - .ok_or_else(|| { - anyhow!( - "request builder clone failed" - ) - }); - - Box::pin(async move { - let builder = builder?; - - let request = builder.build()?; - - let ctx = RequestContext { - method: request.method().clone(), - - url: request.url().to_string(), - }; - - Ok((ctx, request)) - }) - }); - - self.retry_middleware - .execute(&self.client, factory) - .await - } - - // ======================================================== - // GET JSON - // ======================================================== - - pub async fn get_json( - &self, - url: &str, - query: &impl Serialize, - ) -> Result - where - T: DeserializeOwned, - { - let response = self - .send(self.client.get(url).query(query)) - .await?; - - Ok(response.json().await?) - } - - // ======================================================== - // POST JSON - // ======================================================== - - pub async fn post_json( - &self, - url: &str, - body: &impl Serialize, - ) -> Result - where - T: DeserializeOwned, - { - let response = self - .send(self.client.post(url).json(body)) - .await?; - - Ok(response.json().await?) - } - - // ======================================================== - // GET TEXT - // ======================================================== - - pub async fn get_text( - &self, - url: &str, - ) -> Result { - let response = self - .send(self.client.get(url)) - .await?; - - Ok(response.text().await?) - } - - // ======================================================== - // GET BYTES - // ======================================================== - - pub async fn get_bytes( - &self, - url: &str, - ) -> Result { - let response = self - .send(self.client.get(url)) - .await?; - - Ok(response.bytes().await?) - } - - // ======================================================== - // SSE / STREAM - // ======================================================== - - pub async fn get_stream( - &self, - url: &str, - ) -> Result< - impl Stream< - Item = Result, - >, - > { - let response = self - .send(self.client.get(url)) - .await?; - - Ok(response.bytes_stream()) - } - - // ======================================================== - // Multipart Upload - // ======================================================== - - pub async fn upload_file( - &self, - url: &str, - field_name: &str, - file_path: impl Into, - ) -> Result - where - T: DeserializeOwned, - { - let path = file_path.into(); - - let client = self.client.clone(); - - let url = url.to_string(); - - let field_name = field_name.to_string(); - - let factory: RequestFactory = - Box::new(move || { - let client = client.clone(); - - let path = path.clone(); - - let url = url.clone(); - - let field_name = - field_name.clone(); - - Box::pin(async move { - let bytes = - fs::read(&path).await?; - - let filename = path - .file_name() - .unwrap_or_default() - .to_string_lossy() - .to_string(); - - let part = - multipart::Part::bytes( - bytes, - ) - .file_name(filename); - - let form = - multipart::Form::new() - .part( - field_name, - part, - ); - - let request = client - .post(url) - .multipart(form) - .build()?; - - let ctx = RequestContext { - method: request - .method() - .clone(), - - url: request - .url() - .to_string(), - }; - - Ok((ctx, request)) - }) - }); - - let response = self - .retry_middleware - .execute(&self.client, factory) - .await?; - - Ok(response.json().await?) - } - - // ======================================================== - // Stream Upload - // ======================================================== - - pub async fn upload_stream( - &self, - url: &str, - stream_factory: impl Fn() -> S - + Send - + Sync - + 'static, - ) -> Result - where - T: DeserializeOwned, - - S: Stream< - Item = Result< - Bytes, - std::io::Error, - >, - > + Send - + 'static, - { - let client = self.client.clone(); - - let url = url.to_string(); - - let factory: RequestFactory = - Box::new(move || { - let client = client.clone(); - - let url = url.clone(); - - let stream = stream_factory(); - - Box::pin(async move { - let body = - reqwest::Body::wrap_stream( - stream, - ); - - let request = client - .post(url) - .body(body) - .build()?; - - let ctx = RequestContext { - method: request - .method() - .clone(), - - url: request - .url() - .to_string(), - }; - - Ok((ctx, request)) - }) - }); - - let response = self - .retry_middleware - .execute(&self.client, factory) - .await?; - - Ok(response.json().await?) - } - - // ======================================================== - // Retry Predicate - // ======================================================== - - pub fn with_retry_predicate( - mut self, - predicate: F, - ) -> Self - where - F: Fn( - &RequestContext, - &RequestResult, - ) -> bool - + Send - + Sync - + 'static, - { - self.retry_middleware = self - .retry_middleware - .with_retry_predicate(predicate); - - self - } - - // ======================================================== - // Inner Client - // ======================================================== - - pub fn inner(&self) -> &Client { - &self.client - } -} \ No newline at end of file diff --git a/src/core/transport/errors.rs b/src/core/transport/errors.rs new file mode 100644 index 0000000..04f9a02 --- /dev/null +++ b/src/core/transport/errors.rs @@ -0,0 +1,19 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum TransportError { + #[error("http error: {0}")] + Http(#[from] reqwest::Error), + + #[error("request timeout")] + Timeout, + + #[error("rate limited")] + RateLimited, + + #[error("server error: {0}")] + Server(String), + + #[error("invalid response: {0}")] + InvalidResponse(String), +} \ No newline at end of file diff --git a/src/core/transport/http_transoprt.rs b/src/core/transport/http_transoprt.rs new file mode 100644 index 0000000..116136b --- /dev/null +++ b/src/core/transport/http_transoprt.rs @@ -0,0 +1,103 @@ +use std::time::Duration; + +use tokio::time::sleep; +use crate::core::transport::errors::TransportError; +use crate::core::transport::request::TransportRequest; +use crate::core::transport::response::TransportResponse; +use crate::core::transport::retry::RetryPolicy; +use crate::core::client::base_client::BaseClient; +#[derive(Clone)] +pub struct HttpTransport { + client: BaseClient, + retry: RetryPolicy, +} + +impl HttpTransport { + pub fn new( + client: BaseClient, + retry: RetryPolicy, + ) -> Self { + Self { client, retry } + } + + pub async fn execute( + &self, + req: TransportRequest, + ) -> Result { + let mut retries = 0; + + loop { + let request = self.build_request(req.clone())?; + + match self.client.execute(request).await { + Ok(resp) => { + let status = resp.status(); + + if status.as_u16() == 429 { + if retries >= self.retry.max_retries { + return Err(TransportError::RateLimited); + } + + retries += 1; + + sleep(self.backoff(retries)).await; + continue; + } + + if status.is_server_error() { + if retries >= self.retry.max_retries { + return Err( + TransportError::Server(status.to_string()) + ); + } + + retries += 1; + + sleep(self.backoff(retries)).await; + continue; + } + + let body = resp.bytes().await?; + + return Ok(TransportResponse { + status, + body, + }); + } + + Err(err) => { + if retries >= self.retry.max_retries { + return Err(TransportError::Http(err)); + } + + retries += 1; + + sleep(self.backoff(retries)).await; + } + } + } + } + + fn build_request( + &self, + req: TransportRequest, + ) -> Result { + let mut builder = self.client + .inner() + .request(req.method, req.url); + + for (k, v) in req.headers { + builder = builder.header(k, v); + } + + if let Some(body) = req.body { + builder = builder.body(body); + } + + Ok(builder.build()?) + } + + fn backoff(&self, retry: usize) -> Duration { + self.retry.base_delay * retry as u32 + } +} \ No newline at end of file diff --git a/src/core/transport/mod.rs b/src/core/transport/mod.rs index 2322d1e..d7162ee 100644 --- a/src/core/transport/mod.rs +++ b/src/core/transport/mod.rs @@ -1 +1,7 @@ -mod client; \ No newline at end of file +mod client; +mod errors; +mod request; +mod response; +mod retry; +mod http_transoprt; +mod stream; \ No newline at end of file diff --git a/src/core/transport/request.rs b/src/core/transport/request.rs new file mode 100644 index 0000000..769344c --- /dev/null +++ b/src/core/transport/request.rs @@ -0,0 +1,11 @@ +use bytes::Bytes; +use reqwest::Method; +use std::collections::HashMap; + +#[derive(Clone)] +pub struct TransportRequest { + pub method: Method, + pub url: String, + pub headers: HashMap, + pub body: Option, +} \ No newline at end of file diff --git a/src/core/transport/response.rs b/src/core/transport/response.rs new file mode 100644 index 0000000..25a74e6 --- /dev/null +++ b/src/core/transport/response.rs @@ -0,0 +1,7 @@ +use bytes::Bytes; +use reqwest::StatusCode; + +pub struct TransportResponse { + pub status: StatusCode, + pub body: Bytes, +} \ No newline at end of file diff --git a/src/core/transport/retry.rs b/src/core/transport/retry.rs new file mode 100644 index 0000000..8357848 --- /dev/null +++ b/src/core/transport/retry.rs @@ -0,0 +1,16 @@ +use std::time::Duration; + +#[derive(Clone)] +pub struct RetryPolicy { + pub max_retries: usize, + pub base_delay: Duration, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + max_retries: 3, + base_delay: Duration::from_millis(500), + } + } +} \ No newline at end of file diff --git a/src/core/transport/stream.rs b/src/core/transport/stream.rs new file mode 100644 index 0000000..694ba11 --- /dev/null +++ b/src/core/transport/stream.rs @@ -0,0 +1,17 @@ +use futures_util::StreamExt; + +pub async fn read_sse( + response: reqwest::Response, +) -> anyhow::Result<()> { + let mut stream = response.bytes_stream(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + + let text = String::from_utf8_lossy(&chunk); + + println!("{}", text); + } + + Ok(()) +} \ No newline at end of file