diff --git a/src/dtls12/client.rs b/src/dtls12/client.rs index 71ac7a70..84bb12fd 100644 --- a/src/dtls12/client.rs +++ b/src/dtls12/client.rs @@ -263,8 +263,17 @@ impl Client { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; - - let new_state = prev_state.make_progress(self)?; + let snapshot = self.engine.handshake_progress_snapshot(); + + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e99..72cf6dbf 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -231,6 +231,8 @@ impl Engine { /// Insert a parsed datagram into the receive queue. fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { + self.purge_handled_queue_rx(); + // Capacity guard before iterating records. if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( @@ -241,8 +243,16 @@ impl Engine { return Err(Error::ReceiveQueueFull); } - // Dispatch to specialized handlers - if incoming.first().first_handshake().is_some() { + // Dispatch to specialized handlers. A datagram can start with a + // non-handshake record (for example CCS) and still carry a handshake + // record later in the same packet. + if incoming.first().first_handshake().is_some() + || (!self.release_app_data + && incoming + .records() + .iter() + .any(|r| !r.handshakes().is_empty())) + { self.insert_incoming_handshake(incoming) } else { self.insert_incoming_non_handshake(incoming) @@ -250,16 +260,6 @@ impl Engine { } fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> { - let first_record = incoming.first(); - let handshake = first_record - .first_handshake() - .expect("caller ensures handshake"); - - let key_current = ( - handshake.header.message_seq, - handshake.header.fragment_offset, - ); - let maybe_dupe_seq = incoming .records() .iter() @@ -278,12 +278,20 @@ impl Engine { } } - // Drop old duplicates we've already processed - don't let them block newer messages. - if handshake.header.message_seq < self.peer_handshake_seq_no { + mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); + + let Some((record, handshake)) = + first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + else { return Ok(()); - } + }; - if self.peer_encryption_enabled && first_record.record().sequence.epoch == 0 { + let key_current = ( + handshake.header.message_seq, + handshake.header.fragment_offset, + ); + + if self.peer_encryption_enabled && record.record().sequence.epoch == 0 { // Keep old plaintext handshake records available long enough to // trigger flight resends above, but never queue or process them as // new messages after peer encryption is enabled. @@ -296,10 +304,8 @@ impl Engine { } let search_result = self.queue_rx.binary_search_by(|item| { - let key_other = item - .first() - .first_handshake() - .as_ref() + let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) + .map(|(_, h)| h) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -674,6 +680,30 @@ impl Engine { Ok(Some(handshake)) } + pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot { + HandshakeProgressSnapshot { + peer_handshake_seq_no: self.peer_handshake_seq_no, + transcript_len: self.transcript.len(), + } + } + + pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) { + self.peer_handshake_seq_no = snapshot.peer_handshake_seq_no; + self.transcript.resize(snapshot.transcript_len, 0); + + for incoming in self.queue_rx.iter() { + let reject_incoming = incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + + if reject_incoming { + mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no); + } + } + } + pub(crate) fn next_record(&mut self, ctype: ContentType) -> Option<&Record> { let record = self .queue_rx @@ -1233,6 +1263,47 @@ impl Engine { } } +#[derive(Clone, Copy)] +pub(crate) struct HandshakeProgressSnapshot { + peer_handshake_seq_no: u16, + transcript_len: usize, +} + +fn first_relevant_handshake( + incoming: &Incoming, + peer_handshake_seq_no: u16, +) -> Option<(&Record, &Handshake)> { + incoming + .records() + .iter() + .flat_map(|r| r.handshakes().iter().map(move |h| (r, h))) + .filter(|(_, h)| !h.is_handled()) + .find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no) +} + +fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for handshake in incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| h.header.message_seq < peer_handshake_seq_no) + { + handshake.set_handled(); + } +} + +fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for record in incoming.records().iter() { + for handshake in record + .handshakes() + .iter() + .filter(|h| h.header.message_seq >= peer_handshake_seq_no) + { + handshake.set_handled(); + } + } +} + impl RecordHandler for Engine { fn classify_record(&mut self, record: Record) -> Result, Error> { let epoch = record.record().sequence.epoch; @@ -1377,3 +1448,119 @@ impl RecordHandler for Engine { self.release_app_data } } + +#[cfg(test)] +mod tests { + use super::*; + + struct PassthroughRecordHandler; + + impl RecordHandler for PassthroughRecordHandler { + fn classify_record(&mut self, record: Record) -> Result, Error> { + Ok(Some(record)) + } + + fn is_peer_encryption_enabled(&self) -> bool { + true + } + + fn replay_check(&self, _seq: Sequence) -> bool { + true + } + + fn replay_update(&mut self, _seq: Sequence) {} + + fn decryption_aad_and_nonce(&self, _dtls: &DTLSRecord, _buf: &[u8]) -> (Aad, Nonce) { + ( + Aad::new_dtls12(ContentType::Handshake, Sequence::new(1), 0), + Nonce::xor(&[0; 12], 0), + ) + } + + fn explicit_nonce_len(&self) -> usize { + 0 + } + + fn min_protected_fragment_len(&self) -> usize { + 0 + } + + fn decrypt_data( + &mut self, + _ciphertext: &mut TmpBuf, + _aad: Aad, + _nonce: Nonce, + ) -> Result<(), Error> { + Ok(()) + } + } + + fn test_engine() -> Engine { + Engine::new(Arc::new(Config::default()), AuthMode::Psk) + } + + fn handshake_record(content_type: ContentType, epoch: u16, seq: u64, msg_seq: u16) -> Vec { + let mut fragment = Vec::new(); + fragment.push(MessageType::Finished.as_u8()); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&msg_seq.to_be_bytes()); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + fragment.extend_from_slice(&0u32.to_be_bytes()[1..]); + + let mut out = Vec::new(); + out.push(content_type.as_u8()); + out.extend_from_slice(&[0xFE, 0xFD]); + out.extend_from_slice(&epoch.to_be_bytes()); + out.extend_from_slice(&seq.to_be_bytes()[2..]); + out.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + out.extend_from_slice(&fragment); + out + } + + fn ccs_record(seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(ContentType::ChangeCipherSpec.as_u8()); + out.extend_from_slice(&[0xFE, 0xFD]); + out.extend_from_slice(&0u16.to_be_bytes()); + out.extend_from_slice(&seq.to_be_bytes()[2..]); + out.extend_from_slice(&1u16.to_be_bytes()); + out.push(1); + out + } + + #[test] + fn encrypted_relevant_handshake_after_stale_plaintext_is_queued() { + let mut engine = test_engine(); + engine.peer_encryption_enabled = true; + engine.peer_handshake_seq_no = 1; + + let mut packet = handshake_record(ContentType::Handshake, 0, 0, 0); + packet.extend_from_slice(&ccs_record(1)); + packet.extend_from_slice(&handshake_record(ContentType::Handshake, 1, 0, 1)); + + let incoming = Incoming::parse_packet( + &packet, + &mut PassthroughRecordHandler, + Some(Dtls12CipherSuite::ECDHE_ECDSA_AES128_GCM_SHA256), + ) + .expect("parse packet") + .expect("packet contains records"); + + engine + .insert_incoming(incoming) + .expect("insert mixed retransmission"); + + assert_eq!(engine.queue_rx.len(), 1); + assert!( + engine.queue_rx[0] + .records() + .iter() + .any(|record| record.record().sequence.epoch == 1 + && record + .handshakes() + .iter() + .any(|handshake| handshake.header.message_seq == 1)), + "relevant encrypted Finished must remain queued even when stale plaintext starts the datagram" + ); + } +} diff --git a/src/dtls12/server.rs b/src/dtls12/server.rs index 48e22f1a..2eb239b0 100644 --- a/src/dtls12/server.rs +++ b/src/dtls12/server.rs @@ -261,8 +261,17 @@ impl Server { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/dtls13/client.rs b/src/dtls13/client.rs index a6c60639..db03d6dc 100644 --- a/src/dtls13/client.rs +++ b/src/dtls13/client.rs @@ -302,8 +302,17 @@ impl Client { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); @@ -474,22 +483,31 @@ impl State { debug!("Received HelloRetryRequest"); - // Extract selected group and cookie from HRR extensions + let mut hrr_selected_group = None; + let mut saved_cookie = None; + let mut hrr_version_ok = false; + if let Some(ref extensions) = server_hello.extensions { for ext in extensions { match ext.extension_type { ExtensionType::KeyShare => { let ext_data = ext.extension_data(&client.defragment_buffer); - if let Ok((_, hrr_ks)) = KeyShareHelloRetryRequest::parse(ext_data) { - client.hrr_selected_group = Some(hrr_ks.selected_group); - } + let (_, hrr_ks) = KeyShareHelloRetryRequest::parse(ext_data) + .map_err(InternalError::from)?; + hrr_selected_group = Some(hrr_ks.selected_group); } ExtensionType::Cookie => { let ext_data = ext.extension_data(&client.defragment_buffer); parse_cookie_extension(ext_data).map_err(InternalError::from)?; let mut cookie = Buf::new(); cookie.extend_from_slice(ext_data); - client.saved_cookie = Some(cookie); + saved_cookie = Some(cookie); + } + ExtensionType::SupportedVersions => { + let ext_data = ext.extension_data(&client.defragment_buffer); + let (_, sv) = SupportedVersionsServerHello::parse(ext_data) + .map_err(InternalError::from)?; + hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; } _ => {} } @@ -506,26 +524,17 @@ impl State { )) .into()); } - client.engine.set_cipher_suite(server_hello.cipher_suite); - // Validate HRR supported_versions - let mut hrr_version_ok = false; - if let Some(ref extensions) = server_hello.extensions { - for ext in extensions { - if ext.extension_type == ExtensionType::SupportedVersions { - let ext_data = ext.extension_data(&client.defragment_buffer); - if let Ok((_, sv)) = SupportedVersionsServerHello::parse(ext_data) { - hrr_version_ok = sv.selected_version == ProtocolVersion::DTLS1_3; - } - } - } - } if !hrr_version_ok { return Err( (Error::SecurityError(crate::SecurityError::HrrDidNotSelectDtls13)).into(), ); } + client.engine.set_cipher_suite(server_hello.cipher_suite); + client.hrr_selected_group = hrr_selected_group; + client.saved_cookie = saved_cookie; + // Replace transcript with message_hash per RFC 8446 Section 4.4.1. // The HRR was already appended to the transcript by next_handshake(). // We must hash only CH1, then re-append the HRR bytes. diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index d67ccc34..55d85fd8 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -340,6 +340,8 @@ impl Engine { } fn insert_incoming(&mut self, incoming: Incoming) -> Result<(), Error> { + self.purge_handled_queue_rx(); + if self.queue_rx.len() >= self.config.max_queue_rx() { warn!( "Receive queue full (max {}): {:?}", @@ -349,7 +351,13 @@ impl Engine { return Err(Error::ReceiveQueueFull); } - if incoming.first().first_handshake().is_some() { + if incoming.first().first_handshake().is_some() + || (!self.release_app_data + && incoming + .records() + .iter() + .any(|r| !r.handshakes().is_empty())) + { self.insert_incoming_handshake(incoming) } else { self.insert_incoming_non_handshake(incoming) @@ -357,16 +365,6 @@ impl Engine { } fn insert_incoming_handshake(&mut self, incoming: Incoming) -> Result<(), Error> { - let first_record = incoming.first(); - let handshake = first_record - .first_handshake() - .expect("caller ensures handshake"); - - let key_current = ( - handshake.header.message_seq, - handshake.header.fragment_offset, - ); - let maybe_dupe_seq = incoming .records() .iter() @@ -380,10 +378,17 @@ impl Engine { } } - // Drop old duplicates we've already processed - if handshake.header.message_seq < self.peer_handshake_seq_no { + mark_stale_handshakes(&incoming, self.peer_handshake_seq_no); + + let Some((_, handshake)) = first_relevant_handshake(&incoming, self.peer_handshake_seq_no) + else { return Ok(()); - } + }; + + let key_current = ( + handshake.header.message_seq, + handshake.header.fragment_offset, + ); // Reject new handshakes after initial handshake is complete, // but allow KeyUpdate (a post-handshake message). @@ -395,10 +400,8 @@ impl Engine { } let search_result = self.queue_rx.binary_search_by(|item| { - let key_other = item - .first() - .first_handshake() - .as_ref() + let key_other = first_relevant_handshake(item, self.peer_handshake_seq_no) + .map(|(_, h)| h) .map(|h| (h.header.message_seq, h.header.fragment_offset)) .unwrap_or((u16::MAX, u32::MAX)); key_other.cmp(&key_current) @@ -858,6 +861,29 @@ impl Engine { Ok(Some(handshake)) } + pub(crate) fn handshake_progress_snapshot(&self) -> HandshakeProgressSnapshot { + HandshakeProgressSnapshot { + peer_handshake_seq_no: self.peer_handshake_seq_no, + transcript_len: self.transcript.len(), + } + } + + pub(crate) fn rollback_handshake_progress(&mut self, snapshot: HandshakeProgressSnapshot) { + self.transcript.resize(snapshot.transcript_len, 0); + + for incoming in self.queue_rx.iter() { + let reject_incoming = incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .any(|h| h.header.message_seq >= snapshot.peer_handshake_seq_no); + + if reject_incoming { + mark_rejected_handshakes(incoming, snapshot.peer_handshake_seq_no); + } + } + } + /// Advance the expected peer handshake sequence number. /// /// Must be called by the caller of `next_handshake` / `next_handshake_no_transcript` @@ -2495,6 +2521,47 @@ impl RecordHandler for Engine { } } +#[derive(Clone, Copy)] +pub(crate) struct HandshakeProgressSnapshot { + peer_handshake_seq_no: u16, + transcript_len: usize, +} + +fn first_relevant_handshake( + incoming: &Incoming, + peer_handshake_seq_no: u16, +) -> Option<(&Record, &Handshake)> { + incoming + .records() + .iter() + .flat_map(|r| r.handshakes().iter().map(move |h| (r, h))) + .filter(|(_, h)| !h.is_handled()) + .find(|(_, h)| h.header.message_seq >= peer_handshake_seq_no) +} + +fn mark_stale_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for handshake in incoming + .records() + .iter() + .flat_map(|r| r.handshakes()) + .filter(|h| h.header.message_seq < peer_handshake_seq_no) + { + handshake.set_handled(); + } +} + +fn mark_rejected_handshakes(incoming: &Incoming, peer_handshake_seq_no: u16) { + for record in incoming.records().iter() { + for handshake in record + .handshakes() + .iter() + .filter(|h| h.header.message_seq >= peer_handshake_seq_no) + { + handshake.set_handled(); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -2579,16 +2646,38 @@ mod tests { packet } - fn parsed_key_update(seq: u16) -> Incoming { + fn encrypted_app_data_record(seq: u16, payload: &[u8]) -> Vec { + let mut fragment = Vec::new(); + fragment.extend_from_slice(payload); + fragment.push(ContentType::ApplicationData.as_u8()); + + let mut packet = Vec::new(); + packet.push( + 0b0010_0000 + | 0b0000_1000 // 2-byte sequence number. + | 0b0000_0100 // explicit length. + | 0b0000_0010, // epoch bits resolved by PassthroughRecordHandler. + ); + packet.extend_from_slice(&seq.to_be_bytes()); + packet.extend_from_slice(&(fragment.len() as u16).to_be_bytes()); + packet.extend_from_slice(&fragment); + packet + } + + fn parsed_packet(packet: &[u8]) -> Incoming { Incoming::parse_packet( - &encrypted_key_update_record(seq), + packet, &mut PassthroughRecordHandler, Some(Dtls13CipherSuite::AES_128_GCM_SHA256), ) - .expect("parse key update packet") + .expect("parse packet") .expect("packet contains a record") } + fn parsed_key_update(seq: u16) -> Incoming { + parsed_packet(&encrypted_key_update_record(seq)) + } + /// Issue 2: Epoch-0 sequence number must have an overflow guard. /// /// Per RFC 9147 ยง4.2, implementations MUST NOT allow the sequence number @@ -2756,4 +2845,42 @@ mod tests { "malformed ACK vector length must not partially acknowledge records" ); } + + #[test] + #[cfg(feature = "rcgen")] + fn rollback_preserves_coalesced_application_data() { + let mut packet = encrypted_app_data_record(0, b"still-pending"); + packet.extend_from_slice(&encrypted_key_update_record(1)); + let incoming = parsed_packet(&packet); + + let app_record = incoming + .records() + .iter() + .find(|record| record.record().content_type == ContentType::ApplicationData) + .expect("packet contains application data"); + let key_update = incoming + .records() + .iter() + .flat_map(|record| record.handshakes()) + .find(|handshake| handshake.header.msg_type == MessageType::KeyUpdate) + .expect("packet contains KeyUpdate"); + + assert!(!app_record.is_handled()); + assert!(!key_update.is_handled()); + + let snapshot = HandshakeProgressSnapshot { + peer_handshake_seq_no: key_update.header.message_seq, + transcript_len: 0, + }; + mark_rejected_handshakes(&incoming, snapshot.peer_handshake_seq_no); + + assert!( + !app_record.is_handled(), + "rollback must not consume application data coalesced with a rejected handshake" + ); + assert!( + key_update.is_handled(), + "rollback should discard the rejected handshake" + ); + } } diff --git a/src/dtls13/server.rs b/src/dtls13/server.rs index 4711c8d4..7e8d5a1e 100644 --- a/src/dtls13/server.rs +++ b/src/dtls13/server.rs @@ -344,8 +344,17 @@ impl Server { fn make_progress(&mut self) -> Result<(), InternalError> { loop { let prev_state = self.state; + let snapshot = self.engine.handshake_progress_snapshot(); - let new_state = prev_state.make_progress(self)?; + let new_state = match prev_state.make_progress(self) { + Ok(new_state) => new_state, + Err(err) => { + if err.is_transient() { + self.engine.rollback_handshake_progress(snapshot); + } + return Err(err); + } + }; if prev_state != new_state { self.state = new_state; trace!("{:?} -> {:?}", prev_state, new_state); diff --git a/src/error.rs b/src/error.rs index 5856c569..21b90c5b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -684,6 +684,10 @@ impl InternalError { Self::Transient(TransientError::TooManyRecords) } + pub(crate) fn is_transient(&self) -> bool { + matches!(self, Self::Transient(_)) + } + pub(crate) fn into_public_error(self) -> Option { match self { Self::Transient(err) => { diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 6fd95d38..2f4af92f 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -32,6 +32,17 @@ fn dtls12_epoch1_record(seq: u64, len: usize) -> Vec { out } +fn dtls12_ccs_record(seq: u64) -> Vec { + let mut out = Vec::new(); + out.push(20); // ChangeCipherSpec + out.extend_from_slice(&[0xFE, 0xFD]); // DTLS 1.2 + out.extend_from_slice(&0u16.to_be_bytes()); // epoch 0 + out.extend_from_slice(&seq.to_be_bytes()[2..]); // u48 sequence number + out.extend_from_slice(&1u16.to_be_bytes()); // payload length + out.push(1); // change_cipher_spec + out +} + fn dtls12_config_for_suite(suite: Dtls12CipherSuite) -> Arc { let mut provider = Config::default().crypto_provider().clone(); let selected = provider @@ -62,6 +73,29 @@ fn dtls12_min_protected_fragment_len(suite: Dtls12CipherSuite) -> usize { } } +fn poison_extension_vector_len(packet: &mut [u8], extension_type: u16) { + let marker = extension_type.to_be_bytes(); + + for i in 0..packet.len().saturating_sub(6) { + if packet[i..i + 2] != marker { + continue; + } + + let extension_len = u16::from_be_bytes([packet[i + 2], packet[i + 3]]) as usize; + let extension_data_start = i + 4; + let extension_data_end = extension_data_start + extension_len; + if extension_len < 2 || extension_data_end > packet.len() { + continue; + } + + packet[extension_data_start..extension_data_start + 2] + .copy_from_slice(&(extension_len as u16).to_be_bytes()); + return; + } + + panic!("extension {extension_type:#06x} not found"); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_malformed_datagram_is_discarded_without_processing_alerts() { @@ -107,6 +141,50 @@ fn dtls12_too_many_control_records_are_discarded() { .expect("too many records should be discarded"); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls12_malformed_client_hello_extension_does_not_consume_sequence() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls12_config(); + let now = Instant::now(); + + let mut client = Dtls::new_12(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_12(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout start"); + client.handle_timeout(now).expect("client arm flight 1"); + let client_hello = collect_packets(&mut client); + assert_eq!(client_hello.len(), 1, "client should emit one ClientHello"); + + let mut poisoned = client_hello[0].clone(); + poison_extension_vector_len(&mut poisoned, 0x000e); // use_srtp + let mut mixed_poisoned = dtls12_ccs_record(42); + mixed_poisoned.extend_from_slice(&poisoned); + server + .handle_packet(&mixed_poisoned) + .expect("malformed ClientHello extension should be discarded"); + + server + .handle_packet(&client_hello[0]) + .expect("clean retransmission should still be accepted"); + server.handle_timeout(now).expect("server timeout"); + + let server_flight = collect_packets(&mut server); + assert!( + server_flight + .iter() + .flat_map(|packet| parse_handshake_types(packet)) + .any(|ty| ty == HELLO_VERIFY_REQUEST), + "server should respond to the clean ClientHello retransmission" + ); +} + #[test] #[cfg(feature = "rcgen")] fn dtls12_recovers_from_corrupted_packet() { diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 33e94423..64747426 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -50,6 +50,29 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { out } +fn poison_extension_vector_len(packet: &mut [u8], extension_type: u16) { + let marker = extension_type.to_be_bytes(); + + for i in 0..packet.len().saturating_sub(6) { + if packet[i..i + 2] != marker { + continue; + } + + let extension_len = u16::from_be_bytes([packet[i + 2], packet[i + 3]]) as usize; + let extension_data_start = i + 4; + let extension_data_end = extension_data_start + extension_len; + if extension_len < 2 || extension_data_end > packet.len() { + continue; + } + + packet[extension_data_start..extension_data_start + 2] + .copy_from_slice(&(extension_len as u16).to_be_bytes()); + return; + } + + panic!("extension {extension_type:#06x} not found"); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() { @@ -95,6 +118,40 @@ fn dtls13_too_many_control_records_are_discarded() { .expect("too many records should be discarded"); } +#[test] +#[cfg(feature = "rcgen")] +fn dtls13_malformed_client_hello_extension_does_not_poison_transcript() { + let _ = env_logger::try_init(); + + let client_cert = generate_self_signed_certificate().expect("gen client cert"); + let server_cert = generate_self_signed_certificate().expect("gen server cert"); + let config = dtls13_config(); + let now = Instant::now(); + + let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now); + client.set_active(true); + + let mut server = Dtls::new_13(config, server_cert, now); + server.set_active(false); + + client.handle_timeout(now).expect("client timeout start"); + let client_hello = collect_packets(&mut client); + assert_eq!(client_hello.len(), 1, "client should emit one ClientHello"); + + let mut poisoned = client_hello[0].clone(); + poison_extension_vector_len(&mut poisoned, 0x0033); // key_share + let mut mixed_poisoned = dtls13_ack_record(42); + mixed_poisoned.extend_from_slice(&poisoned); + server + .handle_packet(&mixed_poisoned) + .expect("malformed ClientHello extension should be discarded"); + + server + .handle_packet(&client_hello[0]) + .expect("clean retransmission should still be accepted"); + complete_dtls13_handshake(&mut client, &mut server, now); +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_discards_too_short_ciphertext_record() { diff --git a/tests/dtls13_cookie.rs b/tests/dtls13_cookie.rs index ac1391f1..e4009f35 100644 --- a/tests/dtls13_cookie.rs +++ b/tests/dtls13_cookie.rs @@ -6,8 +6,8 @@ mod common; use std::sync::Arc; use std::time::Instant; -use dimpl::Dtls; use dimpl::certificate::generate_self_signed_certificate; +use dimpl::{Dtls, NamedGroup}; use crate::common::{drain_outputs, dtls13_config}; @@ -94,6 +94,131 @@ fn shrink_dtls13_cookie_extension_inner_len(packet: &mut [u8]) -> bool { false } +fn dtls13_hrr_extension_types(packet: &[u8]) -> Option> { + const RECORD_HEADER_LEN: usize = 13; + const HANDSHAKE_HEADER_LEN: usize = 12; + + if packet.len() < RECORD_HEADER_LEN + HANDSHAKE_HEADER_LEN || packet[0] != 22 { + return None; + } + + let handshake = &packet[RECORD_HEADER_LEN..]; + let msg_type = handshake[0]; + let body_len = + ((handshake[1] as usize) << 16) | ((handshake[2] as usize) << 8) | handshake[3] as usize; + if handshake.len() < HANDSHAKE_HEADER_LEN + body_len { + return None; + } + + let body = &handshake[HANDSHAKE_HEADER_LEN..HANDSHAKE_HEADER_LEN + body_len]; + let mut pos = cookie_extensions_start(body, msg_type)?; + if body.len() < pos + 2 { + return None; + } + + let extensions_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; + pos += 2; + let extensions_end = pos + extensions_len; + if body.len() < extensions_end { + return None; + } + + let mut extensions = Vec::new(); + while pos + 4 <= extensions_end { + let extension_type = u16::from_be_bytes([body[pos], body[pos + 1]]); + let extension_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize; + pos += 4 + extension_len; + if pos > extensions_end { + return None; + } + extensions.push(extension_type); + } + + Some(extensions) +} + +fn insert_dtls13_hrr_key_share_before_cookie(packet: &mut Vec, selected_group: u16) -> bool { + const RECORD_HEADER_LEN: usize = 13; + const HANDSHAKE_HEADER_LEN: usize = 12; + const KEY_SHARE_EXTENSION: u16 = 0x0033; + const COOKIE_EXTENSION: u16 = 0x002C; + + if packet.len() < RECORD_HEADER_LEN + HANDSHAKE_HEADER_LEN || packet[0] != 22 { + return false; + } + + let handshake_start = RECORD_HEADER_LEN; + let body_start = handshake_start + HANDSHAKE_HEADER_LEN; + let body_len = ((packet[handshake_start + 1] as usize) << 16) + | ((packet[handshake_start + 2] as usize) << 8) + | packet[handshake_start + 3] as usize; + if packet.len() < body_start + body_len { + return false; + } + + let body = &packet[body_start..body_start + body_len]; + let mut pos = match cookie_extensions_start(body, packet[handshake_start]) { + Some(pos) => body_start + pos, + None => return false, + }; + if packet.len() < pos + 2 { + return false; + } + + let extensions_len = u16::from_be_bytes([packet[pos], packet[pos + 1]]) as usize; + let extensions_len_pos = pos; + pos += 2; + let extensions_end = pos + extensions_len; + if packet.len() < extensions_end { + return false; + } + + while pos + 4 <= extensions_end { + let extension_type = u16::from_be_bytes([packet[pos], packet[pos + 1]]); + let extension_len = u16::from_be_bytes([packet[pos + 2], packet[pos + 3]]) as usize; + let next = pos + 4 + extension_len; + if next > extensions_end { + return false; + } + + if extension_type == KEY_SHARE_EXTENSION { + return true; + } + + if extension_type == COOKIE_EXTENSION { + let key_share = [ + (KEY_SHARE_EXTENSION >> 8) as u8, + KEY_SHARE_EXTENSION as u8, + 0, + 2, + (selected_group >> 8) as u8, + selected_group as u8, + ]; + packet.splice(pos..pos, key_share); + + let new_extensions_len = extensions_len + key_share.len(); + packet[extensions_len_pos..extensions_len_pos + 2] + .copy_from_slice(&(new_extensions_len as u16).to_be_bytes()); + + let new_body_len = body_len + key_share.len(); + packet[handshake_start + 1] = ((new_body_len >> 16) & 0xff) as u8; + packet[handshake_start + 2] = ((new_body_len >> 8) & 0xff) as u8; + packet[handshake_start + 3] = (new_body_len & 0xff) as u8; + packet[handshake_start + 9] = ((new_body_len >> 16) & 0xff) as u8; + packet[handshake_start + 10] = ((new_body_len >> 8) & 0xff) as u8; + packet[handshake_start + 11] = (new_body_len & 0xff) as u8; + + let new_record_len = HANDSHAKE_HEADER_LEN + new_body_len; + packet[11..13].copy_from_slice(&(new_record_len as u16).to_be_bytes()); + return true; + } + + pos = next; + } + + false +} + #[test] fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { let _ = env_logger::try_init(); @@ -124,12 +249,30 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { .next() .expect("server should emit HRR"); assert!( - shrink_dtls13_cookie_extension_inner_len(&mut hrr), + insert_dtls13_hrr_key_share_before_cookie(&mut hrr, NamedGroup::X25519.as_u16()), + "fixture should insert HRR KeyShare before Cookie" + ); + let mut malformed_hrr = hrr.clone(); + let hrr_extensions = dtls13_hrr_extension_types(&hrr).expect("parse HRR extensions"); + let key_share_index = hrr_extensions + .iter() + .position(|ext| *ext == 0x0033) + .expect("HRR should contain KeyShare"); + let cookie_index = hrr_extensions + .iter() + .position(|ext| *ext == 0x002C) + .expect("HRR should contain Cookie"); + assert!( + key_share_index < cookie_index, + "fixture must parse KeyShare before rejecting malformed Cookie" + ); + assert!( + shrink_dtls13_cookie_extension_inner_len(&mut malformed_hrr), "fixture should contain a Cookie extension" ); client - .handle_packet(&hrr) + .handle_packet(&malformed_hrr) .expect("malformed HRR Cookie extension should be discarded"); client @@ -140,6 +283,18 @@ fn dtls13_client_rejects_hrr_cookie_extension_trailing_bytes() { client_out.packets.is_empty(), "client must not send CH2 after malformed HRR Cookie" ); + + client + .handle_packet(&hrr) + .expect("clean HRR retransmission should be accepted"); + client + .handle_timeout(now) + .expect("client timeout after clean HRR"); + let client_out = drain_outputs(&mut client); + assert!( + !client_out.packets.is_empty(), + "client should send CH2 after clean HRR retransmission" + ); } #[test]