Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ default-members = ["cli", "client-cli", "daemon", "."]

[workspace.dependencies]
base64 = "0.22"
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.5", features = ["cargo", "derive", "env"] }
defguard_wireguard_rs = "0.10"
dirs-next = "2.0"
Expand Down Expand Up @@ -76,7 +77,7 @@ redundant_closure = "warn"
anyhow = "1.0"
base64.workspace = true
clap.workspace = true
chrono = { version = "0.4", features = ["serde"] }
chrono.workspace = true
defguard-client-proto = { path = "client-proto" }
defguard-client-core = { path = "core" }
defguard-client-posture = { path = "enterprise/posture" }
Expand Down
1 change: 1 addition & 0 deletions src-tauri/client-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ version.workspace = true
name = "defguard-cli"

[dependencies]
chrono.workspace = true
clap = { workspace = true, features = ["cargo", "derive", "env"] }
owo-colors = { version = "4", features = ["supports-colors"] }

Expand Down
6 changes: 4 additions & 2 deletions src-tauri/client-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use common::check_version_flag;
mod brand;
mod cli;
mod commands;
mod config_poll;
mod exit;
mod logging;
mod mfa;
mod mfa_code;
mod mfa_qr;
mod monitor;
mod output;
mod polling;
mod resolve;
mod state;
#[cfg(all(test, target_os = "linux"))]
Expand Down Expand Up @@ -52,7 +53,8 @@ async fn main() -> ExitCode {
}
};

config_poll::poll_config(&state).await;
polling::poll_config(&state).await;
monitor::tear_down_stale_connections(&state).await;

