Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,12 @@ impl FlightDataEncoder {
DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?,
};

for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) {
let batches = split_batch_for_grpc_response(batch, self.max_flight_data_size);
let last = batches.len().saturating_sub(1); // handle empty batches
for (i, batch) in batches.into_iter().enumerate() {
self.encoder
.ipc_write_context
.set_reserve_scratch(i != last);
let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?;
for dict in flight_dictionaries {
self.queue_message(dict);
Expand Down Expand Up @@ -666,7 +671,7 @@ fn prepare_schema_for_flight(
fn split_batch_for_grpc_response(
batch: RecordBatch,
max_flight_data_size: usize,
) -> impl Iterator<Item = RecordBatch> {
) -> Vec<RecordBatch> {
let size = batch
.columns()
.iter()
Expand All @@ -678,17 +683,15 @@ fn split_batch_for_grpc_response(
let num_rows = batch.num_rows();
let rows_per_batch = (num_rows / n_batches).max(1);
let mut offset = 0;
let mut batches = Vec::with_capacity(n_batches);

std::iter::from_fn(move || {
if offset < num_rows {
let length = rows_per_batch.min(num_rows - offset);
let slice = batch.slice(offset, length);
offset += length;
Some(slice)
} else {
None
}
})
while offset < num_rows {
let length = rows_per_batch.min(num_rows - offset);
batches.push(batch.slice(offset, length));
offset += length;
}

batches
}

/// The data needed to encode a stream of flight data, holding on to
Expand Down Expand Up @@ -1858,8 +1861,7 @@ mod tests {
let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
let split: Vec<_> = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
assert_eq!(split.len(), 1);
assert_eq!(batch, split[0]);

Expand All @@ -1869,8 +1871,7 @@ mod tests {
let c = UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::<Vec<_>>());
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)])
.expect("cannot create record batch");
let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size).collect();
let split: Vec<_> = split_batch_for_grpc_response(batch.clone(), max_flight_data_size);
assert_eq!(split.len(), 3);
assert_eq!(
split.iter().map(|batch| batch.num_rows()).sum::<usize>(),
Expand Down Expand Up @@ -1915,7 +1916,7 @@ mod tests {
let input_rows = batch.num_rows();

let split: Vec<_> =
split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes).collect();
split_batch_for_grpc_response(batch.clone(), max_flight_data_size_bytes);
let sizes: Vec<_> = split.iter().map(RecordBatch::num_rows).collect();
let output_rows: usize = sizes.iter().sum();

Expand Down
9 changes: 8 additions & 1 deletion arrow-ipc/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ const DEFAULT_ZSTD_COMPRESSION_LEVEL: i32 = 3;
/// compression. Also holds a [`FlatBufferBuilder`] that is reused across IPC writes.
#[derive(Default)]
pub struct IpcWriteContext {
#[expect(dead_code)]
pub(crate) scratch: Vec<u8>,
pub(crate) reserve_scratch: bool,
fbb: FlatBufferBuilder<'static>,
#[cfg(feature = "zstd")]
compressor: Option<zstd::bulk::Compressor<'static>>,
Expand All @@ -44,6 +44,13 @@ impl IpcWriteContext {
&mut self.fbb
}

/// Set whether the scratch buffer capacity should be reserved after each encode for reuse
/// on the next call. Set to `false` for the final batch in a sequence to avoid a
/// pointless allocation.
pub fn set_reserve_scratch(&mut self, reserve: bool) {
self.reserve_scratch = reserve;
}

#[cfg(feature = "zstd")]
fn zstd_compressor(&mut self, level: i32) -> &mut zstd::bulk::Compressor<'static> {
self.compressor.get_or_insert_with(|| {
Expand Down
6 changes: 5 additions & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,18 @@ impl IpcDataGenerator {
) -> Result<(Vec<EncodedData>, EncodedData), ArrowError> {
let encoded_dictionaries =
self.encode_all_dicts(batch, dictionary_tracker, write_options, ipc_write_context)?;
let mut arrow_data = Vec::new();
let mut arrow_data = std::mem::take(&mut ipc_write_context.scratch);
let (ipc_message, _, tail_pad) = self.record_batch_to_bytes(
batch,
write_options,
ipc_write_context,
&mut IpcBodySink::Write(&mut arrow_data),
)?;
arrow_data.extend_from_slice(&PADDING[..tail_pad]);
let final_capcity = arrow_data.capacity();
if ipc_write_context.reserve_scratch {
ipc_write_context.scratch.reserve(final_capcity);
}
Ok((
encoded_dictionaries,
EncodedData {
Expand Down
Loading