diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs
index 71ac7a70..84bb12fd 100644
--- a/src/dtls12/client.rs
+++ b/src/dtls12/client.rs
@@ -263,8 +263,17 @@ impl Client {
fn make_progress(&mut self) -> Result<(), InternalError> {
loop {
let prev_state = self.state;
-
- let new_state = prev_state.make_progress(self)?;
+ let snapshot = self.engine.handshake_progress_snapshot();
+
+ let new_state = match prev_state.make_progress(self) {
+ Ok(new_state) => new_state,
+ Err(err) => {
+ if err.is_transient() {
+ self.engine.rollback_handshake_progress(snapshot);
+ }
+ return Err(err);
+ }
+ };
if prev_state != new_state {
self.state = new_state;
trace!("{:?} -> {:?}", prev_state, new_state);
diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs
index 3bd86e99..72cf6dbf 100644
--- a/src/dtls12/engine.rs
+++ b/src/dtls12/engine.rs
@@ -231,6 +231,8 @@ impl Engine {
/// Insert a parsed datagram into the receive queue.
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
+ self.purge_handled_queue_rx();
+
// Capacity guard before iterating records.
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
@@ -241,8 +243,16 @@ impl Engine {
return Err(Error::ReceiveQueueFull);
}
- // Dispatch to specialized handlers
- if incoming.first().first_handshake().is_some() {
+ // Dispatch to specialized handlers. A datagram can start with a
+ // non-handshake record (for example CCS) and still carry a handshake
+ // record later in the same packet.
+ if incoming.first().first_handshake().is_some()
+ || (!self.release_app_data
+ && incoming
+ .records()
+ .iter()
+ .any(|r| !r.handshakes().is_empty()))
+ {
self.insert_incoming_handshake(incoming)
} else {
self.insert_incoming_non_handshake(incoming)
@@ -250,16 +260,6 @@ impl Engine {
}
fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> {
- let first_record = incoming.first();
- let handshake = first_record
- .first_handshake()
- .expect("caller ensures handshake");
-
- let key_current = (
- handshake.header.message_seq,
- handshake.header.fragment_offset,
- );
-
let maybe_dupe_seq = incoming
.records()
.iter()
@@ -278,12 +278,20 @@ impl Engine {
}
}
- // Drop old duplicates we've already processed - don't let them block newer messages.
- if handshake.header.message_seq < self.peer_handshake_seq_no {
+ mark_stale_handshakes(&incoming, self.peer_handshake_seq_no);
+
+ let Some((record, handshake)) =
+ first_relevant_handshake(&incoming, self.peer_handshake_seq_no)
+ else {
return Ok(());
- }
+ };
- if self.peer_encryption_enabled && first_record.record().sequence.epoch == 0 {
+ let key_current = (
+ handshake.header.message_seq,
+ handshake.header.fragment_offset,
+ );
+
+ if self.peer_encryption_enabled && record.record().sequence.epoch == 0 {
// Keep old plaintext handshake records available long enough to
// trigger flight resends above, but never queue or process them as
// new messages after peer encryption is enabled.
@@ -296,10 +304,8 @@ impl Engine {
}
let search_result = self.queue_rx.binary_search_by(|item| {
- let key_other = item
- .first()
- .first_handshake()
- .as_ref()
+ let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no)
+ .map(|(_, h)| h)
.map(|h| (h.header.message_seq, h.header.fragment_offset))
.unwrap_or((u16::MAX, u32::MAX));
key_other.cmp(&key_current)
@@ -674,6 +680,30 @@ impl Engine {
Ok(Some(handshake))
}
+ pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot {
+ HandshakeProgressSnapshot {
+ peer_handshake_seq_no: self.peer_handshake_seq_no,
+ transcript_len: self.transcript.len(),
+ }
+ }
+
+ pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) {
+ self.peer_handshake_seq_no = snapshot.peer_handshake_seq_no;
+ self.transcript.resize(snapshot.transcript_len, 0);
+
+ for incoming in self.queue_rx.iter() {
+ let reject_incoming = incoming
+ .records()
+ .iter()
+ .flat_map(|r| r.handshakes())
+ .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no);
+
+ if reject_incoming {
+ mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no);
+ }
+ }
+ }
+
pub(crate) fn next_record(&mut self, ctype: ContentType) -> Option<&Record> {
let record = self
.queue_rx
@@ -1233,6 +1263,47 @@ impl Engine {
}
}
+#[derive(Clone, Copy)]
+pub(crate) struct HandshakeProgressSnapshot {
+ peer_handshake_seq_no: u16,
+ transcript_len: usize,
+}
+
+fn first_relevant_handshake(
+ incoming: &Incoming,
+ peer_handshake_seq_no: u16,
+) -> Option<(&Record, &Handshake)> {
+ incoming
+ .records()
+ .iter()
+ .flat_map(|r| r.handshakes().iter().map(move |h| (r, h)))
+ .filter(|(_, h)| !h.is_handled())
+ .find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no)
+}
+
+fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) {
+ for handshake in incoming
+ .records()
+ .iter()
+ .flat_map(|r| r.handshakes())
+ .filter(|h| h.header.message_seq < peer_handshake_seq_no)
+ {
+ handshake.set_handled();
+ }
+}
+
+fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) {
+ for record in incoming.records().iter() {
+ for handshake in record
+ .handshakes()
+ .iter()
+ .filter(|h| h.header.message_seq >= peer_handshake_seq_no)
+ {
+ handshake.set_handled();
+ }
+ }
+}
+
impl RecordHandler for Engine {
fn classify_record(&mut self, record: Record) -> Result