diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index c88ab71ce476..6adf4153c06a 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -373,7 +373,12 @@ impl FlightDataEncoder { DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, }; - for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { + let batches = split_batch_for_grpc_response(batch, self.max_flight_data_size); + let last = batches.len().saturating_sub(1); // handle empty batches + for (i, batch) in batches.into_iter().enumerate() { + self.encoder + .ipc_write_context + .set_reserve_scratch(i != last); let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; for dict in flight_dictionaries { self.queue_message(dict); @@ -666,7 +671,7 @@ fn prepare_schema_for_flight( fn split_batch_for_grpc_response( batch: RecordBatch, max_flight_data_size: usize, -) -> impl Iterator { +) -> Vec { let size = batch .columns() .iter() @@ -678,17 +683,15 @@ fn split_batch_for_grpc_response( let num_rows = batch.num_rows(); let rows_per_batch = (num_rows / n_batches).max(1); let mut offset = 0; + let mut batches = Vec::with_capacity(n_batches); - std::iter::from_fn(move || { - if offset < num_rows { - let length = rows_per_batch.min(num_rows - offset); - let slice = batch.slice(offset, length); - offset += length; - Some(slice) - } else { - None - } - }) + while offset < num_rows { + let length = rows_per_batch.min(num_rows - offset); + batches.push(batch.slice(offset, length)); + offset += length; + } + + batches } /// The data needed to encode a stream of flight data, holding on to @@ -1858,8 +1861,7 @@ mod tests { let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split: Vec<_> = - split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect(); + let split: Vec<_> = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); assert_eq!(split.len(), 1); assert_eq!(batch, split[0]); @@ -1869,8 +1871,7 @@ mod tests { let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) .expect("cannot create record batch"); - let split: Vec<_> = - split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect(); + let split: Vec<_> = split_batch_for_grpc_response(batch.clone(), max_flight_data_size); assert_eq!(split.len(), 3); assert_eq!( split.iter().map(|batch| batch.num_rows()).sum::(), @@ -1915,7 +1916,7 @@ mod tests { let input_rows = batch.num_rows(); let split: Vec<_> = - split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes).collect(); + split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes); let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect(); let output_rows: usize = sizes.iter().sum(); diff --git a/arrow-ipc/src/compression.rs b/arrow-ipc/src/compression.rs index 79879332d4e0..9bf596f838bc 100644 --- a/arrow-ipc/src/compression.rs +++ b/arrow-ipc/src/compression.rs @@ -31,8 +31,8 @@ const DEFAULT_ZSTD_COMPRESSION_LEVEL: i32 = 3; /// compression. Also holds a [`FlatBufferBuilder`] that is reused across IPC writes. #[derive(Default)] pub struct IpcWriteContext { - #[expect(dead_code)] pub(crate) scratch: Vec, + pub(crate) reserve_scratch: bool, fbb: FlatBufferBuilder<'static>, #[cfg(feature = "zstd")] compressor: Option>, @@ -44,6 +44,13 @@ impl IpcWriteContext { &mut self.fbb } + /// Set whether the scratch buffer capacity should be reserved after each encode for reuse + /// on the next call. Set to `false` for the final batch in a sequence to avoid a + /// pointless allocation. + pub fn set_reserve_scratch(&mut self, reserve: bool) { + self.reserve_scratch = reserve; + } + #[cfg(feature = "zstd")] fn zstd_compressor(&mut self, level: i32) -> &mut zstd::bulk::Compressor<'static> { self.compressor.get_or_insert_with(|| { diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 6ae64843731f..ba667c648204 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -622,7 +622,7 @@ impl IpcDataGenerator { ) -> Result<(Vec, EncodedData), ArrowError> { let encoded_dictionaries = self.encode_all_dicts(batch, dictionary_tracker, write_options, ipc_write_context)?; - let mut arrow_data = Vec::new(); + let mut arrow_data = std::mem::take(&mut ipc_write_context.scratch); let (ipc_message, _, tail_pad) = self.record_batch_to_bytes( batch, write_options, @@ -630,6 +630,10 @@ impl IpcDataGenerator { &mut IpcBodySink::Write(&mut arrow_data), )?; arrow_data.extend_from_slice(&PADDING[..tail_pad]); + let final_capcity = arrow_data.capacity(); + if ipc_write_context.reserve_scratch { + ipc_write_context.scratch.reserve(final_capcity); + } Ok(( encoded_dictionaries, EncodedData {