diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 6126e667..7f21cafd 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1575,6 +1575,7 @@ name = "defguard-cli" version = "2.1.0" dependencies = [ "base64 0.22.1", + "chrono", "clap", "defguard-client-common", "defguard-client-config-sync", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 371d6696..dbd72fbb 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -4,6 +4,7 @@ default-members = ["cli", "client-cli", "daemon", "."] [workspace.dependencies] base64 = "0.22" +chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.5", features = ["cargo", "derive", "env"] } defguard_wireguard_rs = "0.10" dirs-next = "2.0" @@ -76,7 +77,7 @@ redundant_closure = "warn" anyhow = "1.0" base64.workspace = true clap.workspace = true -chrono = { version = "0.4", features = ["serde"] } +chrono.workspace = true defguard-client-proto = { path = "client-proto" } defguard-client-core = { path = "core" } defguard-client-posture = { path = "enterprise/posture" } diff --git a/src-tauri/client-cli/Cargo.toml b/src-tauri/client-cli/Cargo.toml index cb40bcba..6b5a46cc 100644 --- a/src-tauri/client-cli/Cargo.toml +++ b/src-tauri/client-cli/Cargo.toml @@ -12,6 +12,7 @@ version.workspace = true name = "defguard-cli" [dependencies] +chrono.workspace = true clap = { workspace = true, features = ["cargo", "derive", "env"] } owo-colors = { version = "4", features = ["supports-colors"] } diff --git a/src-tauri/client-cli/src/main.rs b/src-tauri/client-cli/src/main.rs index 81850fd2..edf34dd8 100644 --- a/src-tauri/client-cli/src/main.rs +++ b/src-tauri/client-cli/src/main.rs @@ -6,13 +6,14 @@ use common::check_version_flag; mod brand; mod cli; mod commands; -mod config_poll; mod exit; mod logging; mod mfa; mod mfa_code; mod mfa_qr; +mod monitor; mod output; +mod polling; mod resolve; mod state; #[cfg(all(test, target_os = "linux"))] @@ -52,7 +53,8 @@ async fn main() -> ExitCode { } }; - config_poll::poll_config(&state).await; + polling::poll_config(&state).await; + monitor::tear_down_stale_connections(&state).await; // Dispatch command. match cli.command { diff --git a/src-tauri/client-cli/src/monitor.rs b/src-tauri/client-cli/src/monitor.rs new file mode 100644 index 00000000..432efb4b --- /dev/null +++ b/src-tauri/client-cli/src/monitor.rs @@ -0,0 +1,47 @@ +use chrono::Utc; +use defguard_core::connection::active_state::{active_state, ActiveConnectionInfo}; +use tracing::error; + +use crate::state::State; + +/// Determine whether a connection is stale based on its latest WireGuard handshake. +/// +/// Returns `None` when live backend stats are unavailable or the connection has no +/// recorded handshake, because in that case the CLI cannot safely decide whether the +/// connection is stale. +fn is_stale(connection: &ActiveConnectionInfo, _peer_alive_period: u32) -> Option { + let last_handshake = connection.stats.as_ref()?.last_handshake?; + let now: u64 = Utc::now().timestamp().try_into().ok()?; + + // Some(now.saturating_sub(last_handshake) > u64::from(peer_alive_period)) + Some(now.saturating_sub(last_handshake) > u64::from(20u64)) +} + +/// Disconnect active connections whose latest handshake is older than the configured +/// peer alive period. +/// +/// Connections without usable live stats are left untouched. Failures are logged and do +/// not stop cleanup of the remaining connections. +pub async fn tear_down_stale_connections(state: &State) { + let connections = match active_state(&state.pool).await { + Ok(connections) => connections, + Err(err) => { + error!("Failed to retrieve active connections: {err}"); + return; + } + }; + if connections.is_empty() { + return; + } + + for connection in connections { + if is_stale(&connection, state.app_config.peer_alive_period).is_some_and(|v| v) { + use defguard_core::connection::tear_down; + + let result = tear_down(&connection).await; + if let Err(err) = result { + error!("Error removing stale connection {}: {err}", connection.name); + } + } + } +} diff --git a/src-tauri/client-cli/src/config_poll.rs b/src-tauri/client-cli/src/polling.rs similarity index 100% rename from src-tauri/client-cli/src/config_poll.rs rename to src-tauri/client-cli/src/polling.rs diff --git a/src-tauri/core/src/connection/active_state.rs b/src-tauri/core/src/connection/active_state.rs index 46559079..ae826173 100644 --- a/src-tauri/core/src/connection/active_state.rs +++ b/src-tauri/core/src/connection/active_state.rs @@ -18,6 +18,8 @@ use objc2_network_extension::NEVPNStatus; #[cfg(not(target_os = "macos"))] use tonic::Code; +#[cfg(target_os = "macos")] +use crate::connection::apple::tunnel_stats; #[cfg(target_os = "macos")] use crate::database::models::get_all_tunnels_locations; #[cfg(not(target_os = "macos"))] @@ -62,7 +64,7 @@ pub struct InterfaceStats { /// **unfiltered** snapshot of all managed interfaces (unlike `ReadInterfaceData`, which /// drops peers that haven't completed a handshake or whose stats haven't changed). /// -/// On macOS the Network Extension path is stubbed (pending the NE spike). +/// On macOS this queries Network Extension managers and asks connected providers for stats. #[cfg(target_os = "macos")] pub async fn active_state(_pool: &DbPool) -> Result, Error> { let (tunnels, locations) = get_all_tunnels_locations().await; @@ -73,12 +75,20 @@ pub async fn active_state(_pool: &DbPool) -> Result, E let mut result = Vec::new(); for location in locations { if let Some(NEVPNStatus::Connected) = location.status() { + let stats = tunnel_stats(location.id, &ConnectionType::Location).map(|stats| { + InterfaceStats { + listen_port: 0, + tx_bytes: stats.tx_bytes, + rx_bytes: stats.rx_bytes, + last_handshake: (stats.last_handshake != 0).then_some(stats.last_handshake), + } + }); let info = ActiveConnectionInfo { connection_type: ConnectionType::Location, target_id: location.id, name: location.name, interface_name: String::new(), - stats: None, // TODO + stats, }; result.push(info); } @@ -86,12 +96,19 @@ pub async fn active_state(_pool: &DbPool) -> Result, E for tunnel in tunnels { if let Some(NEVPNStatus::Connected) = tunnel.status() { + let stats = + tunnel_stats(tunnel.id, &ConnectionType::Tunnel).map(|stats| InterfaceStats { + listen_port: 0, + tx_bytes: stats.tx_bytes, + rx_bytes: stats.rx_bytes, + last_handshake: (stats.last_handshake != 0).then_some(stats.last_handshake), + }); let info = ActiveConnectionInfo { connection_type: ConnectionType::Tunnel, target_id: tunnel.id, name: tunnel.name, interface_name: String::new(), - stats: None, // TODO + stats, }; result.push(info); } diff --git a/src-tauri/core/src/connection/apple.rs b/src-tauri/core/src/connection/apple.rs index 70b9d34a..ae9420eb 100644 --- a/src-tauri/core/src/connection/apple.rs +++ b/src-tauri/core/src/connection/apple.rs @@ -17,17 +17,21 @@ const OBSERVER_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); use block2::RcBlock; use objc2::{rc::Retained, runtime::ProtocolObject}; use objc2_foundation::{ - ns_string, NSArray, NSDate, NSError, NSNotification, NSNotificationCenter, NSNumber, + ns_string, NSArray, NSData, NSDate, NSError, NSNotification, NSNotificationCenter, NSNumber, NSObjectProtocol, NSOperationQueue, NSRunLoop, NSString, }; use objc2_network_extension::{ - NETunnelProviderManager, NETunnelProviderProtocol, NEVPNConnection, + NETunnelProviderManager, NETunnelProviderProtocol, NETunnelProviderSession, NEVPNConnection, NEVPNStatusDidChangeNotification, }; +use serde::Deserialize; -use crate::database::{ - models::{location::Location, tunnel::Tunnel, Id}, - DB_POOL, +use crate::{ + database::{ + models::{location::Location, tunnel::Tunnel, Id}, + DB_POOL, + }, + ConnectionType, }; pub const PLUGIN_BUNDLE_ID: &str = "net.defguard.VPNExtension"; @@ -147,6 +151,92 @@ pub fn spawn_runloop_and_wait_for(semaphore: &Arc) { } } +/// Tunnel statistics shared with VPNExtension (written in Swift). +#[derive(Deserialize)] +#[repr(C)] +#[serde(rename_all = "camelCase")] +pub struct Stats { + pub location_id: Option, + pub tunnel_id: Option, + pub tx_bytes: u64, + pub rx_bytes: u64, + pub last_handshake: u64, +} + +/// Retrieve VPN tunnel statistics from VPNExtension. +pub fn tunnel_stats(id: Id, connection_type: &ConnectionType) -> Option { + let new_stats = Arc::new(Mutex::new(None)); + let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID); + + let new_stats_clone = Arc::clone(&new_stats); + + let finished = Arc::new(AtomicBool::new(false)); + let finished_clone = Arc::clone(&finished); + + let response_handler = RcBlock::new(move |data_ptr: *mut NSData| { + if let Some(data) = unsafe { data_ptr.as_ref() } { + if let Ok(stats) = serde_json::from_slice(data.to_vec().as_slice()) { + if let Ok(mut new_stats_locked) = new_stats_clone.lock() { + *new_stats_locked = Some(stats); + } + } else { + warn!("Failed to deserialize tunnel stats"); + } + } else { + debug!("No data received in tunnel stats response, skipping"); + } + finished_clone.store(true, Ordering::Release); + }); + + let manager = manager_for_key_and_value( + match connection_type { + ConnectionType::Location => LOCATION_ID, + ConnectionType::Tunnel => TUNNEL_ID, + }, + id, + )?; + + let vpn_protocol = (unsafe { manager.protocolConfiguration() })?; + let Ok(tunnel_protocol) = vpn_protocol.downcast::() else { + error!("Failed to downcast to NETunnelProviderProtocol"); + return None; + }; + + // Sometimes all managers from all apps come through, so filter by bundle ID. + if let Some(bundle_id) = unsafe { tunnel_protocol.providerBundleIdentifier() } { + if &*bundle_id != plugin_bundle_id { + return None; + } + } + + let Ok(session) = unsafe { manager.connection() }.downcast::() else { + error!("Failed to downcast to NETunnelProviderSession"); + return None; + }; + + let message_data = NSData::new(); + if unsafe { + session.sendProviderMessage_returnError_responseHandler( + &message_data, + None, + Some(&response_handler), + ) + } { + debug!("Message sent to NETunnelProviderSession"); + } else { + error!("Failed to send to NETunnelProviderSession while requesting stats"); + } + + // Wait for the response handler to complete. + while !finished.load(Ordering::Acquire) { + spin_loop(); + } + + new_stats + .lock() + .map_or(None, |mut new_stats_locked| new_stats_locked.take()) +} + /// Handle VPN status change. fn vpn_status_change_handler(notification: &NSNotification) { let name = notification.name(); diff --git a/src-tauri/src/apple.rs b/src-tauri/src/apple.rs index d100dcae..ac1dde28 100644 --- a/src-tauri/src/apple.rs +++ b/src-tauri/src/apple.rs @@ -1,28 +1,13 @@ //! Interchangeability and communication with VPNExtension (written in Swift). -use std::{ - collections::HashMap, - hint::spin_loop, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; +use std::{collections::HashMap, time::Duration}; -use block2::RcBlock; use defguard_client_core::connection::{ active_connections::find_connection, - apple::{ - manager_for_key_and_value, LOCATION_ID, PLUGIN_BUNDLE_ID, TUNNEL_ID, VPN_STATE_UPDATE_COMMS, - }, + apple::{manager_for_key_and_value, LOCATION_ID, TUNNEL_ID, VPN_STATE_UPDATE_COMMS}, }; use objc2::rc::Retained; -use objc2_foundation::{ns_string, NSData}; -use objc2_network_extension::{ - NETunnelProviderManager, NETunnelProviderProtocol, NETunnelProviderSession, NEVPNStatus, -}; -use serde::Deserialize; +use objc2_network_extension::{NETunnelProviderManager, NEVPNStatus}; use tauri::{AppHandle, Emitter, Manager}; use tokio::time::sleep; use tracing::Level; @@ -265,18 +250,6 @@ async fn sync_connections_with_system(app_handle: &AppHandle) { } } -/// Tunnel statistics shared with VPNExtension (written in Swift). -#[derive(Deserialize)] -#[repr(C)] -#[serde(rename_all = "camelCase")] -pub(crate) struct Stats { - pub(crate) location_id: Option, - pub(crate) tunnel_id: Option, - pub(crate) tx_bytes: u64, - pub(crate) rx_bytes: u64, - pub(crate) last_handshake: u64, -} - #[must_use] pub fn get_managers_for_tunnels_and_locations( tunnels: &[Tunnel], @@ -298,79 +271,3 @@ pub fn get_managers_for_tunnels_and_locations( managers } - -/// Retrieve VPN tunnel statistics from VPNExtension. -pub(crate) fn tunnel_stats(id: Id, connection_type: &ConnectionType) -> Option { - let new_stats = Arc::new(Mutex::new(None)); - let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID); - - let new_stats_clone = Arc::clone(&new_stats); - - let finished = Arc::new(AtomicBool::new(false)); - let finished_clone = Arc::clone(&finished); - - let response_handler = RcBlock::new(move |data_ptr: *mut NSData| { - if let Some(data) = unsafe { data_ptr.as_ref() } { - if let Ok(stats) = serde_json::from_slice(data.to_vec().as_slice()) { - if let Ok(mut new_stats_locked) = new_stats_clone.lock() { - *new_stats_locked = Some(stats); - } - } else { - warn!("Failed to deserialize tunnel stats"); - } - } else { - debug!("No data received in tunnel stats response, skipping"); - } - finished_clone.store(true, Ordering::Release); - }); - - let manager = manager_for_key_and_value( - match connection_type { - ConnectionType::Location => LOCATION_ID, - ConnectionType::Tunnel => TUNNEL_ID, - }, - id, - )?; - - let vpn_protocol = (unsafe { manager.protocolConfiguration() })?; - let Ok(tunnel_protocol) = vpn_protocol.downcast::() else { - error!("Failed to downcast to NETunnelProviderProtocol"); - return None; - }; - - // Sometimes all managers from all apps come through, so filter by bundle ID. - if let Some(bundle_id) = unsafe { tunnel_protocol.providerBundleIdentifier() } { - if &*bundle_id != plugin_bundle_id { - return None; - } - } - - let Ok(session) = unsafe { manager.connection() }.downcast::() else { - error!("Failed to downcast to NETunnelProviderSession"); - return None; - }; - - let message_data = NSData::new(); - if unsafe { - session.sendProviderMessage_returnError_responseHandler( - &message_data, - None, - Some(&response_handler), - ) - } { - debug!("Message sent to NETunnelProviderSession"); - } else { - error!("Failed to send to NETunnelProviderSession while requesting stats"); - } - - // Wait for all handlers to complete. - while !finished.load(Ordering::Acquire) { - spin_loop(); - } - - let stats = new_stats - .lock() - .map_or(None, |mut new_stats_locked| new_stats_locked.take()); - - stats -} diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index f68a7db8..0d3a118f 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -30,8 +30,6 @@ use windows_service::{ #[cfg(windows)] use windows_sys::Win32::Foundation::ERROR_SERVICE_DOES_NOT_EXIST; -#[cfg(target_os = "macos")] -use crate::apple::tunnel_stats; #[cfg(not(target_os = "macos"))] use crate::database::models::{ location_stats::peer_to_location_stats, tunnel::peer_to_tunnel_stats, @@ -48,6 +46,8 @@ use crate::{ log_watcher::service_log_watcher::spawn_log_watcher_task, ConnectionType, }; +#[cfg(target_os = "macos")] +use defguard_client_core::connection::apple::tunnel_stats; // Work-around MFA propagation delay. FIXME: remove once Core API is corrected. #[cfg(target_os = "macos")]