Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ hex = "0.4"
# Utilities
chrono = { version = "0.4", features = ["serde"] }
url = "2"
uuid = { version = "1", features = ["v4"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
log = "0.4.29"
regex = "1.12.3"
parking_lot = "0.12"

[dev-dependencies]
tokio-test = "0.4"
Expand Down
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! HTTP client for Bybit REST API.

use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
use serde::de::DeserializeOwned;
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tracing::{debug, warn};

use crate::auth::{generate_signature, get_timestamp};
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod config;
mod constants;
mod error;
mod models;
mod utils;

// API modules
pub mod api;
Expand Down
2 changes: 1 addition & 1 deletion src/models/trade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl PlaceOrderParams {
None => {
return Err(BybitError::InvalidParam(
"price is required for limit orders".into(),
))
));
}
Some(p) => {
let price: Decimal = p.parse().map_err(|_| {
Expand Down
67 changes: 67 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use chrono::{DateTime, NaiveDate, TimeDelta, Utc};
use log::{debug};

// use regex::Captures;
use regex::Regex;
use std::str::FromStr;
use std::sync::OnceLock;

#[derive(Debug)]
pub enum OptionType {
Put,
Call,
}

#[derive(Debug)]
pub struct BybitInfo {
pub base: String,
pub expire: DateTime<Utc>,
pub strike_price: f32,
pub side: OptionType,
pub quote: Option<String>,
}

pub fn parse_expiration_date(date: &str) -> DateTime<Utc> {
let naive_date = NaiveDate::parse_from_str(date, "%d%b%y")
.expect("error parsing expire date from bybit symbol");

return naive_date
.and_hms_opt(8, 0, 0)
.expect("error creating utc datetime object from bybit symbol")
.and_utc();
}

pub fn calculate_years_to_maturity(expire: DateTime<Utc>) -> f32 {
debug!("expire date time obj: {}", expire);
let time_to_expiration: TimeDelta = expire - Utc::now();
debug!("time_to_expiration: {}", time_to_expiration);
let seconds_to_expiration = time_to_expiration.num_seconds();
debug!("seconds_to_expiration: {}", seconds_to_expiration);


let years_to_expiration = (seconds_to_expiration) as f32 / (60 * 60 * 24 * 365) as f32;
debug!("years_to_expiration: {}", years_to_expiration);
return years_to_expiration;
}

pub fn extract_bybit_info(symbol: &str) -> Option<BybitInfo> {
static RE: OnceLock<Regex> = OnceLock::new();

let re = RE.get_or_init(|| {
Regex::new(r"(?<base>\w+)-(?<expire>\d+\w+\d+)-(?<strike_price>\d+\.?\d*)-(?<side>C|P)(?:-(?<quote>USDT))?")
.expect("invalid regex extracting bybit infos from symbol!")
});

re.captures(symbol).map(|caps| BybitInfo {
base: caps["base"].to_string(),
expire: parse_expiration_date(&caps["expire"]),
strike_price: f32::from_str(&caps["strike_price"]).unwrap(),
side: match &caps["side"] {
"C" => OptionType::Call,
"P" => OptionType::Put,
_ => unreachable!(),
},
quote: caps.name("quote").map(|m| m.as_str().to_string()),
})
}

99 changes: 88 additions & 11 deletions src/websocket/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! WebSocket client implementation.

use futures_util::{SinkExt, StreamExt};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::net::TcpStream;
Expand All @@ -13,6 +14,7 @@ use crate::auth::{generate_ws_signature, get_timestamp};
use crate::config::WsConfig;
use crate::error::{BybitError, Result};
use crate::websocket::models::*;
use crate::{MAINNET_WS_TRADE, TESTNET_WS_TRADE};

type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type Callback = Arc<dyn Fn(WsMessage) + Send + Sync>;
Expand All @@ -24,6 +26,7 @@ pub struct BybitWebSocket {
callbacks: Arc<RwLock<HashMap<String, Callback>>>,
tx: Option<mpsc::Sender<Message>>,
is_connected: Arc<RwLock<bool>>,
is_trade: bool,
}

impl BybitWebSocket {
Expand All @@ -35,6 +38,7 @@ impl BybitWebSocket {
callbacks: Arc::new(RwLock::new(HashMap::new())),
tx: None,
is_connected: Arc::new(RwLock::new(false)),
is_trade: false,
}
}

Expand All @@ -46,6 +50,7 @@ impl BybitWebSocket {
callbacks: Arc::new(RwLock::new(HashMap::new())),
tx: None,
is_connected: Arc::new(RwLock::new(false)),
is_trade: url == MAINNET_WS_TRADE || url == TESTNET_WS_TRADE,
}
}

Expand Down Expand Up @@ -160,6 +165,9 @@ impl BybitWebSocket {
.get("success")
.and_then(|v| v.as_bool())
.unwrap_or(false)
|| json.get("retCode").and_then(|v| v.as_i64()) == Some(0)
// ^^^ this is for *_WS_TRADE ^^^
// https://bybit-exchange.github.io/docs/v5/websocket/trade/guideline#response-parameters
{
info!("Authentication successful");
} else {
Expand Down Expand Up @@ -201,6 +209,8 @@ impl BybitWebSocket {
}
}
}

debug!("{:#?}", text);
}
Ok(Message::Ping(_)) => {
debug!("Received ping frame");
Expand Down Expand Up @@ -237,16 +247,27 @@ impl BybitWebSocket {
let expires = get_timestamp() + 10000;
let signature = generate_ws_signature(api_secret, expires);

let auth_msg = WsAuthRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "auth".to_string(),
args: vec![
serde_json::Value::String(api_key.clone()),
serde_json::Value::Number(expires.into()),
serde_json::Value::String(signature),
],
let auth_msg= if self.is_trade {
AuthRequest::Trade(WsTradeAuthRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "auth".to_string(),
args: vec![
serde_json::Value::String(api_key.clone()),
serde_json::Value::Number(expires.into()),
serde_json::Value::String(signature),
],
})
} else {
AuthRequest::Public(WsAuthRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "auth".to_string(),
args: vec![
serde_json::Value::String(api_key.clone()),
serde_json::Value::Number(expires.into()),
serde_json::Value::String(signature),
],
})
};

let msg = serde_json::to_string(&auth_msg).map_err(|e| BybitError::Parse(e.to_string()))?;

self.send(msg).await?;
Expand Down Expand Up @@ -291,6 +312,50 @@ impl BybitWebSocket {
self.send(msg).await
}

pub async fn subscribe_mut<F>(&mut self, topics: Vec<String>, callback: F) -> Result<()>
where
F: FnMut(WsMessage) + Send + Sync + 'static,
{
// 1. Wrap the FnMut in a Mutex to "convert" it to an Fn closure
let callback_mutable = Mutex::new(callback);

// 2. Create an Fn closure that locks the mutex and calls the inner FnMut
let wrapped_callback = move |msg: WsMessage| {
let mut cb = callback_mutable.lock();
(&mut *cb)(msg);
};

// 3. Wrap in Arc and cast to your existing Callback type
let callback_arc = Arc::new(wrapped_callback) as Callback;

// --- The rest of the logic remains the same as your original function ---

// Register callbacks
{
let mut cbs = self.callbacks.write().await;
for topic in &topics {
cbs.insert(topic.clone(), callback_arc.clone());
}
}

// Store subscriptions
{
let mut subs = self.subscriptions.write().await;
subs.extend(topics.clone());
}

// Send subscription request
let sub_msg = WsRequest {
req_id: uuid::Uuid::new_v4().to_string(),
op: "subscribe".to_string(),
args: topics,
};

let msg = serde_json::to_string(&sub_msg).map_err(|e| BybitError::Parse(e.to_string()))? ;

self.send(msg).await
}

/// Unsubscribe from topics.
pub async fn unsubscribe(&mut self, topics: Vec<String>) -> Result<()> {
// Remove callbacks
Expand All @@ -314,8 +379,7 @@ impl BybitWebSocket {
args: topics,
};

let msg =
serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;
let msg = serde_json::to_string(&unsub_msg).map_err(|e| BybitError::Parse(e.to_string()))?;

self.send(msg).await
}
Expand All @@ -332,6 +396,19 @@ impl BybitWebSocket {
Ok(())
}



pub async fn send_order(&self, order: WsTradeOrder) -> Result<()> {
debug!("{:#?}", order);
if !self.is_trade {
error!("can t execute a trade on a non trade socket");
return Err(BybitError::Parse(("can t execute a trade on a non trade socket".to_string())));
}
let msg = serde_json::to_string(&order).map_err(|e| BybitError::Parse(e.to_string()))?;
self.send(msg).await

}

/// Check if connected.
pub async fn is_connected(&self) -> bool {
*self.is_connected.read().await
Expand Down
Loading