From 34fca0519690c3205f39691a35b1b53db6114fd1 Mon Sep 17 00:00:00 2001 From: 39ali Date: Wed, 29 Apr 2026 11:48:32 +0300 Subject: [PATCH 1/3] make sort32 faster --- rust/spark-rs/src/sort.rs | 207 +++++++++++++++++++++++++++----------- 1 file changed, 146 insertions(+), 61 deletions(-) diff --git a/rust/spark-rs/src/sort.rs b/rust/spark-rs/src/sort.rs index 52130c99..d839dff3 100644 --- a/rust/spark-rs/src/sort.rs +++ b/rust/spark-rs/src/sort.rs @@ -23,7 +23,11 @@ impl SortBuffers { } pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result { - let SortBuffers { readback, ordering, buckets } = buffers; + let SortBuffers { + readback, + ordering, + buckets, + } = buffers; let readback = &readback[..num_splats]; // Set the bucket counts to zero @@ -57,15 +61,17 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result, /// output indices pub ordering: Vec, - /// bucket counts / offsets (length == RADIX_BASE) - pub buckets16lo: Vec, - /// bucket counts / offsets (length == RADIX_BASE) - pub buckets16hi: Vec, - /// scratch space for indices - pub scratch: Vec, + pub scratch: Vec, // (key, index) + pub buckets: Vec, // 2 * 65536 } impl Sort32Buffers { @@ -93,84 +95,167 @@ impl Sort32Buffers { if self.scratch.len() < max_splats { self.scratch.resize(max_splats, 0); } - if self.buckets16lo.len() < RADIX_BASE { - self.buckets16lo.resize(RADIX_BASE, 0); - } - if self.buckets16hi.len() < RADIX_BASE { - self.buckets16hi.resize(RADIX_BASE, 0); + if self.buckets.len() < RADIX_BASE * 2 { + self.buckets.resize(RADIX_BASE * 2, 0); } } } -/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns, -/// descending order (largest keys first). Mirrors the JS `sort32Splats`. +#[inline(always)] +unsafe fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 { + let mut sum = 0u32; + for b in buckets.iter_mut() { + let tmp = *b; + *b = sum; + sum = sum.wrapping_add(tmp); + } + sum +} + pub fn sort32_internal( buffers: &mut Sort32Buffers, max_splats: usize, num_splats: usize, ) -> Result { - // make sure our buffers can hold `max_splats` buffers.ensure_size(max_splats); - let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers; + let Sort32Buffers { + readback, + ordering, + scratch, + buckets, + } = buffers; let keys = &readback[..num_splats]; - // tally low and high buckets - buckets16lo.fill(0); - buckets16hi.fill(0); - for &key in keys.iter() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - buckets16lo[(inv & 0xFFFF) as usize] += 1; - buckets16hi[(inv >> 16) as usize] += 1; + // Split buckets + let (b0, b1) = buckets.split_at_mut(RADIX_BASE); + + b0.fill(0); + b1.fill(0); + + // pass 1: Histogram (branchless) + let mut chunks = keys.chunks_exact(8); + + for chunk in chunks.by_ref() { + macro_rules! tick { + ($k:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; + + let r0 = inv & RADIX_MASK; + let r1 = inv >> RADIX_BITS; + + b0[r0 as usize] += valid; + b1[r1 as usize] += valid; + }}; } + + tick!(chunk[0]); + tick!(chunk[1]); + tick!(chunk[2]); + tick!(chunk[3]); + tick!(chunk[4]); + tick!(chunk[5]); + tick!(chunk[6]); + tick!(chunk[7]); + } + + for &k in chunks.remainder() { + let valid = (k < DEPTH_INFINITY_F32) as u32; + let inv = !k; + b0[(inv & RADIX_MASK) as usize] += valid; + b1[(inv >> RADIX_BITS) as usize] += valid; } - // ——— Pass #1: bucket by inv(low 16 bits) ——— // exclusive prefix‑sum → starting offsets - let mut total: u32 = 0; - for slot in buckets16lo.iter_mut() { - let cnt = *slot; - *slot = total; - total = total.wrapping_add(cnt); + let active = unsafe { prefix_sum_exclusive(b0) } as usize; + unsafe { + prefix_sum_exclusive(b1); } - let active_splats = total; - - // scatter into scratch by low bits of inv - for (i, &key) in keys.iter().enumerate() { - if key < DEPTH_INFINITY_F32 { - let inv = !key; - let lo = (inv & 0xFFFF) as usize; - scratch[buckets16lo[lo] as usize] = i as u32; - buckets16lo[lo] += 1; + + // pass 1: scatter into scratch + let mut chunks = keys.chunks_exact(8); + let mut i = 0; + + for chunk in chunks.by_ref() { + macro_rules! place { + ($k:expr, $idx:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; + + let r0 = (inv & RADIX_MASK) as usize; + let pos = b0[r0] as usize; + + // Always write (branchless), but only advance if valid + scratch[pos] = ((inv as u64) << 32) | ($idx as u64); + b0[r0] += valid; + }}; } + + place!(chunk[0], i); + place!(chunk[1], i + 1); + place!(chunk[2], i + 2); + place!(chunk[3], i + 3); + place!(chunk[4], i + 4); + place!(chunk[5], i + 5); + place!(chunk[6], i + 6); + place!(chunk[7], i + 7); + + i += 8; } - // ——— Pass #2: bucket by inv(high 16 bits) ——— - // exclusive prefix‑sum again - let mut sum: u32 = 0; - for slot in buckets16hi.iter_mut() { - let cnt = *slot; - *slot = sum; - sum = sum.wrapping_add(cnt); + for &k in chunks.remainder() { + let valid = (k < DEPTH_INFINITY_F32) as u32; + let inv = !k; + + let r0 = (inv & RADIX_MASK) as usize; + let pos = b0[r0] as usize; + + scratch[pos] = ((inv as u64) << 32) | (i as u64); + b0[r0] += valid; + + i += 1; } - // scatter into final ordering by high bits of inv - for &idx in scratch.iter().take(active_splats as usize) { - let key = keys[idx as usize]; - let inv = !key; - let hi = (inv >> 16) as usize; - ordering[buckets16hi[hi] as usize] = idx; - buckets16hi[hi] += 1; + + // pass 2: scatter into final ordering + let mut chunks = scratch[..active].chunks_exact(8); + + for chunk in chunks.by_ref() { + macro_rules! place2 { + ($kv:expr) => {{ + let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize; + let pos = b1[r1] as usize; + + ordering[pos] = $kv as u32; + b1[r1] += 1; + }}; + } + + place2!(chunk[0]); + place2!(chunk[1]); + place2!(chunk[2]); + place2!(chunk[3]); + place2!(chunk[4]); + place2!(chunk[5]); + place2!(chunk[6]); + place2!(chunk[7]); + } + + for &kv in chunks.remainder() { + let r1 = ((kv >> 48) & RADIX_MASK as u64) as usize; + let pos = b1[r1] as usize; + + ordering[pos] = kv as u32; + b1[r1] += 1; } // sanity‑check: last bucket should have consumed all entries - if buckets16hi[RADIX_BASE - 1] != active_splats { + if b1[RADIX_BASE - 1] != active as u32 { return Err(format!( "Expected {} active splats but got {}", - active_splats, - buckets16hi[RADIX_BASE - 1] + active, b1[RADIX_BASE - 1] )); } - Ok(active_splats) -} \ No newline at end of file + Ok(active as u32) +} From d76eba7d25d1580dc203209bb0a693c63436a621 Mon Sep 17 00:00:00 2001 From: 39ali Date: Wed, 29 Apr 2026 20:17:12 +0300 Subject: [PATCH 2/3] remove more branches in wasm hot loops --- rust/spark-rs/src/sort.rs | 43 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/rust/spark-rs/src/sort.rs b/rust/spark-rs/src/sort.rs index d839dff3..19e26338 100644 --- a/rust/spark-rs/src/sort.rs +++ b/rust/spark-rs/src/sort.rs @@ -102,7 +102,7 @@ impl Sort32Buffers { } #[inline(always)] -unsafe fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 { +fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 { let mut sum = 0u32; for b in buckets.iter_mut() { let tmp = *b; @@ -129,6 +129,7 @@ pub fn sort32_internal( // Split buckets let (b0, b1) = buckets.split_at_mut(RADIX_BASE); + let b1 = &mut b1[..RADIX_BASE]; b0.fill(0); b1.fill(0); @@ -146,7 +147,7 @@ pub fn sort32_internal( let r1 = inv >> RADIX_BITS; b0[r0 as usize] += valid; - b1[r1 as usize] += valid; + unsafe { *b1.get_unchecked_mut(r1 as usize) += valid }; }}; } @@ -164,16 +165,13 @@ pub fn sort32_internal( let valid = (k < DEPTH_INFINITY_F32) as u32; let inv = !k; b0[(inv & RADIX_MASK) as usize] += valid; - b1[(inv >> RADIX_BITS) as usize] += valid; + unsafe { *b1.get_unchecked_mut((inv >> RADIX_BITS) as usize) += valid }; } - // exclusive prefix‑sum → starting offsets - let active = unsafe { prefix_sum_exclusive(b0) } as usize; - unsafe { - prefix_sum_exclusive(b1); - } + let active = prefix_sum_exclusive(b0) as usize; + prefix_sum_exclusive(b1); - // pass 1: scatter into scratch + // pass 1: scatter into scratch let mut chunks = keys.chunks_exact(8); let mut i = 0; @@ -184,11 +182,11 @@ pub fn sort32_internal( let inv = !$k; let r0 = (inv & RADIX_MASK) as usize; - let pos = b0[r0] as usize; + let pos = unsafe { *b0.get_unchecked(r0) } as usize; // Always write (branchless), but only advance if valid - scratch[pos] = ((inv as u64) << 32) | ($idx as u64); - b0[r0] += valid; + unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | ($idx as u64) }; + unsafe { *b0.get_unchecked_mut(r0) += valid }; }}; } @@ -209,10 +207,10 @@ pub fn sort32_internal( let inv = !k; let r0 = (inv & RADIX_MASK) as usize; - let pos = b0[r0] as usize; + let pos = unsafe { *b0.get_unchecked(r0) } as usize; - scratch[pos] = ((inv as u64) << 32) | (i as u64); - b0[r0] += valid; + unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | (i as u64) }; + unsafe { *b0.get_unchecked_mut(r0) += valid }; i += 1; } @@ -224,10 +222,10 @@ pub fn sort32_internal( macro_rules! place2 { ($kv:expr) => {{ let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize; - let pos = b1[r1] as usize; + let pos = unsafe { *b1.get_unchecked(r1) } as usize; - ordering[pos] = $kv as u32; - b1[r1] += 1; + unsafe { *ordering.get_unchecked_mut(pos) = $kv as u32 }; + unsafe { *b1.get_unchecked_mut(r1) += 1 }; }}; } @@ -243,17 +241,18 @@ pub fn sort32_internal( for &kv in chunks.remainder() { let r1 = ((kv >> 48) & RADIX_MASK as u64) as usize; - let pos = b1[r1] as usize; + let pos = unsafe { *b1.get_unchecked(r1) } as usize; - ordering[pos] = kv as u32; - b1[r1] += 1; + unsafe { *ordering.get_unchecked_mut(pos) = kv as u32 }; + unsafe { *b1.get_unchecked_mut(r1) += 1 }; } // sanity‑check: last bucket should have consumed all entries if b1[RADIX_BASE - 1] != active as u32 { return Err(format!( "Expected {} active splats but got {}", - active, b1[RADIX_BASE - 1] + active, + b1[RADIX_BASE - 1] )); } From a72213bad06c452e704be631c96bc07dda83d625 Mon Sep 17 00:00:00 2001 From: 39ali Date: Tue, 2 Jun 2026 08:25:28 +0300 Subject: [PATCH 3/3] cleanup --- rust/spark-rs/src/sort.rs | 93 ++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 54 deletions(-) diff --git a/rust/spark-rs/src/sort.rs b/rust/spark-rs/src/sort.rs index 19e26338..056f0e47 100644 --- a/rust/spark-rs/src/sort.rs +++ b/rust/spark-rs/src/sort.rs @@ -135,22 +135,22 @@ pub fn sort32_internal( b1.fill(0); // pass 1: Histogram (branchless) - let mut chunks = keys.chunks_exact(8); + macro_rules! tick { + ($k:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; - for chunk in chunks.by_ref() { - macro_rules! tick { - ($k:expr) => {{ - let valid = ($k < DEPTH_INFINITY_F32) as u32; - let inv = !$k; + let r0 = inv & RADIX_MASK; + let r1 = inv >> RADIX_BITS; - let r0 = inv & RADIX_MASK; - let r1 = inv >> RADIX_BITS; + b0[r0 as usize] += valid; + unsafe { *b1.get_unchecked_mut(r1 as usize) += valid }; + }}; + } - b0[r0 as usize] += valid; - unsafe { *b1.get_unchecked_mut(r1 as usize) += valid }; - }}; - } + let mut chunks = keys.chunks_exact(8); + for chunk in chunks.by_ref() { tick!(chunk[0]); tick!(chunk[1]); tick!(chunk[2]); @@ -162,34 +162,31 @@ pub fn sort32_internal( } for &k in chunks.remainder() { - let valid = (k < DEPTH_INFINITY_F32) as u32; - let inv = !k; - b0[(inv & RADIX_MASK) as usize] += valid; - unsafe { *b1.get_unchecked_mut((inv >> RADIX_BITS) as usize) += valid }; + tick!(k); } let active = prefix_sum_exclusive(b0) as usize; prefix_sum_exclusive(b1); // pass 1: scatter into scratch + macro_rules! place { + ($k:expr, $idx:expr) => {{ + let valid = ($k < DEPTH_INFINITY_F32) as u32; + let inv = !$k; + + let r0 = (inv & RADIX_MASK) as usize; + let pos = unsafe { *b0.get_unchecked(r0) } as usize; + + // Always write (branchless), but only advance if valid + unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | ($idx as u64) }; + unsafe { *b0.get_unchecked_mut(r0) += valid }; + }}; + } + let mut chunks = keys.chunks_exact(8); let mut i = 0; for chunk in chunks.by_ref() { - macro_rules! place { - ($k:expr, $idx:expr) => {{ - let valid = ($k < DEPTH_INFINITY_F32) as u32; - let inv = !$k; - - let r0 = (inv & RADIX_MASK) as usize; - let pos = unsafe { *b0.get_unchecked(r0) } as usize; - - // Always write (branchless), but only advance if valid - unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | ($idx as u64) }; - unsafe { *b0.get_unchecked_mut(r0) += valid }; - }}; - } - place!(chunk[0], i); place!(chunk[1], i + 1); place!(chunk[2], i + 2); @@ -203,32 +200,24 @@ pub fn sort32_internal( } for &k in chunks.remainder() { - let valid = (k < DEPTH_INFINITY_F32) as u32; - let inv = !k; - - let r0 = (inv & RADIX_MASK) as usize; - let pos = unsafe { *b0.get_unchecked(r0) } as usize; - - unsafe { *scratch.get_unchecked_mut(pos) = ((inv as u64) << 32) | (i as u64) }; - unsafe { *b0.get_unchecked_mut(r0) += valid }; - + place!(k, i); i += 1; } // pass 2: scatter into final ordering + macro_rules! place2 { + ($kv:expr) => {{ + let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize; + let pos = unsafe { *b1.get_unchecked(r1) } as usize; + + unsafe { *ordering.get_unchecked_mut(pos) = $kv as u32 }; + unsafe { *b1.get_unchecked_mut(r1) += 1 }; + }}; + } + let mut chunks = scratch[..active].chunks_exact(8); for chunk in chunks.by_ref() { - macro_rules! place2 { - ($kv:expr) => {{ - let r1 = (($kv >> 48) & RADIX_MASK as u64) as usize; - let pos = unsafe { *b1.get_unchecked(r1) } as usize; - - unsafe { *ordering.get_unchecked_mut(pos) = $kv as u32 }; - unsafe { *b1.get_unchecked_mut(r1) += 1 }; - }}; - } - place2!(chunk[0]); place2!(chunk[1]); place2!(chunk[2]); @@ -240,11 +229,7 @@ pub fn sort32_internal( } for &kv in chunks.remainder() { - let r1 = ((kv >> 48) & RADIX_MASK as u64) as usize; - let pos = unsafe { *b1.get_unchecked(r1) } as usize; - - unsafe { *ordering.get_unchecked_mut(pos) = kv as u32 }; - unsafe { *b1.get_unchecked_mut(r1) += 1 }; + place2!(kv); } // sanity‑check: last bucket should have consumed all entries