From 55b88a8f06b8b48796ccf84835522e82bead1bd4 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Thu, 11 Jun 2026 09:33:10 +0200 Subject: [PATCH 1/3] fix: rollback malformed handshake consumption Snapshot handshake consumption state before each state-machine step and roll back transcript, DTLS 1.2 peer sequence, and tentative queued handshakes when a transient parse error rejects the message body. Add DTLS 1.2 and DTLS 1.3 regression tests for malformed ClientHello extensions followed by clean retransmissions. Co-Authored-By: Claude --- src/dtls12/client.rs | 13 +++++++-- src/dtls12/engine.rs | 35 ++++++++++++++++++++++++ src/dtls12/server.rs | 11 +++++++- src/dtls13/client.rs | 11 +++++++- src/dtls13/engine.rs | 34 +++++++++++++++++++++++ src/dtls13/server.rs | 11 +++++++- src/error.rs | 4 +++ tests/dtls12/edge.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++ tests/dtls13/edge.rs | 55 +++++++++++++++++++++++++++++++++++++ 9 files changed, 234 insertions(+), 5 deletions(-) diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 71ac7a70..84bb12fd 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -263,8 +263,17 @@ impl Client { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; - - let new_state = prev_state.make_progress(self)?; + let snapshot = self.engine.handshake_progress_snapshot(); + + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e99..4d88a7d9 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -674,6 +674,35 @@ impl Engine { Ok(Some(handshake)) } + pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot { + HandshakeProgressSnapshot { + peer_handshake_seq_no: self.peer_handshake_seq_no, + transcript_len: self.transcript.len(), + } + } + + pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) { + self.peer_handshake_seq_no = snapshot.peer_handshake_seq_no; + self.transcript.resize(snapshot.transcript_len, 0); + + let mut keep = QueueRx::new(); + while let Some(incoming) = self.queue_rx.pop_front() { + let drop_incoming = incoming + .first() + .first_handshake() + .is_some_and(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + + if drop_incoming { + incoming + .into_records() + .for_each(|r| self.buffers_free.push(r.into_buffer())); + } else { + keep.push_back(incoming); + } + } + self.queue_rx = keep; + } + pub(crate) fn next_record(&mut self, ctype: ContentType) -> Option<&Record> { let record = self .queue_rx @@ -1233,6 +1262,12 @@ impl Engine { } } +#[derive(Clone, Copy)] +pub(crate) struct HandshakeProgressSnapshot { + peer_handshake_seq_no: u16, + transcript_len: usize, +} + impl RecordHandler for Engine { fn classify_record(&mut self, record: Record) -> Result, Error> { let epoch = record.record().sequence.epoch; diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 48e22f1a..2eb239b0 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -261,8 +261,17 @@ impl Server { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index a6c60639..d0d4ecea 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -302,8 +302,17 @@ impl Client { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc34..c722b194 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -858,6 +858,34 @@ impl Engine { Ok(Some(handshake)) } + pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot { + HandshakeProgressSnapshot { + peer_handshake_seq_no: self.peer_handshake_seq_no, + transcript_len: self.transcript.len(), + } + } + + pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) { + self.transcript.resize(snapshot.transcript_len, 0); + + let mut keep = QueueRx::new(); + while let Some(incoming) = self.queue_rx.pop_front() { + let drop_incoming = incoming + .first() + .first_handshake() + .is_some_and(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + + if drop_incoming { + incoming + .into_records() + .for_each(|r| self.buffers_free.push(r.into_buffer())); + } else { + keep.push_back(incoming); + } + } + self.queue_rx = keep; + } + /// Advance the expected peer handshake sequence number. /// /// Must be called by the caller of `next_handshake` / `next_handshake_no_transcript` @@ -2495,6 +2523,12 @@ impl RecordHandler for Engine { } } +#[derive(Clone, Copy)] +pub(crate) struct HandshakeProgressSnapshot { + peer_handshake_seq_no: u16, + transcript_len: usize, +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 4711c8d4..7e8d5a1e 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -344,8 +344,17 @@ impl Server { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/error.rs b/src/error.rs index 5856c569..21b90c5b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -684,6 +684,10 @@ impl InternalError { Self::Transient(TransientError::TooManyRecords) } + pub(crate) fn is_transient(&self) -> bool { + matches!(self, Self::Transient(_)) + } + pub(crate) fn into_public_error(self) -> Option { match self { Self::Transient(err) => { diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 6fd95d38..edf395b2 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -62,6 +62,29 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { } } +fn poison_extension_vector_len(packet: &mut [u8], extension_type: u16) { + let marker = extension_type.to_be_bytes(); + + for i in 0..packet.len().saturating_sub(6) { + if packet[i..i + 2] != marker { + continue; + } + + let extension_len = u16::from_be_bytes([packet[i + 2], packet[i + 3]]) as usize; + let extension_data_start = i + 4; + let extension_data_end = extension_data_start + extension_len; + if extension_len < 2 || extension_data_end > packet.len() { + continue; + } + + packet[extension_data_start..extension_data_start + 2] + .copy_from_slice(&(extension_len as u16).to_be_bytes()); + return; + } + + panic!("extension {extension_type:#06x} not found"); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() { @@ -107,6 +130,48 @@ fn dtls12_too_many_control_records_are_discarded() { .expect("too many records should be discarded"); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_malformed_client_hello_extension_does_not_consume_sequence() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls12_config(); + let now = Instant::now(); + + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_12(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout start"); + client.handle_timeout(now).expect("client arm flight 1"); + let client_hello = collect_packets(&mut client); + assert_eq!(client_hello.len(), 1, "client should emit one ClientHello"); + + let mut poisoned = client_hello[0].clone(); + poison_extension_vector_len(&mut poisoned, 0x000e); // use_srtp + server + .handle_packet(&poisoned) + .expect("malformed ClientHello extension should be discarded"); + + server + .handle_packet(&client_hello[0]) + .expect("clean retransmission should still be accepted"); + server.handle_timeout(now).expect("server timeout"); + + let server_flight = collect_packets(&mut server); + assert!( + server_flight + .iter() + .flat_map(|packet| parse_handshake_types(packet)) + .any(|ty| ty == HELLO_VERIFY_REQUEST), + "server should respond to the clean ClientHello retransmission" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_recovers_from_corrupted_packet() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 33e94423..229bc813 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -50,6 +50,29 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { out } +fn poison_extension_vector_len(packet: &mut [u8], extension_type: u16) { + let marker = extension_type.to_be_bytes(); + + for i in 0..packet.len().saturating_sub(6) { + if packet[i..i + 2] != marker { + continue; + } + + let extension_len = u16::from_be_bytes([packet[i + 2], packet[i + 3]]) as usize; + let extension_data_start = i + 4; + let extension_data_end = extension_data_start + extension_len; + if extension_len < 2 || extension_data_end > packet.len() { + continue; + } + + packet[extension_data_start..extension_data_start + 2] + .copy_from_slice(&(extension_len as u16).to_be_bytes()); + return; + } + + panic!("extension {extension_type:#06x} not found"); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() { @@ -95,6 +118,38 @@ fn dtls13_too_many_control_records_are_discarded() { .expect("too many records should be discarded"); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_malformed_client_hello_extension_does_not_poison_transcript() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls13_config(); + let now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout start"); + let client_hello = collect_packets(&mut client); + assert_eq!(client_hello.len(), 1, "client should emit one ClientHello"); + + let mut poisoned = client_hello[0].clone(); + poison_extension_vector_len(&mut poisoned, 0x0033); // key_share + server + .handle_packet(&poisoned) + .expect("malformed ClientHello extension should be discarded"); + + server + .handle_packet(&client_hello[0]) + .expect("clean retransmission should still be accepted"); + complete_dtls13_handshake(&mut client, &mut server, now); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_discards_too_short_ciphertext_record() { From 7f64bdcecf6126e16210e4b0111505778c80441c Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Thu, 11 Jun 2026 19:46:15 +0200 Subject: [PATCH 2/3] fix: handle mixed handshake rollback without queue rebuild Scan all handshakes in each incoming datagram when deciding rollback rejection, including mixed records where the handshake is not first. Mark rejected datagrams handled and let the existing purge path recycle buffers back into the reusable pool instead of rebuilding the receive queue. Key queued handshakes by the first unhandled relevant handshake while preserving duplicate resend behavior, and delay DTLS 1.3 HRR client-state commits until validation completes. Extend regressions for non-handshake-leading malformed handshakes and clean HRR retry after a rejected HRR. Co-Authored-By: Claude --- src/dtls12/engine.rs | 101 +++++++++++++++++++++++++++-------------- src/dtls13/client.rs | 36 +++++++-------- src/dtls13/engine.rs | 93 ++++++++++++++++++++++++------------- tests/dtls12/edge.rs | 15 +++++- tests/dtls13/edge.rs | 4 +- tests/dtls13_cookie.rs | 19 ++++++-- 6 files changed, 180 insertions(+), 88 deletions(-) diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 4d88a7d9..5119f6b3 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -231,6 +231,8 @@ impl Engine { /// Insert a parsed datagram into the receive queue. fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { + self.purge_handled_queue_rx(); + // Capacity guard before iterating records. if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( @@ -241,8 +243,16 @@ impl Engine { return Err(Error::ReceiveQueueFull); } - // Dispatch to specialized handlers - if incoming.first().first_handshake().is_some() { + // Dispatch to specialized handlers. A datagram can start with a + // non-handshake record (for example CCS) and still carry a handshake + // record later in the same packet. + if incoming.first().first_handshake().is_some() + || (!self.release_app_data + && incoming + .records() + .iter() + .any(|r| !r.handshakes().is_empty())) + { self.insert_incoming_handshake(incoming) } else { self.insert_incoming_non_handshake(incoming) @@ -250,16 +260,6 @@ impl Engine { } fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> { - let first_record = incoming.first(); - let handshake = first_record - .first_handshake() - .expect("caller ensures handshake"); - - let key_current = ( - handshake.header.message_seq, - handshake.header.fragment_offset, - ); - let maybe_dupe_seq = incoming .records() .iter() @@ -278,12 +278,19 @@ impl Engine { } } - // Drop old duplicates we've already processed - don't let them block newer messages. - if handshake.header.message_seq < self.peer_handshake_seq_no { + mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); + + let Some(handshake) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + else { return Ok(()); - } + }; - if self.peer_encryption_enabled && first_record.record().sequence.epoch == 0 { + let key_current = ( + handshake.header.message_seq, + handshake.header.fragment_offset, + ); + + if self.peer_encryption_enabled && incoming.first().record().sequence.epoch == 0 { // Keep old plaintext handshake records available long enough to // trigger flight resends above, but never queue or process them as // new messages after peer encryption is enabled. @@ -296,10 +303,7 @@ impl Engine { } let search_result = self.queue_rx.binary_search_by(|item| { - let key_other = item - .first() - .first_handshake() - .as_ref() + let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -685,22 +689,17 @@ impl Engine { self.peer_handshake_seq_no = snapshot.peer_handshake_seq_no; self.transcript.resize(snapshot.transcript_len, 0); - let mut keep = QueueRx::new(); - while let Some(incoming) = self.queue_rx.pop_front() { - let drop_incoming = incoming - .first() - .first_handshake() - .is_some_and(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); - - if drop_incoming { - incoming - .into_records() - .for_each(|r| self.buffers_free.push(r.into_buffer())); - } else { - keep.push_back(incoming); + for incoming in self.queue_rx.iter() { + let reject_incoming = incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + + if reject_incoming { + mark_incoming_handled(incoming); } } - self.queue_rx = keep; } pub(crate) fn next_record(&mut self, ctype: ContentType) -> Option<&Record> { @@ -1268,6 +1267,40 @@ pub(crate) struct HandshakeProgressSnapshot { transcript_len: usize, } +fn first_relevant_handshake(incoming: &Incoming, peer_handshake_seq_no: u16) -> Option<&Handshake> { + incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| !h.is_handled()) + .find(|h| h.header.message_seq >= peer_handshake_seq_no) +} + +fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for handshake in incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| h.header.message_seq < peer_handshake_seq_no) + { + handshake.set_handled(); + } +} + +fn mark_incoming_handled(incoming: &Incoming) { + for record in incoming.records().iter() { + if record.handshakes().is_empty() { + if !record.is_handled() { + record.set_handled(); + } + } else { + for handshake in record.handshakes() { + handshake.set_handled(); + } + } + } +} + impl RecordHandler for Engine { fn classify_record(&mut self, record: Record) -> Result, Error> { let epoch = record.record().sequence.epoch; diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index d0d4ecea..db03d6dc 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -483,22 +483,31 @@ impl State { debug!("Received HelloRetryRequest"); - // Extract selected group and cookie from HRR extensions + let mut hrr_selected_group = None; + let mut saved_cookie = None; + let mut hrr_version_ok = false; + if let Some(ref extensions) = server_hello.extensions { for ext in extensions { match ext.extension_type { ExtensionType::KeyShare => { let ext_data = ext.extension_data(&client.defragment_buffer); - if let Ok((_, hrr_ks)) = KeyShareHelloRetryRequest::parse(ext_data) { - client.hrr_selected_group = Some(hrr_ks.selected_group); - } + let (_, hrr_ks) = KeyShareHelloRetryRequest::parse(ext_data) + .map_err(InternalError::from)?; + hrr_selected_group = Some(hrr_ks.selected_group); } ExtensionType::Cookie => { let ext_data = ext.extension_data(&client.defragment_buffer); parse_cookie_extension(ext_data).map_err(InternalError::from)?; let mut cookie = Buf::new(); cookie.extend_from_slice(ext_data); - client.saved_cookie = Some(cookie); + saved_cookie = Some(cookie); + } + ExtensionType::SupportedVersions => { + let ext_data = ext.extension_data(&client.defragment_buffer); + let (_, sv) = SupportedVersionsServerHello::parse(ext_data) + .map_err(InternalError::from)?; + hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; } _ => {} } @@ -515,26 +524,17 @@ impl State { )) .into()); } - client.engine.set_cipher_suite(server_hello.cipher_suite); - // Validate HRR supported_versions - let mut hrr_version_ok = false; - if let Some(ref extensions) = server_hello.extensions { - for ext in extensions { - if ext.extension_type == ExtensionType::SupportedVersions { - let ext_data = ext.extension_data(&client.defragment_buffer); - if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { - hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; - } - } - } - } if !hrr_version_ok { return Err( (Error::SecurityError(crate::SecurityError::HrrDidNotSelectDtls13)).into(), ); } + client.engine.set_cipher_suite(server_hello.cipher_suite); + client.hrr_selected_group = hrr_selected_group; + client.saved_cookie = saved_cookie; + // Replace transcript with message_hash per RFC 8446 Section 4.4.1. // The HRR was already appended to the transcript by next_handshake(). // We must hash only CH1, then re-append the HRR bytes. diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index c722b194..ee78ace2 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -340,6 +340,8 @@ impl Engine { } fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { + self.purge_handled_queue_rx(); + if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( "Receive queue full (max {}): {:?}", @@ -349,7 +351,13 @@ impl Engine { return Err(Error::ReceiveQueueFull); } - if incoming.first().first_handshake().is_some() { + if incoming.first().first_handshake().is_some() + || (!self.release_app_data + && incoming + .records() + .iter() + .any(|r| !r.handshakes().is_empty())) + { self.insert_incoming_handshake(incoming) } else { self.insert_incoming_non_handshake(incoming) @@ -357,16 +365,6 @@ impl Engine { } fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> { - let first_record = incoming.first(); - let handshake = first_record - .first_handshake() - .expect("caller ensures handshake"); - - let key_current = ( - handshake.header.message_seq, - handshake.header.fragment_offset, - ); - let maybe_dupe_seq = incoming .records() .iter() @@ -380,10 +378,17 @@ impl Engine { } } - // Drop old duplicates we've already processed - if handshake.header.message_seq < self.peer_handshake_seq_no { + mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); + + let Some(handshake) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + else { return Ok(()); - } + }; + + let key_current = ( + handshake.header.message_seq, + handshake.header.fragment_offset, + ); // Reject new handshakes after initial handshake is complete, // but allow KeyUpdate (a post-handshake message). @@ -395,10 +400,7 @@ impl Engine { } let search_result = self.queue_rx.binary_search_by(|item| { - let key_other = item - .first() - .first_handshake() - .as_ref() + let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -868,22 +870,17 @@ impl Engine { pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) { self.transcript.resize(snapshot.transcript_len, 0); - let mut keep = QueueRx::new(); - while let Some(incoming) = self.queue_rx.pop_front() { - let drop_incoming = incoming - .first() - .first_handshake() - .is_some_and(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + for incoming in self.queue_rx.iter() { + let reject_incoming = incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); - if drop_incoming { - incoming - .into_records() - .for_each(|r| self.buffers_free.push(r.into_buffer())); - } else { - keep.push_back(incoming); + if reject_incoming { + mark_incoming_handled(incoming); } } - self.queue_rx = keep; } /// Advance the expected peer handshake sequence number. @@ -2529,6 +2526,40 @@ pub(crate) struct HandshakeProgressSnapshot { transcript_len: usize, } +fn first_relevant_handshake(incoming: &Incoming, peer_handshake_seq_no: u16) -> Option<&Handshake> { + incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| !h.is_handled()) + .find(|h| h.header.message_seq >= peer_handshake_seq_no) +} + +fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for handshake in incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| h.header.message_seq < peer_handshake_seq_no) + { + handshake.set_handled(); + } +} + +fn mark_incoming_handled(incoming: &Incoming) { + for record in incoming.records().iter() { + if record.handshakes().is_empty() { + if !record.is_handled() { + record.set_handled(); + } + } else { + for handshake in record.handshakes() { + handshake.set_handled(); + } + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index edf395b2..2f4af92f 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -32,6 +32,17 @@ fn dtls12_epoch1_record(seq: u64, len: usize) -> Vec { out } +fn dtls12_ccs_record(seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(20); // ChangeCipherSpec + out.extend_from_slice(&[0xFE, 0xFD]); // DTLS 1.2 + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&1u16.to_be_bytes()); // payload length + out.push(1); // change_cipher_spec + out +} + fn dtls12_config_for_suite(suite: Dtls12CipherSuite) -> Arc { let mut provider = Config::default().crypto_provider().clone(); let selected = provider @@ -153,8 +164,10 @@ fn dtls12_malformed_client_hello_extension_does_not_consume_sequence() { let mut poisoned = client_hello[0].clone(); poison_extension_vector_len(&mut poisoned, 0x000e); // use_srtp + let mut mixed_poisoned = dtls12_ccs_record(42); + mixed_poisoned.extend_from_slice(&poisoned); server - .handle_packet(&poisoned) + .handle_packet(&mixed_poisoned) .expect("malformed ClientHello extension should be discarded"); server diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 229bc813..64747426 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -140,8 +140,10 @@ fn dtls13_malformed_client_hello_extension_does_not_poison_transcript() { let mut poisoned = client_hello[0].clone(); poison_extension_vector_len(&mut poisoned, 0x0033); // key_share + let mut mixed_poisoned = dtls13_ack_record(42); + mixed_poisoned.extend_from_slice(&poisoned); server - .handle_packet(&poisoned) + .handle_packet(&mixed_poisoned) .expect("malformed ClientHello extension should be discarded"); server diff --git a/tests/dtls13_cookie.rs b/tests/dtls13_cookie.rs index ac1391f1..c39c7e0b 100644 --- a/tests/dtls13_cookie.rs +++ b/tests/dtls13_cookie.rs @@ -118,18 +118,19 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { server.handle_timeout(now).expect("server timeout"); let server_out = drain_outputs(&mut server); - let mut hrr = server_out + let hrr = server_out .packets .into_iter() .next() .expect("server should emit HRR"); + let mut malformed_hrr = hrr.clone(); assert!( - shrink_dtls13_cookie_extension_inner_len(&mut hrr), + shrink_dtls13_cookie_extension_inner_len(&mut malformed_hrr), "fixture should contain a Cookie extension" ); client - .handle_packet(&hrr) + .handle_packet(&malformed_hrr) .expect("malformed HRR Cookie extension should be discarded"); client @@ -140,6 +141,18 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { client_out.packets.is_empty(), "client must not send CH2 after malformed HRR Cookie" ); + + client + .handle_packet(&hrr) + .expect("clean HRR retransmission should be accepted"); + client + .handle_timeout(now) + .expect("client timeout after clean HRR"); + let client_out = drain_outputs(&mut client); + assert!( + !client_out.packets.is_empty(), + "client should send CH2 after clean HRR retransmission" + ); } #[test] From 481d77f07c2c16ac041a71634c6725407b8baf71 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Mon, 15 Jun 2026 21:29:29 +0200 Subject: [PATCH 3/3] fix: preserve mixed handshake rollback records Co-Authored-By: Codex --- src/dtls12/engine.rs | 151 ++++++++++++++++++++++++++++++++++++----- src/dtls13/engine.rs | 98 +++++++++++++++++++++----- tests/dtls13_cookie.rs | 146 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 359 insertions(+), 36 deletions(-) diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 5119f6b3..72cf6dbf 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -280,7 +280,8 @@ impl Engine { mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); - let Some(handshake) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + let Some((record, handshake)) = + first_relevant_handshake(&incoming, self.peer_handshake_seq_no) else { return Ok(()); }; @@ -290,7 +291,7 @@ impl Engine { handshake.header.fragment_offset, ); - if self.peer_encryption_enabled && incoming.first().record().sequence.epoch == 0 { + if self.peer_encryption_enabled && record.record().sequence.epoch == 0 { // Keep old plaintext handshake records available long enough to // trigger flight resends above, but never queue or process them as // new messages after peer encryption is enabled. @@ -304,6 +305,7 @@ impl Engine { let search_result = self.queue_rx.binary_search_by(|item| { let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) + .map(|(_, h)| h) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -697,7 +699,7 @@ impl Engine { .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); if reject_incoming { - mark_incoming_handled(incoming); + mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no); } } } @@ -1267,13 +1269,16 @@ pub(crate) struct HandshakeProgressSnapshot { transcript_len: usize, } -fn first_relevant_handshake(incoming: &Incoming, peer_handshake_seq_no: u16) -> Option<&Handshake> { +fn first_relevant_handshake( + incoming: &Incoming, + peer_handshake_seq_no: u16, +) -> Option<(&Record, &Handshake)> { incoming .records() .iter() - .flat_map(|r| r.handshakes()) - .filter(|h| !h.is_handled()) - .find(|h| h.header.message_seq >= peer_handshake_seq_no) + .flat_map(|r| r.handshakes().iter().map(move |h| (r, h))) + .filter(|(_, h)| !h.is_handled()) + .find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no) } fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { @@ -1287,16 +1292,14 @@ fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { } } -fn mark_incoming_handled(incoming: &Incoming) { +fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { for record in incoming.records().iter() { - if record.handshakes().is_empty() { - if !record.is_handled() { - record.set_handled(); - } - } else { - for handshake in record.handshakes() { - handshake.set_handled(); - } + for handshake in record + .handshakes() + .iter() + .filter(|h| h.header.message_seq >= peer_handshake_seq_no) + { + handshake.set_handled(); } } } @@ -1445,3 +1448,119 @@ impl RecordHandler for Engine { self.release_app_data } } + +#[cfg(test)] +mod tests { + use super::*; + + struct PassthroughRecordHandler; + + impl RecordHandler for PassthroughRecordHandler { + fn classify_record(&mut self, record: Record) -> Result, Error> { + Ok(Some(record)) + } + + fn is_peer_encryption_enabled(&self) -> bool { + true + } + + fn replay_check(&self, _seq: Sequence) -> bool { + true + } + + fn replay_update(&mut self, _seq: Sequence) {} + + fn decryption_aad_and_nonce(&self, _dtls: &DTLSRecord, _buf: &[u8]) -> (Aad, Nonce) { + ( + Aad::new_dtls12(ContentType::Handshake, Sequence::new(1), 0), + Nonce::xor(&[0; 12], 0), + ) + } + + fn explicit_nonce_len(&self) -> usize { + 0 + } + + fn min_protected_fragment_len(&self) -> usize { + 0 + } + + fn decrypt_data( + &mut self, + _ciphertext: &mut TmpBuf, + _aad: Aad, + _nonce: Nonce, + ) -> Result<(), Error> { + Ok(()) + } + } + + fn test_engine() -> Engine { + Engine::new(Arc::new(Config::default()), AuthMode::Psk) + } + + fn handshake_record(content_type: ContentType, epoch: u16, seq: u64, msg_seq: u16) -> Vec { + let mut fragment = Vec::new(); + fragment.push(MessageType::Finished.as_u8()); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&msg_seq.to_be_bytes()); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + + let mut out = Vec::new(); + out.push(content_type.as_u8()); + out.extend_from_slice(&[0xFE, 0xFD]); + out.extend_from_slice(&epoch.to_be_bytes()); + out.extend_from_slice(&seq.to_be_bytes()[2..]); + out.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + out.extend_from_slice(&fragment); + out + } + + fn ccs_record(seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(ContentType::ChangeCipherSpec.as_u8()); + out.extend_from_slice(&[0xFE, 0xFD]); + out.extend_from_slice(&0u16.to_be_bytes()); + out.extend_from_slice(&seq.to_be_bytes()[2..]); + out.extend_from_slice(&1u16.to_be_bytes()); + out.push(1); + out + } + + #[test] + fn encrypted_relevant_handshake_after_stale_plaintext_is_queued() { + let mut engine = test_engine(); + engine.peer_encryption_enabled = true; + engine.peer_handshake_seq_no = 1; + + let mut packet = handshake_record(ContentType::Handshake, 0, 0, 0); + packet.extend_from_slice(&ccs_record(1)); + packet.extend_from_slice(&handshake_record(ContentType::Handshake, 1, 0, 1)); + + let incoming = Incoming::parse_packet( + &packet, + &mut PassthroughRecordHandler, + Some(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256), + ) + .expect("parse packet") + .expect("packet contains records"); + + engine + .insert_incoming(incoming) + .expect("insert mixed retransmission"); + + assert_eq!(engine.queue_rx.len(), 1); + assert!( + engine.queue_rx[0] + .records() + .iter() + .any(|record| record.record().sequence.epoch == 1 + && record + .handshakes() + .iter() + .any(|handshake| handshake.header.message_seq == 1)), + "relevant encrypted Finished must remain queued even when stale plaintext starts the datagram" + ); + } +} diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index ee78ace2..55d85fd8 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -380,7 +380,7 @@ impl Engine { mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); - let Some(handshake) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + let Some((_, handshake)) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) else { return Ok(()); }; @@ -401,6 +401,7 @@ impl Engine { let search_result = self.queue_rx.binary_search_by(|item| { let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) + .map(|(_, h)| h) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -878,7 +879,7 @@ impl Engine { .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); if reject_incoming { - mark_incoming_handled(incoming); + mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no); } } } @@ -2526,13 +2527,16 @@ pub(crate) struct HandshakeProgressSnapshot { transcript_len: usize, } -fn first_relevant_handshake(incoming: &Incoming, peer_handshake_seq_no: u16) -> Option<&Handshake> { +fn first_relevant_handshake( + incoming: &Incoming, + peer_handshake_seq_no: u16, +) -> Option<(&Record, &Handshake)> { incoming .records() .iter() - .flat_map(|r| r.handshakes()) - .filter(|h| !h.is_handled()) - .find(|h| h.header.message_seq >= peer_handshake_seq_no) + .flat_map(|r| r.handshakes().iter().map(move |h| (r, h))) + .filter(|(_, h)| !h.is_handled()) + .find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no) } fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { @@ -2546,16 +2550,14 @@ fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { } } -fn mark_incoming_handled(incoming: &Incoming) { +fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { for record in incoming.records().iter() { - if record.handshakes().is_empty() { - if !record.is_handled() { - record.set_handled(); - } - } else { - for handshake in record.handshakes() { - handshake.set_handled(); - } + for handshake in record + .handshakes() + .iter() + .filter(|h| h.header.message_seq >= peer_handshake_seq_no) + { + handshake.set_handled(); } } } @@ -2644,16 +2646,38 @@ mod tests { packet } - fn parsed_key_update(seq: u16) -> Incoming { + fn encrypted_app_data_record(seq: u16, payload: &[u8]) -> Vec { + let mut fragment = Vec::new(); + fragment.extend_from_slice(payload); + fragment.push(ContentType::ApplicationData.as_u8()); + + let mut packet = Vec::new(); + packet.push( + 0b0010_0000 + | 0b0000_1000 // 2-byte sequence number. + | 0b0000_0100 // explicit length. + | 0b0000_0010, // epoch bits resolved by PassthroughRecordHandler. + ); + packet.extend_from_slice(&seq.to_be_bytes()); + packet.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + packet.extend_from_slice(&fragment); + packet + } + + fn parsed_packet(packet: &[u8]) -> Incoming { Incoming::parse_packet( - &encrypted_key_update_record(seq), + packet, &mut PassthroughRecordHandler, Some(Dtls13CipherSuite::AES_128_GCM_SHA256), ) - .expect("parse key update packet") + .expect("parse packet") .expect("packet contains a record") } + fn parsed_key_update(seq: u16) -> Incoming { + parsed_packet(&encrypted_key_update_record(seq)) + } + /// Issue 2: Epoch-0 sequence number must have an overflow guard. /// /// Per RFC 9147 ยง4.2, implementations MUST NOT allow the sequence number @@ -2821,4 +2845,42 @@ mod tests { "malformed ACK vector length must not partially acknowledge records" ); } + + #[test] + #[cfg(feature = "rcgen")] + fn rollback_preserves_coalesced_application_data() { + let mut packet = encrypted_app_data_record(0, b"still-pending"); + packet.extend_from_slice(&encrypted_key_update_record(1)); + let incoming = parsed_packet(&packet); + + let app_record = incoming + .records() + .iter() + .find(|record| record.record().content_type == ContentType::ApplicationData) + .expect("packet contains application data"); + let key_update = incoming + .records() + .iter() + .flat_map(|record| record.handshakes()) + .find(|handshake| handshake.header.msg_type == MessageType::KeyUpdate) + .expect("packet contains KeyUpdate"); + + assert!(!app_record.is_handled()); + assert!(!key_update.is_handled()); + + let snapshot = HandshakeProgressSnapshot { + peer_handshake_seq_no: key_update.header.message_seq, + transcript_len: 0, + }; + mark_rejected_handshakes(&incoming, snapshot.peer_handshake_seq_no); + + assert!( + !app_record.is_handled(), + "rollback must not consume application data coalesced with a rejected handshake" + ); + assert!( + key_update.is_handled(), + "rollback should discard the rejected handshake" + ); + } } diff --git a/tests/dtls13_cookie.rs b/tests/dtls13_cookie.rs index c39c7e0b..e4009f35 100644 --- a/tests/dtls13_cookie.rs +++ b/tests/dtls13_cookie.rs @@ -6,8 +6,8 @@ mod common; use std::sync::Arc; use std::time::Instant; -use dimpl::Dtls; use dimpl::certificate::generate_self_signed_certificate; +use dimpl::{Dtls, NamedGroup}; use crate::common::{drain_outputs, dtls13_config}; @@ -94,6 +94,131 @@ fn shrink_dtls13_cookie_extension_inner_len(packet: &mut [u8]) -> bool { false } +fn dtls13_hrr_extension_types(packet: &[u8]) -> Option> { + const RECORD_HEADER_LEN: usize = 13; + const HANDSHAKE_HEADER_LEN: usize = 12; + + if packet.len() < RECORD_HEADER_LEN + HANDSHAKE_HEADER_LEN || packet[0] != 22 { + return None; + } + + let handshake = &packet[RECORD_HEADER_LEN..]; + let msg_type = handshake[0]; + let body_len = + ((handshake[1] as usize) << 16) | ((handshake[2] as usize) << 8) | handshake[3] as usize; + if handshake.len() < HANDSHAKE_HEADER_LEN + body_len { + return None; + } + + let body = &handshake[HANDSHAKE_HEADER_LEN..HANDSHAKE_HEADER_LEN + body_len]; + let mut pos = cookie_extensions_start(body, msg_type)?; + if body.len() < pos + 2 { + return None; + } + + let extensions_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; + pos += 2; + let extensions_end = pos + extensions_len; + if body.len() < extensions_end { + return None; + } + + let mut extensions = Vec::new(); + while pos + 4 <= extensions_end { + let extension_type = u16::from_be_bytes([body[pos], body[pos + 1]]); + let extension_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize; + pos += 4 + extension_len; + if pos > extensions_end { + return None; + } + extensions.push(extension_type); + } + + Some(extensions) +} + +fn insert_dtls13_hrr_key_share_before_cookie(packet: &mut Vec, selected_group: u16) -> bool { + const RECORD_HEADER_LEN: usize = 13; + const HANDSHAKE_HEADER_LEN: usize = 12; + const KEY_SHARE_EXTENSION: u16 = 0x0033; + const COOKIE_EXTENSION: u16 = 0x002C; + + if packet.len() < RECORD_HEADER_LEN + HANDSHAKE_HEADER_LEN || packet[0] != 22 { + return false; + } + + let handshake_start = RECORD_HEADER_LEN; + let body_start = handshake_start + HANDSHAKE_HEADER_LEN; + let body_len = ((packet[handshake_start + 1] as usize) << 16) + | ((packet[handshake_start + 2] as usize) << 8) + | packet[handshake_start + 3] as usize; + if packet.len() < body_start + body_len { + return false; + } + + let body = &packet[body_start..body_start + body_len]; + let mut pos = match cookie_extensions_start(body, packet[handshake_start]) { + Some(pos) => body_start + pos, + None => return false, + }; + if packet.len() < pos + 2 { + return false; + } + + let extensions_len = u16::from_be_bytes([packet[pos], packet[pos + 1]]) as usize; + let extensions_len_pos = pos; + pos += 2; + let extensions_end = pos + extensions_len; + if packet.len() < extensions_end { + return false; + } + + while pos + 4 <= extensions_end { + let extension_type = u16::from_be_bytes([packet[pos], packet[pos + 1]]); + let extension_len = u16::from_be_bytes([packet[pos + 2], packet[pos + 3]]) as usize; + let next = pos + 4 + extension_len; + if next > extensions_end { + return false; + } + + if extension_type == KEY_SHARE_EXTENSION { + return true; + } + + if extension_type == COOKIE_EXTENSION { + let key_share = [ + (KEY_SHARE_EXTENSION >> 8) as u8, + KEY_SHARE_EXTENSION as u8, + 0, + 2, + (selected_group >> 8) as u8, + selected_group as u8, + ]; + packet.splice(pos..pos, key_share); + + let new_extensions_len = extensions_len + key_share.len(); + packet[extensions_len_pos..extensions_len_pos + 2] + .copy_from_slice(&(new_extensions_len as u16).to_be_bytes()); + + let new_body_len = body_len + key_share.len(); + packet[handshake_start + 1] = ((new_body_len >> 16) & 0xff) as u8; + packet[handshake_start + 2] = ((new_body_len >> 8) & 0xff) as u8; + packet[handshake_start + 3] = (new_body_len & 0xff) as u8; + packet[handshake_start + 9] = ((new_body_len >> 16) & 0xff) as u8; + packet[handshake_start + 10] = ((new_body_len >> 8) & 0xff) as u8; + packet[handshake_start + 11] = (new_body_len & 0xff) as u8; + + let new_record_len = HANDSHAKE_HEADER_LEN + new_body_len; + packet[11..13].copy_from_slice(&(new_record_len as u16).to_be_bytes()); + return true; + } + + pos = next; + } + + false +} + #[test] fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { let _ = env_logger::try_init(); @@ -118,12 +243,29 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { server.handle_timeout(now).expect("server timeout"); let server_out = drain_outputs(&mut server); - let hrr = server_out + let mut hrr = server_out .packets .into_iter() .next() .expect("server should emit HRR"); + assert!( + insert_dtls13_hrr_key_share_before_cookie(&mut hrr, NamedGroup::X25519.as_u16()), + "fixture should insert HRR KeyShare before Cookie" + ); let mut malformed_hrr = hrr.clone(); + let hrr_extensions = dtls13_hrr_extension_types(&hrr).expect("parse HRR extensions"); + let key_share_index = hrr_extensions + .iter() + .position(|ext| *ext == 0x0033) + .expect("HRR should contain KeyShare"); + let cookie_index = hrr_extensions + .iter() + .position(|ext| *ext == 0x002C) + .expect("HRR should contain Cookie"); + assert!( + key_share_index < cookie_index, + "fixture must parse KeyShare before rejecting malformed Cookie" + ); assert!( shrink_dtls13_cookie_extension_inner_len(&mut malformed_hrr), "fixture should contain a Cookie extension"