// Dispatch command.
match cli.command {
Expand Down
47 changes: 47 additions & 0 deletions src-tauri/client-cli/src/monitor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use chrono::Utc;
use defguard_core::connection::active_state::{active_state, ActiveConnectionInfo};
use tracing::error;

use crate::state::State;

/// Determine whether a connection is stale based on its latest WireGuard handshake.
///
/// Returns `None` when live backend stats are unavailable or the connection has no
/// recorded handshake, because in that case the CLI cannot safely decide whether the
/// connection is stale.
fn is_stale(connection: &ActiveConnectionInfo, _peer_alive_period: u32) -> Option<bool> {
let last_handshake = connection.stats.as_ref()?.last_handshake?;
let now: u64 = Utc::now().timestamp().try_into().ok()?;

// Some(now.saturating_sub(last_handshake) > u64::from(peer_alive_period))
Some(now.saturating_sub(last_handshake) > u64::from(20u64))
}

/// Disconnect active connections whose latest handshake is older than the configured
/// peer alive period.
///
/// Connections without usable live stats are left untouched. Failures are logged and do
/// not stop cleanup of the remaining connections.
pub async fn tear_down_stale_connections(state: &State) {
let connections = match active_state(&state.pool).await {
Ok(connections) => connections,
Err(err) => {
error!("Failed to retrieve active connections: {err}");
return;
}
};
if connections.is_empty() {
return;
}

for connection in connections {
if is_stale(&connection, state.app_config.peer_alive_period).is_some_and(|v| v) {
use defguard_core::connection::tear_down;

let result = tear_down(&connection).await;
if let Err(err) = result {
error!("Error removing stale connection {}: {err}", connection.name);
}
}
}
}
23 changes: 20 additions & 3 deletions src-tauri/core/src/connection/active_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use objc2_network_extension::NEVPNStatus;
#[cfg(not(target_os = "macos"))]
use tonic::Code;

#[cfg(target_os = "macos")]
use crate::connection::apple::tunnel_stats;
#[cfg(target_os = "macos")]
use crate::database::models::get_all_tunnels_locations;
#[cfg(not(target_os = "macos"))]
Expand Down Expand Up @@ -62,7 +64,7 @@ pub struct InterfaceStats {
/// **unfiltered** snapshot of all managed interfaces (unlike `ReadInterfaceData`, which
/// drops peers that haven't completed a handshake or whose stats haven't changed).
///
/// On macOS the Network Extension path is stubbed (pending the NE spike).
/// On macOS this queries Network Extension managers and asks connected providers for stats.
#[cfg(target_os = "macos")]
pub async fn active_state(_pool: &DbPool) -> Result<Vec<ActiveConnectionInfo>, Error> {
let (tunnels, locations) = get_all_tunnels_locations().await;
Expand All @@ -73,25 +75,40 @@ pub async fn active_state(_pool: &DbPool) -> Result<Vec<ActiveConnectionInfo>, E
let mut result = Vec::new();
for location in locations {
if let Some(NEVPNStatus::Connected) = location.status() {
let stats = tunnel_stats(location.id, &ConnectionType::Location).map(|stats| {
InterfaceStats {
listen_port: 0,
tx_bytes: stats.tx_bytes,
rx_bytes: stats.rx_bytes,
last_handshake: (stats.last_handshake != 0).then_some(stats.last_handshake),
}
});
let info = ActiveConnectionInfo {
connection_type: ConnectionType::Location,
target_id: location.id,
name: location.name,
interface_name: String::new(),
stats: None, // TODO
stats,
};
result.push(info);
}
}

for tunnel in tunnels {
if let Some(NEVPNStatus::Connected) = tunnel.status() {
let stats =
tunnel_stats(tunnel.id, &ConnectionType::Tunnel).map(|stats| InterfaceStats {
listen_port: 0,
tx_bytes: stats.tx_bytes,
rx_bytes: stats.rx_bytes,
last_handshake: (stats.last_handshake != 0).then_some(stats.last_handshake),
});
let info = ActiveConnectionInfo {
connection_type: ConnectionType::Tunnel,
target_id: tunnel.id,
name: tunnel.name,
interface_name: String::new(),
stats: None, // TODO
stats,
};
result.push(info);
}
Expand Down
100 changes: 95 additions & 5 deletions src-tauri/core/src/connection/apple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ const OBSERVER_CLEANUP_INTERVAL: Duration = Duration::from_secs(30);
use block2::RcBlock;
use objc2::{rc::Retained, runtime::ProtocolObject};
use objc2_foundation::{
ns_string, NSArray, NSDate, NSError, NSNotification, NSNotificationCenter, NSNumber,
ns_string, NSArray, NSData, NSDate, NSError, NSNotification, NSNotificationCenter, NSNumber,
NSObjectProtocol, NSOperationQueue, NSRunLoop, NSString,
};
use objc2_network_extension::{
NETunnelProviderManager, NETunnelProviderProtocol, NEVPNConnection,
NETunnelProviderManager, NETunnelProviderProtocol, NETunnelProviderSession, NEVPNConnection,
NEVPNStatusDidChangeNotification,
};
use serde::Deserialize;

use crate::database::{
models::{location::Location, tunnel::Tunnel, Id},
DB_POOL,
use crate::{
database::{
models::{location::Location, tunnel::Tunnel, Id},
DB_POOL,
},
ConnectionType,
};

pub const PLUGIN_BUNDLE_ID: &str = "net.defguard.VPNExtension";
Expand Down Expand Up @@ -147,6 +151,92 @@ pub fn spawn_runloop_and_wait_for(semaphore: &Arc<AtomicBool>) {
}
}

/// Tunnel statistics shared with VPNExtension (written in Swift).
#[derive(Deserialize)]
#[repr(C)]
#[serde(rename_all = "camelCase")]
pub struct Stats {
pub location_id: Option<Id>,
pub tunnel_id: Option<Id>,
pub tx_bytes: u64,
pub rx_bytes: u64,
pub last_handshake: u64,
}

/// Retrieve VPN tunnel statistics from VPNExtension.
pub fn tunnel_stats(id: Id, connection_type: &ConnectionType) -> Option<Stats> {
let new_stats = Arc::new(Mutex::new(None));
let plugin_bundle_id = ns_string!(PLUGIN_BUNDLE_ID);

let new_stats_clone = Arc::clone(&new_stats);

let finished = Arc::new(AtomicBool::new(false));
let finished_clone = Arc::clone(&finished);

let response_handler = RcBlock::new(move |data_ptr: *mut NSData| {
if let Some(data) = unsafe { data_ptr.as_ref() } {
if let Ok(stats) = serde_json::from_slice(data.to_vec().as_slice()) {
if let Ok(mut new_stats_locked) = new_stats_clone.lock() {
*new_stats_locked = Some(stats);
}
} else {
warn!("Failed to deserialize tunnel stats");
}
} else {
debug!("No data received in tunnel stats response, skipping");
}
finished_clone.store(true, Ordering::Release);
});

let manager = manager_for_key_and_value(
match connection_type {
ConnectionType::Location => LOCATION_ID,
ConnectionType::Tunnel => TUNNEL_ID,
},
id,
)?;

let vpn_protocol = (unsafe { manager.protocolConfiguration() })?;
let Ok(tunnel_protocol) = vpn_protocol.downcast::<NETunnelProviderProtocol>() else {
error!("Failed to downcast to NETunnelProviderProtocol");
return None;
};

// Sometimes all managers from all apps come through, so filter by bundle ID.
if let Some(bundle_id) = unsafe { tunnel_protocol.providerBundleIdentifier() } {
if &*bundle_id != plugin_bundle_id {
return None;
}
}

let Ok(session) = unsafe { manager.connection() }.downcast::<NETunnelProviderSession>() else {
error!("Failed to downcast to NETunnelProviderSession");
return None;
};

let message_data = NSData::new();
if unsafe {
session.sendProviderMessage_returnError_responseHandler(
&message_data,
None,
Some(&response_handler),
)
} {
debug!("Message sent to NETunnelProviderSession");
} else {
error!("Failed to send to NETunnelProviderSession while requesting stats");
}

// Wait for the response handler to complete.
while !finished.load(Ordering::Acquire) {
spin_loop();
}

new_stats
.lock()
.map_or(None, |mut new_stats_locked| new_stats_locked.take())
}

/// Handle VPN status change.
fn vpn_status_change_handler(notification: &NSNotification) {
let name = notification.name();
Expand Down
Loading
Loading