diff --git a/rust/spark-rs/src/sort.rs b/rust/spark-rs/src/sort.rs index 52130c99..056f0e47 100644 --- a/rust/spark-rs/src/sort.rs +++ b/rust/spark-rs/src/sort.rs @@ -23,7 +23,11 @@ impl SortBuffers { } pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result { - let SortBuffers { readback, ordering, buckets } = buffers; + let SortBuffers { + readback, + ordering, + buckets, + } = buffers; let readback = &readback[..num_splats]; // Set the bucket counts to zero @@ -57,15 +61,17 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result, /// output indices pub ordering: Vec, - /// bucket counts / offsets (length == RADIX_BASE) - pub buckets16lo: Vec, - /// bucket counts / offsets (length == RADIX_BASE) - pub buckets16hi: Vec, - /// scratch space for indices - pub scratch: Vec, + pub scratch: Vec, // (key, index) + pub buckets: Vec, // 2 * 65536 } impl Sort32Buffers { @@ -93,84 +95,151 @@ impl Sort32Buffers { if self.scratch.len() < max_splats { self.scratch.resize(max_splats, 0); } - if self.buckets16lo.len() < RADIX_BASE { - self.buckets16lo.resize(RADIX_BASE, 0); - } - if self.buckets16hi.len() < RADIX_BASE { - self.buckets16hi.resize(RADIX_BASE, 0); + if self.buckets.len() < RADIX_BASE * 2 { + self.buckets.resize(RADIX_BASE * 2, 0); } } } -/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns, -/// descending order (largest keys first). Mirrors the JS `sort32Splats`. +#[inline(always)] +fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 { + let mut sum = 0u32; + for b in buckets.iter_mut() { + let tmp = *b; + *b = sum; + sum = sum.wrapping_add(tmp); + } + sum +} + pub fn sort32_internal( buffers: &mut Sort32Buffers, max_splats: usize, num_splats: usize, ) -> Result { - // make sure our buffers can hold `max_splats` buffers.ensure_size(max_splats); - let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers; + let Sort32Buffers { + readback, + ordering, + scratch, + buckets, + } = buffers; let keys = &readback[..num_splats]; - // tally low and high buckets - buckets16lo.fill(0); - buckets16hi.fill(0); - for &key in keys.iter() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - buckets16lo[(inv & 0xFFFF) as usize] += 1; - buckets16hi[(inv >> 16) as usize] += 1; - } + // Split buckets + let (b0, b1) = buckets.split_at_mut(RADIX_BASE); + let b1 = &mut b1[..RADIX_BASE]; + + b0.fill(0); + b1.fill(0); + + // pass 1: Histogram (branchless) + macro_rules! tick { + ($k:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; + + let r0 = inv & RADIX_MASK; + let r1 = inv >> RADIX_BITS; + + b0[r0 as usize] += valid; + unsafe { *b1.get_unchecked_mut(r1 as usize) += valid }; + }}; } - // ——— Pass #1: bucket by inv(low 16 bits) ——— - // exclusive prefix‑sum → starting offsets - let mut total: u32 = 0; - for slot in buckets16lo.iter_mut() { - let cnt = *slot; - *slot = total; - total = total.wrapping_add(cnt); + let mut chunks = keys.chunks_exact(8); + + for chunk in chunks.by_ref() { + tick!(chunk[0]); + tick!(chunk[1]); + tick!(chunk[2]); + tick!(chunk[3]); + tick!(chunk[4]); + tick!(chunk[5]); + tick!(chunk[6]); + tick!(chunk[7]); } - let active_splats = total; - // scatter into scratch by low bits of inv - for (i, &key) in keys.iter().enumerate() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - let lo = (inv & 0xFFFF) as usize; - scratch[buckets16lo[lo] as usize] = i as u32; - buckets16lo[lo] += 1; - } + for &k in chunks.remainder() { + tick!(k); } - // ——— Pass #2: bucket by inv(high 16 bits) ——— - // exclusive prefix‑sum again - let mut sum: u32 = 0; - for slot in buckets16hi.iter_mut() { - let cnt = *slot; - *slot = sum; - sum = sum.wrapping_add(cnt); + let active = prefix_sum_exclusive(b0) as usize; + prefix_sum_exclusive(b1); + + // pass 1: scatter into scratch + macro_rules! place { + ($k:expr, $idx:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; + + let r0 = (inv & RADIX_MASK) as usize; + let pos = unsafe { *b0.get_unchecked(r0) } as usize; + + // Always write (branchless), but only advance if valid + unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | ($idx as u64) }; + unsafe { *b0.get_unchecked_mut(r0) += valid }; + }}; } - // scatter into final ordering by high bits of inv - for &idx in scratch.iter().take(active_splats as usize) { - let key = keys[idx as usize]; - let inv = !key; - let hi = (inv >> 16) as usize; - ordering[buckets16hi[hi] as usize] = idx; - buckets16hi[hi] += 1; + + let mut chunks = keys.chunks_exact(8); + let mut i = 0; + + for chunk in chunks.by_ref() { + place!(chunk[0], i); + place!(chunk[1], i + 1); + place!(chunk[2], i + 2); + place!(chunk[3], i + 3); + place!(chunk[4], i + 4); + place!(chunk[5], i + 5); + place!(chunk[6], i + 6); + place!(chunk[7], i + 7); + + i += 8; + } + + for &k in chunks.remainder() { + place!(k, i); + i += 1; + } + + // pass 2: scatter into final ordering + macro_rules! place2 { + ($kv:expr) => {{ + let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize; + let pos = unsafe { *b1.get_unchecked(r1) } as usize; + + unsafe { *ordering.get_unchecked_mut(pos) = $kv as u32 }; + unsafe { *b1.get_unchecked_mut(r1) += 1 }; + }}; + } + + let mut chunks = scratch[..active].chunks_exact(8); + + for chunk in chunks.by_ref() { + place2!(chunk[0]); + place2!(chunk[1]); + place2!(chunk[2]); + place2!(chunk[3]); + place2!(chunk[4]); + place2!(chunk[5]); + place2!(chunk[6]); + place2!(chunk[7]); + } + + for &kv in chunks.remainder() { + place2!(kv); } // sanity‑check: last bucket should have consumed all entries - if buckets16hi[RADIX_BASE - 1] != active_splats { + if b1[RADIX_BASE - 1] != active as u32 { return Err(format!( "Expected {} active splats but got {}", - active_splats, - buckets16hi[RADIX_BASE - 1] + active, + b1[RADIX_BASE - 1] )); } - Ok(active_splats) -} \ No newline at end of file + Ok(active as u32) +}