Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions crates/mqtt5/src/broker/client_handler/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ enum AuthOutcome {
Failed(MqttError),
}

fn clamp_keep_alive_to_u16(keep_alive: Duration) -> u16 {
let secs = keep_alive.as_secs();
u16::try_from(secs).unwrap_or_else(|_| {
warn!(
"Configured server_keep_alive {}s exceeds the u16 wire range; clamping to {}s",
secs,
u16::MAX,
);
u16::MAX
})
}

impl ClientHandler {
pub(super) async fn validate_protocol_version(&mut self, protocol_version: u8) -> Result<()> {
match protocol_version {
Expand Down Expand Up @@ -362,7 +374,7 @@ impl ClientHandler {
}

fn build_connack_properties(
&self,
&mut self,
connack: &mut ConnAckPacket,
assigned_client_id: Option<&String>,
) {
Expand Down Expand Up @@ -401,9 +413,14 @@ impl ClientHandler {
}

if let Some(keep_alive) = self.config.server_keep_alive {
connack
.properties
.set_server_keep_alive(u16::try_from(keep_alive.as_secs()).unwrap_or(u16::MAX));
let secs = clamp_keep_alive_to_u16(keep_alive);
connack.properties.set_server_keep_alive(secs);
self.keep_alive = Duration::from_secs(u64::from(secs));
debug!(
client_id = ?self.client_id,
negotiated_keep_alive_secs = secs,
"Broker overrode client keep-alive via ServerKeepAlive"
);
}

if self.request_response_information {
Expand Down
63 changes: 63 additions & 0 deletions crates/mqtt5/tests/keepalive_negotiation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ use common::TestBroker;
use mqtt5::broker::config::{BrokerConfig, StorageBackend, StorageConfig};
use mqtt5::time::Duration;
use mqtt5::{ConnectOptions, ConnectionEvent, MqttClient};
use mqtt5_protocol::packet::connect::ConnectPacket;
use mqtt5_protocol::packet::MqttPacket;
use mqtt5_protocol::protocol::v5::properties::Properties;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::Mutex;

#[tokio::test]
Expand Down Expand Up @@ -122,6 +127,64 @@ async fn keep_alive_renegotiates_against_each_broker_on_reconnect() {
client.disconnect().await.unwrap();
}

#[tokio::test]
async fn broker_read_timeout_uses_negotiated_keep_alive_not_client_request() {
let storage_config = StorageConfig {
backend: StorageBackend::Memory,
enable_persistence: true,
..Default::default()
};
let mut config = BrokerConfig::default()
.with_bind_address("127.0.0.1:0".parse::<SocketAddr>().unwrap())
.with_storage(storage_config);
config.server_keep_alive = Some(Duration::from_secs(1));

let broker = TestBroker::start_with_config(config).await;
let addr = broker.address().trim_start_matches("mqtt://").to_string();

let mut stream = TcpStream::connect(&addr).await.expect("connect tcp");

let connect = ConnectPacket {
protocol_version: 5,
clean_start: true,
keep_alive: 600,
client_id: "kanego-broker-timeout".to_string(),
username: None,
password: None,
will: None,
properties: Properties::new(),
will_properties: Properties::new(),
};
let mut buf = Vec::new();
connect.encode(&mut buf).expect("encode CONNECT");
stream.write_all(&buf).await.expect("write CONNECT");
stream.flush().await.expect("flush");

let mut connack = [0u8; 64];
let n = tokio::time::timeout(Duration::from_secs(2), stream.read(&mut connack))
.await
.expect("CONNACK read timed out")
.expect("read CONNACK");
assert!(n > 0, "broker closed before sending CONNACK");

let start = tokio::time::Instant::now();
let mut drain = [0u8; 256];
let read_result = tokio::time::timeout(Duration::from_secs(4), stream.read(&mut drain)).await;
let elapsed = start.elapsed();

let bytes = read_result
.expect("broker did not close connection within 4s — read_timeout still using client's 600s request")
.expect("socket read error");
assert_eq!(
bytes, 0,
"expected EOF after broker timeout, got {bytes} bytes",
);
assert!(
elapsed < Duration::from_secs(3),
"broker disconnect took {elapsed:?}; expected ~1.5s (negotiated 1s * 1.5)",
);
}

#[tokio::test]
async fn server_keep_alive_zero_disables_keepalive_even_with_nonzero_client_request() {
let storage_config = StorageConfig {
Expand Down
Loading