diff --git a/Cargo.toml b/Cargo.toml index c1740b3..99b1f47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,10 @@ hex = "0.4" # Utilities chrono = { version = "0.4", features = ["serde"] } url = "2" -uuid = { version = "1", features = ["v4"] } +uuid = { version = "1", features = ["serde", "v4", "v7"] } +log = "0.4.29" +regex = "1.12.3" +parking_lot = "0.12" [dev-dependencies] tokio-test = "0.4" diff --git a/src/client.rs b/src/client.rs index 989c2b6..a9c2372 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,8 +1,8 @@ //! HTTP client for Bybit REST API. -use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; -use serde::de::DeserializeOwned; +use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use serde::Serialize; +use serde::de::DeserializeOwned; use tracing::{debug, warn}; use crate::auth::{generate_signature, get_timestamp}; diff --git a/src/lib.rs b/src/lib.rs index b3a47bf..0ad151c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,7 @@ mod config; mod constants; mod error; mod models; +mod utils; // API modules pub mod api; diff --git a/src/models/trade.rs b/src/models/trade.rs index 174ea41..b468d15 100644 --- a/src/models/trade.rs +++ b/src/models/trade.rs @@ -149,7 +149,7 @@ impl PlaceOrderParams { None => { return Err(BybitError::InvalidParam( "price is required for limit orders".into(), - )) + )); } Some(p) => { let price: Decimal = p.parse().map_err(|_| { diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..fa2984d --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,67 @@ +use chrono::{DateTime, NaiveDate, TimeDelta, Utc}; +use log::{debug}; + +// use regex::Captures; +use regex::Regex; +use std::str::FromStr; +use std::sync::OnceLock; + +#[derive(Debug)] +pub enum OptionType { + Put, + Call, +} + +#[derive(Debug)] +pub struct BybitInfo { + pub base: String, + pub expire: DateTime, + pub strike_price: f32, + pub side: OptionType, + pub quote: Option, +} + +pub fn parse_expiration_date(date: &str) -> DateTime { + let naive_date = NaiveDate::parse_from_str(date, "%d%b%y") + .expect("error parsing expire date from bybit symbol"); + + return naive_date + .and_hms_opt(8, 0, 0) + .expect("error creating utc datetime object from bybit symbol") + .and_utc(); +} + +pub fn calculate_years_to_maturity(expire: DateTime) -> f32 { + debug!("expire date time obj: {}", expire); + let time_to_expiration: TimeDelta = expire - Utc::now(); + debug!("time_to_expiration: {}", time_to_expiration); + let seconds_to_expiration = time_to_expiration.num_seconds(); + debug!("seconds_to_expiration: {}", seconds_to_expiration); + + + let years_to_expiration = (seconds_to_expiration) as f32 / (60 * 60 * 24 * 365) as f32; + debug!("years_to_expiration: {}", years_to_expiration); + return years_to_expiration; +} + +pub fn extract_bybit_info(symbol: &str) -> Option { + static RE: OnceLock = OnceLock::new(); + + let re = RE.get_or_init(|| { + Regex::new(r"(?\w+)-(?\d+\w+\d+)-(?\d+\.?\d*)-(?C|P)(?:-(?USDT))?") + .expect("invalid regex extracting bybit infos from symbol!") + }); + + re.captures(symbol).map(|caps| BybitInfo { + base: caps["base"].to_string(), + expire: parse_expiration_date(&caps["expire"]), + strike_price: f32::from_str(&caps["strike_price"]).unwrap(), + side: match &caps["side"] { + "C" => OptionType::Call, + "P" => OptionType::Put, + _ => unreachable!(), + }, + quote: caps.name("quote").map(|m| m.as_str().to_string()), + }) +} + diff --git a/src/websocket/client.rs b/src/websocket/client.rs index ef8e7bc..b17eb0d 100644 --- a/src/websocket/client.rs +++ b/src/websocket/client.rs @@ -1,6 +1,7 @@ //! WebSocket client implementation. use futures_util::{SinkExt, StreamExt}; +use parking_lot::Mutex; use std::collections::HashMap; use std::sync::Arc; use tokio::net::TcpStream; @@ -13,6 +14,7 @@ use crate::auth::{generate_ws_signature, get_timestamp}; use crate::config::WsConfig; use crate::error::{BybitError, Result}; use crate::websocket::models::*; +use crate::{MAINNET_WS_TRADE, TESTNET_WS_TRADE}; type WsStream = WebSocketStream>; type Callback = Arc; @@ -24,6 +26,7 @@ pub struct BybitWebSocket { callbacks: Arc>>, tx: Option>, is_connected: Arc>, + is_trade: bool, } impl BybitWebSocket { @@ -35,6 +38,7 @@ impl BybitWebSocket { callbacks: Arc::new(RwLock::new(HashMap::new())), tx: None, is_connected: Arc::new(RwLock::new(false)), + is_trade: false, } } @@ -46,6 +50,7 @@ impl BybitWebSocket { callbacks: Arc::new(RwLock::new(HashMap::new())), tx: None, is_connected: Arc::new(RwLock::new(false)), + is_trade: url == MAINNET_WS_TRADE || url == TESTNET_WS_TRADE, } } @@ -160,6 +165,9 @@ impl BybitWebSocket { .get("success") .and_then(|v| v.as_bool()) .unwrap_or(false) + || json.get("retCode").and_then(|v| v.as_i64()) == Some(0) + // ^^^ this is for *_WS_TRADE ^^^ + // https://bybit-exchange.github.io/docs/v5/websocket/trade/guideline#response-parameters { info!("Authentication successful"); } else { @@ -201,6 +209,8 @@ impl BybitWebSocket { } } } + + debug!("{:#?}", text); } Ok(Message::Ping(_)) => { debug!("Received ping frame"); @@ -237,16 +247,27 @@ impl BybitWebSocket { let expires = get_timestamp() + 10000; let signature = generate_ws_signature(api_secret, expires); - let auth_msg = WsAuthRequest { - req_id: uuid::Uuid::new_v4().to_string(), - op: "auth".to_string(), - args: vec![ - serde_json::Value::String(api_key.clone()), - serde_json::Value::Number(expires.into()), - serde_json::Value::String(signature), - ], + let auth_msg= if self.is_trade { + AuthRequest::Trade(WsTradeAuthRequest { + req_id: uuid::Uuid::new_v4().to_string(), + op: "auth".to_string(), + args: vec![ + serde_json::Value::String(api_key.clone()), + serde_json::Value::Number(expires.into()), + serde_json::Value::String(signature), + ], + }) + } else { + AuthRequest::Public(WsAuthRequest { + req_id: uuid::Uuid::new_v4().to_string(), + op: "auth".to_string(), + args: vec![ + serde_json::Value::String(api_key.clone()), + serde_json::Value::Number(expires.into()), + serde_json::Value::String(signature), + ], + }) }; - let msg = serde_json::to_string(&auth_msg).map_err(|e| BybitError::Parse(e.to_string()))?; self.send(msg).await?; @@ -291,6 +312,50 @@ impl BybitWebSocket { self.send(msg).await } + pub async fn subscribe_mut(&mut self, topics: Vec, callback: F) -> Result<()> +where + F: FnMut(WsMessage) + Send + Sync + 'static, +{ + // 1. Wrap the FnMut in a Mutex to "convert" it to an Fn closure + let callback_mutable = Mutex::new(callback); + + // 2. Create an Fn closure that locks the mutex and calls the inner FnMut + let wrapped_callback = move |msg: WsMessage| { + let mut cb = callback_mutable.lock(); + (&mut *cb)(msg); + }; + + // 3. Wrap in Arc and cast to your existing Callback type + let callback_arc = Arc::new(wrapped_callback) as Callback; + + // --- The rest of the logic remains the same as your original function --- + + // Register callbacks + { + let mut cbs = self.callbacks.write().await; + for topic in &topics { + cbs.insert(topic.clone(), callback_arc.clone()); + } + } + + // Store subscriptions + { + let mut subs = self.subscriptions.write().await; + subs.extend(topics.clone()); + } + + // Send subscription request + let sub_msg = WsRequest { + req_id: uuid::Uuid::new_v4().to_string(), + op: "subscribe".to_string(), + args: topics, + }; + + let msg = serde_json::to_string(&sub_msg).map_err(|e| BybitError::Parse(e.to_string()))? ; + + self.send(msg).await +} + /// Unsubscribe from topics. pub async fn unsubscribe(&mut self, topics: Vec) -> Result<()> { // Remove callbacks @@ -314,8 +379,7 @@ impl BybitWebSocket { args: topics, }; - let msg = - serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?; + let msg = serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?; self.send(msg).await } @@ -332,6 +396,19 @@ impl BybitWebSocket { Ok(()) } + + + pub async fn send_order(&self, order: WsTradeOrder) -> Result<()> { + debug!("{:#?}", order); + if !self.is_trade { + error!("can t execute a trade on a non trade socket"); + return Err(BybitError::Parse(("can t execute a trade on a non trade socket".to_string()))); + } + let msg = serde_json::to_string(&order).map_err(|e| BybitError::Parse(e.to_string()))?; + self.send(msg).await + + } + /// Check if connected. pub async fn is_connected(&self) -> bool { *self.is_connected.read().await diff --git a/src/websocket/models.rs b/src/websocket/models.rs index 3140ffb..06ad322 100644 --- a/src/websocket/models.rs +++ b/src/websocket/models.rs @@ -24,6 +24,31 @@ pub struct WsAuthRequest { pub args: Vec, } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WsTradeAuthRequest { + pub req_id: String, + pub op: String, + pub args: Vec, +} + +#[derive(Serialize)] +#[serde(untagged)] // Fondamentale per mantenere il formato JSON originale +pub enum AuthRequest { + Trade(WsTradeAuthRequest), + Public(WsAuthRequest), +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WsTradeAuthResponse { + pub req_id: String, + pub ret_code: i32, + pub ret_msg: String, + pub op: String, + pub conn_id: String, +} + /// WebSocket response. #[derive(Debug, Clone, Deserialize)] pub struct WsResponse { @@ -103,6 +128,140 @@ pub struct WsPong { pub op: Option, } + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WsTradeOrderHeader { + #[serde(rename = "X-BAPI-TIMESTAMP")] + pub x_bapi_timestamp: String, + // #[serde(rename = "X-BAPI-RECV-WINDOW")] + // pub x_bapi_recv_window: String, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum WsTradeOrderCategory { + Spot, + Linear, + Inverse, + Option, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum WsTradeOrderOp { + Create, + Amend, + Delete +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WsTradeOrderArgs { + pub category: String, // linear, inverse, spot, option + pub symbol: String, + pub side: String, // Buy, Sell + pub order_type: String, // Market, Limit + pub qty: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub price: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub is_leverage: Option, // 0: false, 1: true + + #[serde(skip_serializing_if = "Option::is_none")] + pub market_unit: Option, // baseCoin, quoteCoin + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub time_in_force: Option, // GTC, IOC, FOK, PostOnly + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub order_link_id: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub take_profit: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub stop_loss: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub tp_trigger_by: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub sl_trigger_by: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub reduce_only: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub close_on_trigger: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub position_idx: Option, // 0, 1, 2 + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub trigger_price: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub trigger_by: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub tp_limit_price: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub sl_limit_price: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub tp_order_type: Option, + + // #[serde(skip_serializing_if = "Option::is_none")] + // pub sl_order_type: Option, + +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct OptionTickerData { + pub symbol: String, + pub ask_iv: String, + pub ask_price: String, + pub ask_size: String, + pub bid_iv: String, + pub bid_price: String, + + pub bid_size: String, + + pub delta: String, + + pub gamma: String, + + pub theta: String, + + pub vega: String, + + pub mark_price: String, + + pub index_price: String, + + pub underlying_price: String, + + pub open_interest: String, + + pub volume24h: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WsTradeOrder { + pub req_id: uuid::Uuid, + pub header: WsTradeOrderHeader, + /// order.create order.amend order.cancel + pub op: String, + pub args: Vec + + +} + /// Check if message is a pong response. pub fn is_pong(msg: &serde_json::Value) -> bool { if let Some(op) = msg.get("op").and_then(|v| v.as_str()) {