diff --git a/CHANGELOG.md b/CHANGELOG.md index b35cc0e8..1bc3c04a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased + * Return `BufferTooSmall` instead of panicking on undersized poll buffers #150 * Fix DTLS 1.3 RFC 9147 conformance issues #147 * Reject malformed fragmented DTLS handshakes before consuming fragments #144 * Represent DTLS wire-code identifiers as compact newtypes (breaking) #137 diff --git a/README.md b/README.md index 88c16c63..cbc924a3 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ Drive the engine with three calls: The output is an [`Output`][output] enum with borrowed references into your provided buffer: - `Packet(&[u8])`: send on your UDP socket +- `BufferTooSmall { needed }`: resize the poll buffer and retry - `Timeout(Instant)`: schedule a timer and call `handle_timeout` at/after it - `Connected`: handshake complete - `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app @@ -113,6 +114,9 @@ fn example_event_loop(mut dtls: Dtls) -> Result<(), dimpl::Error> { loop { match dtls.poll_output(&mut out_buf) { Output::Packet(p) => send_udp(p), + Output::BufferTooSmall { needed } => { + out_buf.resize(needed, 0); + } Output::Timeout(t) => { next_wake = Some(t); break; } Output::Connected => { // DTLS established — application may start sending diff --git a/src/auto.rs b/src/auto.rs index 410b4bd4..c32219cc 100644 --- a/src/auto.rs +++ b/src/auto.rs @@ -318,12 +318,8 @@ impl ClientPending { if self.needs_send { let len = self.wire_packet.len(); if buf.len() < len { - // Buffer too small; keep needs_send armed so the packet - // is emitted on the next poll with a sufficiently large buffer. - let next = self - .retransmit_at - .unwrap_or(self.last_now + Duration::from_secs(1)); - return Output::Timeout(next); + // Keep needs_send armed so the packet is emitted on retry. + return Output::BufferTooSmall { needed: len }; } self.needs_send = false; buf[..len].copy_from_slice(&self.wire_packet); diff --git a/src/dtls12/engine.rs b/src/dtls12/engine.rs index 3bd86e99..adce4d58 100644 --- a/src/dtls12/engine.rs +++ b/src/dtls12/engine.rs @@ -128,6 +128,12 @@ struct Entry { fragment: Buf, } +enum PollOutput<'a> { + Data(&'a [u8]), + BufferTooSmall { needed: usize }, + None(&'a mut [u8]), +} + impl Engine { pub fn new(config: Arc, auth: AuthMode) -> Self { let mut rng = SeededRng::new(config.rng_seed()); @@ -414,14 +420,16 @@ impl Engine { // Drain incoming queue of processed records. self.purge_handled_queue_rx(); - // First check if we have any decrypted app data. let buf = match self.poll_app_data(buf) { - Ok(p) => return Output::ApplicationData(p), - Err(b) => b, + PollOutput::Data(p) => return Output::ApplicationData(p), + PollOutput::BufferTooSmall { needed } => return Output::BufferTooSmall { needed }, + PollOutput::None(b) => b, }; - if let Ok(p) = self.poll_packet_tx(buf) { - return Output::Packet(p); + match self.poll_packet_tx(buf) { + PollOutput::Data(p) => return Output::Packet(p), + PollOutput::BufferTooSmall { needed } => return Output::BufferTooSmall { needed }, + PollOutput::None(_) => {} } if self.close_notify_received && !self.close_notify_reported { @@ -434,9 +442,9 @@ impl Engine { Output::Timeout(next_timeout) } - fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { + fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> PollOutput<'a> { if !self.release_app_data { - return Err(buf); + return PollOutput::None(buf); } let mut unhandled = self @@ -447,24 +455,21 @@ impl Engine { .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { - return Err(buf); + return PollOutput::None(buf); }; let record_buffer = next.buffer(); let fragment = next.record().fragment(record_buffer); let len = fragment.len(); - assert!( - len <= buf.len(), - "Output buffer too small for application data {} > {}", - len, - buf.len() - ); + if len > buf.len() { + return PollOutput::BufferTooSmall { needed: len }; + } buf[..len].copy_from_slice(fragment); next.set_handled(); - Ok(&buf[..len]) + PollOutput::Data(&buf[..len]) } fn purge_handled_queue_rx(&mut self) { @@ -482,22 +487,23 @@ impl Engine { } } - fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { - let Some(p) = self.queue_tx.pop_front() else { - return Err(buf); + fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> PollOutput<'a> { + let Some(p) = self.queue_tx.front() else { + return PollOutput::None(buf); }; - assert!( - p.len() <= buf.len(), - "Output buffer too small for packet {} > {}", - p.len(), - buf.len() - ); + if p.len() > buf.len() { + return PollOutput::BufferTooSmall { needed: p.len() }; + } + let p = self + .queue_tx + .pop_front() + .expect("queue front checked above"); let len = p.len(); buf[..len].copy_from_slice(&p); - Ok(&buf[..len]) + PollOutput::Data(&buf[..len]) } fn poll_timeout(&self, now: Instant) -> Instant { diff --git a/src/dtls13/engine.rs b/src/dtls13/engine.rs index 5afd1901..d2c8e73b 100644 --- a/src/dtls13/engine.rs +++ b/src/dtls13/engine.rs @@ -201,6 +201,12 @@ struct Entry { acked: bool, } +enum PollOutput<'a> { + Data(&'a [u8]), + BufferTooSmall { needed: usize }, + None(&'a mut [u8]), +} + impl Engine { pub fn new(config: Arc, certificate: DtlsCertificate) -> Self { let mut rng = SeededRng::new(config.rng_seed()); @@ -535,14 +541,17 @@ impl Engine { self.purge_handled_queue_rx(); let buf = match self.poll_app_data(buf) { - Ok(p) => return Output::ApplicationData(p), - Err(b) => b, + PollOutput::Data(p) => return Output::ApplicationData(p), + PollOutput::BufferTooSmall { needed } => return Output::BufferTooSmall { needed }, + PollOutput::None(b) => b, }; self.maybe_schedule_handshake_ack(now); - if let Ok(p) = self.poll_packet_tx(buf) { - return Output::Packet(p); + match self.poll_packet_tx(buf) { + PollOutput::Data(p) => return Output::Packet(p), + PollOutput::BufferTooSmall { needed } => return Output::BufferTooSmall { needed }, + PollOutput::None(_) => {} } if self.close_notify_sequence.is_some() && !self.close_notify_reported { @@ -555,9 +564,9 @@ impl Engine { Output::Timeout(next_timeout) } - fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { + fn poll_app_data<'a>(&mut self, buf: &'a mut [u8]) -> PollOutput<'a> { if !self.release_app_data { - return Err(buf); + return PollOutput::None(buf); } let mut unhandled = self @@ -568,24 +577,21 @@ impl Engine { .skip_while(|r| r.is_handled()); let Some(next) = unhandled.next() else { - return Err(buf); + return PollOutput::None(buf); }; let record_buffer = next.buffer(); let fragment = next.record().fragment(record_buffer); let len = fragment.len(); - assert!( - len <= buf.len(), - "Output buffer too small for application data {} > {}", - len, - buf.len() - ); + if len > buf.len() { + return PollOutput::BufferTooSmall { needed: len }; + } buf[..len].copy_from_slice(fragment); next.set_handled(); - Ok(&buf[..len]) + PollOutput::Data(&buf[..len]) } fn purge_handled_queue_rx(&mut self) { @@ -603,22 +609,23 @@ impl Engine { } } - fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], &'a mut [u8]> { - let Some(p) = self.queue_tx.pop_front() else { - return Err(buf); + fn poll_packet_tx<'a>(&mut self, buf: &'a mut [u8]) -> PollOutput<'a> { + let Some(p) = self.queue_tx.front() else { + return PollOutput::None(buf); }; - assert!( - p.len() <= buf.len(), - "Output buffer too small for packet {} > {}", - p.len(), - buf.len() - ); + if p.len() > buf.len() { + return PollOutput::BufferTooSmall { needed: p.len() }; + } + let p = self + .queue_tx + .pop_front() + .expect("queue front checked above"); let len = p.len(); buf[..len].copy_from_slice(&p); - Ok(&buf[..len]) + PollOutput::Data(&buf[..len]) } /// Prevent subsequent records from being appended to the current last diff --git a/src/lib.rs b/src/lib.rs index e1680c0d..7fa84a3b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,6 +66,7 @@ //! The output is an [`Output`][output] enum with borrowed //! references into your provided buffer: //! - `Packet(&[u8])`: send on your UDP socket +//! - `BufferTooSmall { needed }`: resize the poll buffer and retry //! - `Timeout(Instant)`: schedule a timer and call `handle_timeout` at/after it //! - `Connected`: handshake complete //! - `PeerCert(&[u8])`: peer leaf certificate (DER) — validate in your app @@ -113,6 +114,9 @@ //! loop { //! match dtls.poll_output(&mut out_buf) { //! Output::Packet(p) => send_udp(p), +//! Output::BufferTooSmall { needed } => { +//! out_buf.resize(needed, 0); +//! } //! Output::Timeout(t) => { next_wake = Some(t); break; } //! Output::Connected => { //! // DTLS established — application may start sending @@ -839,6 +843,14 @@ impl fmt::Debug for Dtls { pub enum Output<'a> { /// A DTLS record to transmit on the wire. Packet(&'a [u8]), + /// The provided output buffer is too small for the next pending output. + /// + /// Retry [`Dtls::poll_output`] with a buffer of at least `needed` bytes. + /// The pending output is retained until it can be emitted. + BufferTooSmall { + /// Minimum buffer length required to emit the pending output. + needed: usize, + }, /// Schedule a timer and call [`Dtls::handle_timeout`] at this instant. /// /// This is always the last variant returned by a poll cycle. @@ -863,6 +875,7 @@ impl fmt::Debug for Output<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Packet(v) => write!(f, "Packet({})", v.len()), + Self::BufferTooSmall { needed } => write!(f, "BufferTooSmall({needed})"), Self::Timeout(v) => write!(f, "Timeout({:?})", v), Self::Connected => write!(f, "Connected"), Self::PeerCert(v) => write!(f, "PeerCert({})", v.len()), diff --git a/tests/auto/handshake.rs b/tests/auto/handshake.rs index ea421e00..920d1692 100644 --- a/tests/auto/handshake.rs +++ b/tests/auto/handshake.rs @@ -546,10 +546,10 @@ fn auto_client_poll_output_undersized_buffer() { let mut tiny_buf = [0u8; 4]; let output = client.poll_output(&mut tiny_buf); - // Should return Timeout (packet deferred), not a Packet. + // Should report the required size while keeping the packet deferred. assert!( - matches!(output, Output::Timeout(_)), - "undersized buffer should yield Timeout, got: {output:?}" + matches!(output, Output::BufferTooSmall { .. }), + "undersized buffer should yield BufferTooSmall, got: {output:?}" ); // Now poll with a large buffer — the deferred packet should come through. diff --git a/tests/dtls12/edge.rs b/tests/dtls12/edge.rs index 6fd95d38..592637b3 100644 --- a/tests/dtls12/edge.rs +++ b/tests/dtls12/edge.rs @@ -32,6 +32,41 @@ fn dtls12_epoch1_record(seq: u64, len: usize) -> Vec { out } +#[test] +#[cfg(feature = "rcgen")] +fn oversized_application_data_reports_buffer_too_small() { + let now = Instant::now(); + let (mut client, mut server, _now) = setup_connected_12_pair(now); + let payload = vec![0x5a; 4000]; + + client + .send_application_data(&payload) + .expect("send application data"); + + let mut large_buf = vec![0u8; 8192]; + let mut packets = Vec::new(); + loop { + match client.poll_output(&mut large_buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Timeout(_) => break, + _ => {} + } + } + + deliver_packets(&packets, &mut server); + + let mut small_buf = vec![0u8; 2048]; + match server.poll_output(&mut small_buf) { + Output::BufferTooSmall { needed } => assert_eq!(needed, payload.len()), + output => panic!("expected BufferTooSmall, got {output:?}"), + } + + match server.poll_output(&mut large_buf) { + Output::ApplicationData(data) => assert_eq!(data, payload.as_slice()), + output => panic!("expected retained application data, got {output:?}"), + } +} + fn dtls12_config_for_suite(suite: Dtls12CipherSuite) -> Arc { let mut provider = Config::default().crypto_provider().clone(); let selected = provider diff --git a/tests/dtls13/edge.rs b/tests/dtls13/edge.rs index 33e94423..644a4f1f 100644 --- a/tests/dtls13/edge.rs +++ b/tests/dtls13/edge.rs @@ -50,6 +50,41 @@ fn dtls13_ack_record_for_records(seq: u64, records: &[(u64, u64)]) -> Vec { out } +#[test] +#[cfg(feature = "rcgen")] +fn oversized_application_data_reports_buffer_too_small() { + let now = Instant::now(); + let (mut client, mut server, _now) = setup_connected_13_pair(now); + let payload = vec![0xa5; 4000]; + + client + .send_application_data(&payload) + .expect("send application data"); + + let mut large_buf = vec![0u8; 8192]; + let mut packets = Vec::new(); + loop { + match client.poll_output(&mut large_buf) { + Output::Packet(packet) => packets.push(packet.to_vec()), + Output::Timeout(_) => break, + _ => {} + } + } + + deliver_packets(&packets, &mut server); + + let mut small_buf = vec![0u8; 2048]; + match server.poll_output(&mut small_buf) { + Output::BufferTooSmall { needed } => assert_eq!(needed, payload.len()), + output => panic!("expected BufferTooSmall, got {output:?}"), + } + + match server.poll_output(&mut large_buf) { + Output::ApplicationData(data) => assert_eq!(data, payload.as_slice()), + output => panic!("expected retained application data, got {output:?}"), + } +} + #[test] #[cfg(feature = "rcgen")] fn dtls13_malformed_datagram_is_discarded_without_processing_alerts() {