From 239eb42615bc8228251e0c9a366c23f7a55cf475 Mon Sep 17 00:00:00 2001 From: jotabulacios Date: Fri, 24 Apr 2026 16:01:02 -0300 Subject: [PATCH] Parallelize in_place_bit_reverse_permute --- crypto/math/src/fft/cpu/bit_reversing.rs | 41 ++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/crypto/math/src/fft/cpu/bit_reversing.rs b/crypto/math/src/fft/cpu/bit_reversing.rs index f225dd5e0..fd6936ff7 100644 --- a/crypto/math/src/fft/cpu/bit_reversing.rs +++ b/crypto/math/src/fft/cpu/bit_reversing.rs @@ -1,7 +1,42 @@ /// In-place bit-reverse permutation algorithm. Requires input length to be a power of two. -pub fn in_place_bit_reverse_permute(input: &mut [E]) { - for i in 0..input.len() { - let bit_reversed_index = reverse_index(i, input.len() as u64); +pub fn in_place_bit_reverse_permute(input: &mut [E]) { + let n = input.len(); + #[cfg(feature = "parallel")] + { + // Pair-parallel swap: each pair (i, br(i)) with i < br(i) is independent of all + // other pairs (disjoint indices), so threads can swap concurrently provided they + // never touch the same memory location. `if br > i` selects exactly one owner + // per pair, so no two threads ever write the same slot. + const PARALLEL_BITREV_THRESHOLD: usize = 1 << 14; + if n >= PARALLEL_BITREV_THRESHOLD { + use rayon::prelude::*; + struct SendPtr(*mut E); + impl Copy for SendPtr {} + impl Clone for SendPtr { + fn clone(&self) -> Self { + *self + } + } + unsafe impl Send for SendPtr {} + unsafe impl Sync for SendPtr {} + let ptr = SendPtr(input.as_mut_ptr()); + (0..n).into_par_iter().for_each(|i| { + let br = reverse_index(i, n as u64); + if br > i { + // SAFETY: (i, br) uniquely identifies this pair (smaller index is owner), + // so no two threads race on the same `ptr.0.add(k)` slot. Both indices + // are in-bounds since i < n and br < n. + let p = ptr; + unsafe { + core::ptr::swap(p.0.add(i), p.0.add(br)); + } + } + }); + return; + } + } + for i in 0..n { + let bit_reversed_index = reverse_index(i, n as u64); if bit_reversed_index > i { input.swap(i, bit_reversed_index); }