Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/dtls12/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,17 @@ impl Client {
fn make_progress(&mut self) -> Result<(), InternalError> {
loop {
let prev_state = self.state;

let new_state = prev_state.make_progress(self)?;
let snapshot = self.engine.handshake_progress_snapshot();

let new_state = match prev_state.make_progress(self) {
Ok(new_state) => new_state,
Err(err) => {
if err.is_transient() {
self.engine.rollback_handshake_progress(snapshot);
}
return Err(err);
}
};
if prev_state != new_state {
self.state = new_state;
trace!("{:?} -> {:?}", prev_state, new_state);
Expand Down
227 changes: 207 additions & 20 deletions src/dtls12/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ impl Engine {

/// Insert a parsed datagram into the receive queue.
fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> {
self.purge_handled_queue_rx();

// Capacity guard before iterating records.
if self.queue_rx.len() >= self.config.max_queue_rx() {
warn!(
Expand All @@ -241,25 +243,23 @@ impl Engine {
return Err(Error::ReceiveQueueFull);
}

// Dispatch to specialized handlers
if incoming.first().first_handshake().is_some() {
// Dispatch to specialized handlers. A datagram can start with a
// non-handshake record (for example CCS) and still carry a handshake
// record later in the same packet.
if incoming.first().first_handshake().is_some()
|| (!self.release_app_data
&& incoming
.records()
.iter()
.any(|r| !r.handshakes().is_empty()))
{
self.insert_incoming_handshake(incoming)
} else {
self.insert_incoming_non_handshake(incoming)
}
}

fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> {
let first_record = incoming.first();
let handshake = first_record
.first_handshake()
.expect("caller ensures handshake");

let key_current = (
handshake.header.message_seq,
handshake.header.fragment_offset,
);

let maybe_dupe_seq = incoming
.records()
.iter()
Expand All @@ -278,12 +278,20 @@ impl Engine {
}
}

// Drop old duplicates we've already processed - don't let them block newer messages.
if handshake.header.message_seq < self.peer_handshake_seq_no {
mark_stale_handshakes(&incoming, self.peer_handshake_seq_no);

let Some((record, handshake)) =
first_relevant_handshake(&incoming, self.peer_handshake_seq_no)
else {
return Ok(());
}
};

if self.peer_encryption_enabled && first_record.record().sequence.epoch == 0 {
let key_current = (
handshake.header.message_seq,
handshake.header.fragment_offset,
);

if self.peer_encryption_enabled && record.record().sequence.epoch == 0 {
// Keep old plaintext handshake records available long enough to
// trigger flight resends above, but never queue or process them as
// new messages after peer encryption is enabled.
Expand All @@ -296,10 +304,8 @@ impl Engine {
}

let search_result = self.queue_rx.binary_search_by(|item| {
let key_other = item
.first()
.first_handshake()
.as_ref()
let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no)
.map(|(_, h)| h)
.map(|h| (h.header.message_seq, h.header.fragment_offset))
.unwrap_or((u16::MAX, u32::MAX));
key_other.cmp(&key_current)
Expand Down Expand Up @@ -674,6 +680,30 @@ impl Engine {
Ok(Some(handshake))
}

pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot {
HandshakeProgressSnapshot {
peer_handshake_seq_no: self.peer_handshake_seq_no,
transcript_len: self.transcript.len(),
}
}

pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) {
self.peer_handshake_seq_no = snapshot.peer_handshake_seq_no;
self.transcript.resize(snapshot.transcript_len, 0);

for incoming in self.queue_rx.iter() {
let reject_incoming = incoming
.records()
.iter()
.flat_map(|r| r.handshakes())
.any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no);

if reject_incoming {
mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no);
}
}
}

pub(crate) fn next_record(&mut self, ctype: ContentType) -> Option<&Record> {
let record = self
.queue_rx
Expand Down Expand Up @@ -1233,6 +1263,47 @@ impl Engine {
}
}

#[derive(Clone, Copy)]
pub(crate) struct HandshakeProgressSnapshot {
peer_handshake_seq_no: u16,
transcript_len: usize,
}

fn first_relevant_handshake(
incoming: &Incoming,
peer_handshake_seq_no: u16,
) -> Option<(&Record, &Handshake)> {
incoming
.records()
.iter()
.flat_map(|r| r.handshakes().iter().map(move |h| (r, h)))
.filter(|(_, h)| !h.is_handled())
.find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no)
}

fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) {
for handshake in incoming
.records()
.iter()
.flat_map(|r| r.handshakes())
.filter(|h| h.header.message_seq < peer_handshake_seq_no)
{
handshake.set_handled();
}
}

fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) {
for record in incoming.records().iter() {
for handshake in record
.handshakes()
.iter()
.filter(|h| h.header.message_seq >= peer_handshake_seq_no)
{
handshake.set_handled();
}
}
}

impl RecordHandler for Engine {
fn classify_record(&mut self, record: Record) -> Result<Option<Record>, Error> {
let epoch = record.record().sequence.epoch;
Expand Down Expand Up @@ -1377,3 +1448,119 @@ impl RecordHandler for Engine {
self.release_app_data
}
}

#[cfg(test)]
mod tests {
use super::*;

struct PassthroughRecordHandler;

impl RecordHandler for PassthroughRecordHandler {
fn classify_record(&mut self, record: Record) -> Result<Option<Record>, Error> {
Ok(Some(record))
}

fn is_peer_encryption_enabled(&self) -> bool {
true
}

fn replay_check(&self, _seq: Sequence) -> bool {
true
}

fn replay_update(&mut self, _seq: Sequence) {}

fn decryption_aad_and_nonce(&self, _dtls: &DTLSRecord, _buf: &[u8]) -> (Aad, Nonce) {
(
Aad::new_dtls12(ContentType::Handshake, Sequence::new(1), 0),
Nonce::xor(&[0; 12], 0),
)
}

fn explicit_nonce_len(&self) -> usize {
0
}

fn min_protected_fragment_len(&self) -> usize {
0
}

fn decrypt_data(
&mut self,
_ciphertext: &mut TmpBuf,
_aad: Aad,
_nonce: Nonce,
) -> Result<(), Error> {
Ok(())
}
}

fn test_engine() -> Engine {
Engine::new(Arc::new(Config::default()), AuthMode::Psk)
}

fn handshake_record(content_type: ContentType, epoch: u16, seq: u64, msg_seq: u16) -> Vec<u8> {
let mut fragment = Vec::new();
fragment.push(MessageType::Finished.as_u8());
fragment.extend_from_slice(&0u32.to_be_bytes()[1..]);
fragment.extend_from_slice(&msg_seq.to_be_bytes());
fragment.extend_from_slice(&0u32.to_be_bytes()[1..]);
fragment.extend_from_slice(&0u32.to_be_bytes()[1..]);

let mut out = Vec::new();
out.push(content_type.as_u8());
out.extend_from_slice(&[0xFE, 0xFD]);
out.extend_from_slice(&epoch.to_be_bytes());
out.extend_from_slice(&seq.to_be_bytes()[2..]);
out.extend_from_slice(&(fragment.len() as u16).to_be_bytes());
out.extend_from_slice(&fragment);
out
}

fn ccs_record(seq: u64) -> Vec<u8> {
let mut out = Vec::new();
out.push(ContentType::ChangeCipherSpec.as_u8());
out.extend_from_slice(&[0xFE, 0xFD]);
out.extend_from_slice(&0u16.to_be_bytes());
out.extend_from_slice(&seq.to_be_bytes()[2..]);
out.extend_from_slice(&1u16.to_be_bytes());
out.push(1);
out
}

#[test]
fn encrypted_relevant_handshake_after_stale_plaintext_is_queued() {
let mut engine = test_engine();
engine.peer_encryption_enabled = true;
engine.peer_handshake_seq_no = 1;

let mut packet = handshake_record(ContentType::Handshake, 0, 0, 0);
packet.extend_from_slice(&ccs_record(1));
packet.extend_from_slice(&handshake_record(ContentType::Handshake, 1, 0, 1));

let incoming = Incoming::parse_packet(
&packet,
&mut PassthroughRecordHandler,
Some(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256),
)
.expect("parse packet")
.expect("packet contains records");

engine
.insert_incoming(incoming)
.expect("insert mixed retransmission");

assert_eq!(engine.queue_rx.len(), 1);
assert!(
engine.queue_rx[0]
.records()
.iter()
.any(|record| record.record().sequence.epoch == 1
&& record
.handshakes()
.iter()
.any(|handshake| handshake.header.message_seq == 1)),
"relevant encrypted Finished must remain queued even when stale plaintext starts the datagram"
);
}
}
11 changes: 10 additions & 1 deletion src/dtls12/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,17 @@ impl Server {
fn make_progress(&mut self) -> Result<(), InternalError> {
loop {
let prev_state = self.state;
let snapshot = self.engine.handshake_progress_snapshot();

let new_state = prev_state.make_progress(self)?;
let new_state = match prev_state.make_progress(self) {
Ok(new_state) => new_state,
Err(err) => {
if err.is_transient() {
self.engine.rollback_handshake_progress(snapshot);
}
return Err(err);
}
};
if prev_state != new_state {
self.state = new_state;
trace!("{:?} -> {:?}", prev_state, new_state);
Expand Down
Loading
Loading