From 6d3840b1814fd78cf50bbe16b79781a069c926f7 Mon Sep 17 00:00:00 2001 From: GordonYuanyc Date: Fri, 12 Jun 2026 19:18:04 -0400 Subject: [PATCH 1/2] feat(vector3d, count-hll): fill in Vector3D + per-bucket-HLL CountHll sketch Flesh out the Vector3D stub to full Vector2D parity: a rows x cols grid where each (row, col) cell is a contiguous `depth`-length bucket. Adds init/from_fn/fill, element + bucket accessors, fast_insert, the fast_query_min/median/max[_with_key] + aggregate family over bucket slices, Nitro hooks, custom serde (recomputes col mask), and Index/IndexMut. Add CountHll: a Count Sketch grid whose cells are per-bucket HyperLogLog sketches (Vector3D, depth = 2^precision). Each item is routed to one column per row and recorded in that bucket's HLL. Supports estimate() (per-key distinct count, median across rows) and estimate_total_cardinality() (per-row column partition sum), plus register-max merge and msgpack serde. The HLL register/rank math and classic estimator mirror sketches::hll. Purely additive: Vector3D had no real callers, and only a new module + re-export were added. No existing API changed. fmt clean; new code is clippy-clean; all 489 lib + integration tests pass. Co-Authored-By: Claude Opus 4.8 --- src/common/structures/vector3d.rs | 627 +++++++++++++++++++++++++++++- src/sketches/countsketch_hll.rs | 371 ++++++++++++++++++ src/sketches/mod.rs | 4 + 3 files changed, 991 insertions(+), 11 deletions(-) create mode 100644 src/sketches/countsketch_hll.rs diff --git a/src/common/structures/vector3d.rs b/src/common/structures/vector3d.rs index 46ceb99..1743cb9 100644 --- a/src/common/structures/vector3d.rs +++ b/src/common/structures/vector3d.rs @@ -1,22 +1,627 @@ use serde::{Deserialize, Serialize}; +use std::ops::{Index, IndexMut}; -/// Shared thin wrapper over `Vec` tailored for sketches. -#[derive(Clone, Debug, Serialize, Deserialize)] +use crate::{MatrixFastHash, MatrixHashType, Nitro, compute_median_inline_f64}; + +/// Shared thin wrapper over `Vec` tailored for layered / per-bucket sketches. +/// +/// `Vector3D` is the three-dimensional sibling of [`crate::Vector2D`]. It models +/// a `rows * cols` grid where **every `(row, col)` cell is itself a contiguous +/// run of `depth` elements** — a "bucket". This is the natural storage for +/// sketches that keep a small vector (e.g. a HyperLogLog register array) at each +/// matrix position. +/// +/// Storage is a single flat `Vec` in row-major / bucket-major order; the +/// element at `(row, col, d)` lives at `(row * cols + col) * depth + d`. +/// +/// The row/column addressing (mask bits, `col_for_row`, hashing) mirrors +/// [`crate::Vector2D`] exactly, so the same `MatrixFastHash` machinery selects a +/// column per row; the third dimension is addressed within the selected bucket. +#[derive(Clone, Debug, Serialize)] pub struct Vector3D { data: Vec, - layer: usize, - row: usize, - col: usize, + rows: usize, + cols: usize, + depth: usize, + mask_bits: u32, + mask: u128, + nitro: Nitro, +} + +// Helper type for deserialization: we only read stored fields and recompute +// derived ones (mask_bits, mask) from cols, mirroring `Vector2D`. +#[derive(Deserialize)] +struct Vector3DDeserialize { + data: Vec, + rows: usize, + cols: usize, + depth: usize, + #[serde(default)] + nitro: Nitro, +} + +#[inline] +fn mask_bits_for_cols(cols: usize) -> u32 { + if cols.is_power_of_two() { + cols.ilog2() + } else { + cols.ilog2() + 1 + } +} + +impl<'de, T> Deserialize<'de> for Vector3D +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let input = Vector3DDeserialize::deserialize(deserializer)?; + let mask_bits = mask_bits_for_cols(input.cols); + let mask = (1u128 << mask_bits) - 1; + Ok(Self { + data: input.data, + rows: input.rows, + cols: input.cols, + depth: input.depth, + mask_bits, + mask, + nitro: input.nitro, + }) + } } impl Vector3D { - /// Creates an empty 3D container with reserved capacity. - pub fn init(layer: usize, row: usize, col: usize) -> Self { + /// Creates an empty container with reserved capacity for `rows * cols * depth` + /// elements. The underlying storage is left uninitialized until `fill` or + /// similar methods are called, allowing callers to decide when and how cells + /// are populated. + pub fn init(rows: usize, cols: usize, depth: usize) -> Self { + let mask_bits = mask_bits_for_cols(cols); + let mask = (1u128 << mask_bits) - 1; + Self { + data: Vec::with_capacity(rows * cols * depth), + rows, + cols, + depth, + mask_bits, + mask, + nitro: Nitro::default(), + } + } + + /// Builds a container by invoking a generator for every `(row, col, d)` + /// position. Useful for types that require per-cell construction logic + /// instead of cloning a single value across all cells. + pub fn from_fn(rows: usize, cols: usize, depth: usize, mut f: F) -> Self + where + F: FnMut(usize, usize, usize) -> T, + { + let mask_bits = mask_bits_for_cols(cols); + let mask = (1u128 << mask_bits) - 1; + let mut data = Vec::with_capacity(rows * cols * depth); + for r in 0..rows { + for c in 0..cols { + for d in 0..depth { + data.push(f(r, c, d)); + } + } + } Self { - data: Vec::with_capacity(layer * row * col), - layer, - row, - col, + data, + rows, + cols, + depth, + mask_bits, + mask, + nitro: Nitro::default(), + } + } + + /// Enables Nitro sampling with the provided rate. + pub fn enable_nitro(&mut self, sampling_rate: f64) { + self.nitro = Nitro::init_nitro(sampling_rate); + } + + /// Disables Nitro sampling and resets the internal state. + pub fn disable_nitro(&mut self) { + self.nitro = Nitro::default(); + } + + #[inline(always)] + /// Decrements the Nitro skip counter by one. + pub fn reduce_to_skip(&mut self) { + self.nitro.reduce_to_skip(); + } + + /// Returns the Nitro configuration. + #[inline(always)] + pub fn nitro(&self) -> &Nitro { + &self.nitro + } + + #[inline(always)] + /// Returns the current Nitro delta weight. + pub fn get_delta(&self) -> u64 { + self.nitro.delta + } + + /// Returns a mutable Nitro configuration reference. + #[inline(always)] + pub fn nitro_mut(&mut self) -> &mut Nitro { + &mut self.nitro + } + + /// Replaces the entire container with `rows * cols * depth` clones of `value`, + /// reusing the existing allocation. This is the most efficient way to reset + /// cells to a baseline without reallocating. + pub fn fill(&mut self, value: T) + where + T: Clone, + { + self.data.clear(); + self.data.resize(self.rows * self.cols * self.depth, value); + } + + #[inline(always)] + fn col_for_row(&self, hashed_val: &Hash, row: usize) -> usize { + hashed_val.col_for_row(row, self.cols) + } + + #[inline(always)] + fn bucket_start(&self, row: usize, col: usize) -> usize { + (row * self.cols + col) * self.depth + } + + /// Returns the number of rows. + #[inline(always)] + pub fn rows(&self) -> usize { + self.rows + } + + /// Returns the number of columns. + #[inline(always)] + pub fn cols(&self) -> usize { + self.cols + } + + /// Returns the per-bucket depth (length of each `(row, col)` cell). + #[inline(always)] + pub fn depth(&self) -> usize { + self.depth + } + + /// Allocates one extra row initialized with `value`. + pub fn allocate_extra_row(&mut self, value: T) + where + T: Clone, + { + self.rows += 1; + self.data.resize(self.rows * self.cols * self.depth, value); + } + + /// Returns the total number of elements. + #[inline(always)] + pub fn len(&self) -> usize { + self.data.len() + } + + #[inline(always)] + /// Returns `true` when the container stores no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Provides immutable access to the flattened storage. + #[inline(always)] + pub fn as_slice(&self) -> &[T] { + &self.data + } + + /// Provides mutable access to the flattened storage. + #[inline(always)] + pub fn as_mut_slice(&mut self) -> &mut [T] { + &mut self.data + } + + /// Returns a reference to a single element when it exists. + #[inline(always)] + pub fn get(&self, row: usize, col: usize, d: usize) -> Option<&T> { + if row < self.rows && col < self.cols && d < self.depth { + Some(&self.data[self.bucket_start(row, col) + d]) + } else { + None + } + } + + /// Returns a mutable reference to a single element when it exists. + #[inline(always)] + pub fn get_mut(&mut self, row: usize, col: usize, d: usize) -> Option<&mut T> { + if row < self.rows && col < self.cols && d < self.depth { + let idx = self.bucket_start(row, col) + d; + Some(&mut self.data[idx]) + } else { + None + } + } + + /// Returns the `(row, col)` bucket slice when it exists. + #[inline(always)] + pub fn bucket(&self, row: usize, col: usize) -> Option<&[T]> { + if row < self.rows && col < self.cols { + let start = self.bucket_start(row, col); + Some(&self.data[start..start + self.depth]) + } else { + None + } + } + + /// Returns the `(row, col)` bucket slice mutably when it exists. + #[inline(always)] + pub fn bucket_mut(&mut self, row: usize, col: usize) -> Option<&mut [T]> { + if row < self.rows && col < self.cols { + let start = self.bucket_start(row, col); + Some(&mut self.data[start..start + self.depth]) + } else { + None + } + } + + /// Returns the `(row, col)` bucket slice, debug-asserting bounds. + /// Faster sibling of [`Self::bucket`]. + #[inline(always)] + pub fn bucket_slice(&self, row: usize, col: usize) -> &[T] { + debug_assert!(row < self.rows && col < self.cols, "bucket out of bounds"); + let start = self.bucket_start(row, col); + &self.data[start..start + self.depth] + } + + /// Mutable sibling of [`Self::bucket_slice`]. + #[inline(always)] + pub fn bucket_slice_mut(&mut self, row: usize, col: usize) -> &mut [T] { + debug_assert!(row < self.rows && col < self.cols, "bucket out of bounds"); + let start = self.bucket_start(row, col); + &mut self.data[start..start + self.depth] + } + + /// Applies an update to a single element via the supplied operator. + #[inline(always)] + pub fn update_one_counter(&mut self, row: usize, col: usize, d: usize, op: F, value: V) + where + F: Fn(&mut T, V), + { + let idx = self.bucket_start(row, col) + d; + op(&mut self.data[idx], value); + } + + /// get the number of bits required to cover the col size + #[inline(always)] + /// Returns the bit width needed to represent a column index. + pub fn get_mask_bits(&self) -> u32 { + mask_bits_for_cols(self.cols) + } + + /// get the number of bits required for hashed value + /// only three case possible: 32, 64, 128 + #[inline] + /// Returns the packed hash width needed for all rows. + pub fn get_required_bits(&self) -> usize { + let mut bits_required = self.get_mask_bits() as usize; + bits_required *= self.rows; + bits_required = 32 << ((bits_required > 32) as u32 + (bits_required > 64) as u32); + bits_required = bits_required.min(128); + bits_required + } + + /// Inserts along every row using a hashed column selection. + /// + /// For each row a column is selected from `hashed_val`, yielding one + /// `(row, col)` bucket; the closure receives that **bucket slice**, the + /// value, and the current row index. This is the three-dimensional analogue + /// of [`crate::Vector2D::fast_insert`], where the per-row target is a whole + /// bucket rather than a single counter. + #[inline(always)] + pub fn fast_insert(&mut self, op: F, value: V, hashed_val: &Hash) + where + Hash: MatrixFastHash, + F: Fn(&mut [T], &V, usize), + V: Clone, + { + for row in 0..self.rows { + let col = self.col_for_row(hashed_val, row); + let start = self.bucket_start(row, col); + let end = start + self.depth; + op(&mut self.data[start..end], &value, row); + } + } + + #[inline(always)] + /// Decrements the Nitro skip counter by `c`. + pub fn reduce_nitro_skip(&mut self, c: usize) { + self.nitro.reduce_to_skip_by_count(c) + } + + #[inline(always)] + /// Sets the Nitro skip counter to `c`. + pub fn update_nitro_skip(&mut self, c: usize) { + self.nitro.to_skip = c + } + + #[inline(always)] + /// Returns the current Nitro skip counter. + pub fn get_nitro_skip(&mut self) -> usize { + self.nitro.to_skip + } + + /// Reads a single element by `(row, col, d)`. + #[inline(always)] + pub fn query_one_counter(&self, row: usize, col: usize, d: usize) -> T + where + T: Clone, + { + self.data[self.bucket_start(row, col) + d].clone() + } + + /// Queries all rows using precomputed hashed values to find the minimum. + /// + /// The closure receives: bucket slice, row index, and hash value. + #[inline(always)] + pub fn fast_query_min(&self, hashed_val: &Hash, op: F) -> R + where + Hash: MatrixFastHash, + F: Fn(&[T], usize, &Hash) -> R, + R: PartialOrd, + { + let c0 = self.col_for_row(hashed_val, 0); + let mut min = op(self.bucket_slice(0, c0), 0, hashed_val); + for row in 1..self.rows { + let col = self.col_for_row(hashed_val, row); + let candidate = op(self.bucket_slice(row, col), row, hashed_val); + if candidate < min { + min = candidate; + } + } + min + } + + /// Queries all rows using precomputed hashed values to find the median. + /// + /// The closure receives: bucket slice, row index, and hash value, and + /// returns `f64` values which are collected and reduced to a median. + #[inline(always)] + pub fn fast_query_median(&self, hashed_val: &Hash, op: F) -> f64 + where + Hash: MatrixFastHash, + F: Fn(&[T], usize, &Hash) -> f64, + { + let mut estimates = Vec::with_capacity(self.rows); + for row in 0..self.rows { + let col = self.col_for_row(hashed_val, row); + estimates.push(op(self.bucket_slice(row, col), row, hashed_val)); + } + compute_median_inline_f64(&mut estimates) + } + + /// Queries all rows using precomputed hashed values to find the maximum. + /// + /// The closure receives: bucket slice, row index, and hash value. + #[inline(always)] + pub fn fast_query_max(&self, hashed_val: &MatrixHashType, op: F) -> R + where + F: Fn(&[T], usize, &MatrixHashType) -> R, + R: PartialOrd, + { + let c0 = self.col_for_row(hashed_val, 0); + let mut max = op(self.bucket_slice(0, c0), 0, hashed_val); + for row in 1..self.rows { + let col = self.col_for_row(hashed_val, row); + let candidate = op(self.bucket_slice(row, col), row, hashed_val); + if candidate > max { + max = candidate; + } + } + max + } + + /// Queries all rows to find the minimum with a query key. + /// + /// The closure receives: bucket slice, query key, row index, and hash value. + #[inline(always)] + pub fn fast_query_min_with_key( + &self, + hashed_val: &MatrixHashType, + query_key: &Q, + op: F, + ) -> R + where + F: Fn(&[T], &Q, usize, &MatrixHashType) -> R, + R: PartialOrd, + { + let c0 = self.col_for_row(hashed_val, 0); + let mut min = op(self.bucket_slice(0, c0), query_key, 0, hashed_val); + for row in 1..self.rows { + let col = self.col_for_row(hashed_val, row); + let candidate = op(self.bucket_slice(row, col), query_key, row, hashed_val); + if candidate < min { + min = candidate; + } + } + min + } + + /// Queries all rows to find the maximum with a query key. + /// + /// The closure receives: bucket slice, query key, row index, and hash value. + #[inline(always)] + pub fn fast_query_max_with_key( + &self, + hashed_val: &MatrixHashType, + query_key: &Q, + op: F, + ) -> R + where + F: Fn(&[T], &Q, usize, &MatrixHashType) -> R, + R: PartialOrd, + { + let c0 = self.col_for_row(hashed_val, 0); + let mut max = op(self.bucket_slice(0, c0), query_key, 0, hashed_val); + for row in 1..self.rows { + let col = self.col_for_row(hashed_val, row); + let candidate = op(self.bucket_slice(row, col), query_key, row, hashed_val); + if candidate > max { + max = candidate; + } + } + max + } + + /// Queries all rows to find the median with a query key. + /// + /// The closure receives: bucket slice, query key, row index, and hash value. + #[inline(always)] + pub fn fast_query_median_with_key( + &self, + hashed_val: &MatrixHashType, + query_key: &Q, + op: F, + ) -> f64 + where + F: Fn(&[T], &Q, usize, &MatrixHashType) -> f64, + { + let mut estimates = Vec::with_capacity(self.rows); + for row in 0..self.rows { + let col = self.col_for_row(hashed_val, row); + estimates.push(op(self.bucket_slice(row, col), query_key, row, hashed_val)); } + compute_median_inline_f64(&mut estimates) + } + + /// Queries all rows with custom aggregation logic (fold/reduce pattern). + /// + /// The closure receives: accumulator, bucket slice, query key, row index, and + /// hash value. + #[inline(always)] + pub fn fast_query_aggregate( + &self, + hashed_val: &MatrixHashType, + query_key: &Q, + init: R, + fold_fn: F, + ) -> R + where + F: Fn(R, &[T], &Q, usize, &MatrixHashType) -> R, + { + let mut acc = init; + for row in 0..self.rows { + let col = self.col_for_row(hashed_val, row); + acc = fold_fn(acc, self.bucket_slice(row, col), query_key, row, hashed_val); + } + acc + } + + /// Returns an immutable slice corresponding to a full row plane + /// (`cols * depth` elements). + #[inline(always)] + pub fn row_slice(&self, row: usize) -> &[T] { + debug_assert!(row < self.rows, "row index out of bounds"); + let start = row * self.cols * self.depth; + let end = start + self.cols * self.depth; + &self.data[start..end] + } + + /// Returns a mutable slice corresponding to a full row plane. + #[inline(always)] + pub fn row_slice_mut(&mut self, row: usize) -> &mut [T] { + debug_assert!(row < self.rows, "row index out of bounds"); + let start = row * self.cols * self.depth; + let end = start + self.cols * self.depth; + &mut self.data[start..end] + } + + /// Returns the number of rows (legacy helper). + #[inline(always)] + pub fn get_row(&self) -> usize { + self.rows + } + + /// Returns the number of columns (legacy helper). + #[inline(always)] + pub fn get_col(&self) -> usize { + self.cols + } + + /// Returns the per-bucket depth (legacy helper). + #[inline(always)] + pub fn get_depth(&self) -> usize { + self.depth + } +} + +impl Index for Vector3D { + type Output = [T]; + + fn index(&self, index: usize) -> &Self::Output { + self.row_slice(index) + } +} + +impl IndexMut for Vector3D { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.row_slice_mut(index) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn required_bits_match_expected_thresholds() { + let default_dims: Vector3D = Vector3D::init(3, 4096, 8); + assert_eq!(default_dims.get_required_bits(), 64); + + let smaller_cols: Vector3D = Vector3D::init(3, 64, 8); + assert_eq!(smaller_cols.get_required_bits(), 32); + + let larger_shape: Vector3D = Vector3D::init(5, 1_048_576, 8); + assert_eq!(larger_shape.get_required_bits(), 128); + } + + #[test] + fn fill_initializes_every_cell() { + let mut v: Vector3D = Vector3D::init(2, 4, 3); + v.fill(0); + assert_eq!(v.len(), 2 * 4 * 3); + assert!(!v.is_empty()); + assert!(v.as_slice().iter().all(|&x| x == 0)); + assert_eq!(v.rows(), 2); + assert_eq!(v.cols(), 4); + assert_eq!(v.depth(), 3); + } + + #[test] + fn from_fn_addresses_every_position() { + let v = Vector3D::from_fn(2, 3, 2, |r, c, d| (r * 100 + c * 10 + d) as u32); + assert_eq!(v.get(0, 0, 0), Some(&0)); + assert_eq!(v.get(1, 2, 1), Some(&121)); + assert_eq!(v.get(2, 0, 0), None); + assert_eq!(v.bucket(1, 2), Some([120u32, 121u32].as_slice())); + assert_eq!(v.bucket(0, 3), None); + } + + #[test] + fn bucket_and_element_mutation_round_trips() { + let mut v: Vector3D = Vector3D::init(2, 2, 4); + v.fill(0); + v.bucket_slice_mut(1, 0)[2] = 7; + v.update_one_counter(0, 1, 3, |a, b| *a = b, 9); + assert_eq!(v.query_one_counter(1, 0, 2), 7); + assert_eq!(v.get(0, 1, 3), Some(&9)); + // Untouched bucket stays zero. + assert!(v.bucket_slice(0, 0).iter().all(|&x| x == 0)); + // Row plane spans cols * depth elements. + assert_eq!(v.row_slice(0).len(), 2 * 4); + assert_eq!(v[1].len(), 2 * 4); } } diff --git a/src/sketches/countsketch_hll.rs b/src/sketches/countsketch_hll.rs new file mode 100644 index 0000000..eec720e --- /dev/null +++ b/src/sketches/countsketch_hll.rs @@ -0,0 +1,371 @@ +//! Count Sketch + HyperLogLog hybrid (`CountHll`). +//! +//! A frequency-style hashing layout (the Count Sketch row/column grid) where +//! **every `(row, col)` bucket is a small HyperLogLog** instead of a single +//! signed counter. Each item is routed to one column per row (exactly like a +//! Count Sketch) and recorded into that bucket's HLL registers. This answers +//! *distinct-count* questions rather than frequency questions: +//! +//! - [`CountHll::estimate`] — the number of **distinct** items sharing a key's +//! buckets (median across rows, to suppress collision noise). +//! - [`CountHll::estimate_total_cardinality`] — total stream cardinality, +//! exploiting the fact that, within a row, items are partitioned across +//! columns, so the per-bucket distinct counts sum to the total. +//! +//! Storage is a [`Vector3D`] of shape `rows x cols x (2^precision)`: the +//! third dimension is the HLL register array for each bucket. +//! +//! The HyperLogLog register/rank math mirrors [`crate::sketches::hll`] (classic +//! estimator with small/large-range corrections). +//! +//! References: +//! - Charikar, Chen & Farach-Colton, "Finding Frequent Items in Data Streams," +//! ICALP 2002. +//! - Flajolet, Fusy, Gandouet & Meunier, "HyperLogLog: the analysis of a +//! near-optimal cardinality estimation algorithm," 2007. + +use crate::{DataInput, DefaultXxHasher, SketchHasher, Vector3D}; +use rmp_serde::{ + decode::Error as RmpDecodeError, encode::Error as RmpEncodeError, from_slice, to_vec_named, +}; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +const DEFAULT_ROW_NUM: usize = 4; +const DEFAULT_COL_NUM: usize = 64; +const DEFAULT_PRECISION: u32 = 8; +const LOWER_32_MASK: u64 = (1u64 << 32) - 1; + +/// A Count Sketch grid whose cells are per-bucket HyperLogLog sketches. +/// +/// `rows` independent hash rows each route an item to one of `cols` columns; the +/// selected `(row, col)` bucket holds a `2^precision`-register HyperLogLog that +/// records the item. See the [module docs](crate::sketches::countsketch_hll) for +/// the supported queries. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct CountHll { + buckets: Vector3D, + rows: usize, + cols: usize, + precision: u32, + #[serde(skip)] + _hasher: PhantomData, +} + +impl Default for CountHll { + fn default() -> Self { + Self::with_dimensions(DEFAULT_ROW_NUM, DEFAULT_COL_NUM, DEFAULT_PRECISION) + } +} + +impl CountHll { + /// Creates a sketch with the requested grid size and per-bucket HLL precision. + /// + /// `precision` is the HyperLogLog precision `p`; each bucket holds `2^p` + /// registers. Panics if `precision` is not in `1..=18` (the range for which + /// the register layout and estimator are well-defined here). + pub fn with_dimensions(rows: usize, cols: usize, precision: u32) -> Self { + assert!( + (1..=18).contains(&precision), + "precision must be in 1..=18, got {precision}" + ); + assert!(rows > 0 && cols > 0, "rows and cols must be non-zero"); + let depth = 1usize << precision; + let mut buckets = Vector3D::init(rows, cols, depth); + buckets.fill(0); + Self { + buckets, + rows, + cols, + precision, + _hasher: PhantomData, + } + } + + /// Number of hash rows. + pub fn rows(&self) -> usize { + self.rows + } + + /// Number of columns per row. + pub fn cols(&self) -> usize { + self.cols + } + + /// HyperLogLog precision `p` (each bucket has `2^p` registers). + pub fn precision(&self) -> u32 { + self.precision + } + + /// Number of HLL registers per `(row, col)` bucket. + pub fn registers_per_bucket(&self) -> usize { + self.buckets.depth() + } + + /// Exposes the backing storage for inspection/testing. + pub fn as_storage(&self) -> &Vector3D { + &self.buckets + } + + /// Mutable access used internally for testing scenarios. + pub fn as_storage_mut(&mut self) -> &mut Vector3D { + &mut self.buckets + } + + /// Seed used for the per-bucket HLL register hash. + /// + /// Distinct from the per-row column-selection seeds (`0..rows`), so the + /// register hash is independent of column placement. + #[inline(always)] + fn hll_seed(&self) -> usize { + self.rows + } + + /// Computes the `(register index, rank)` pair for a value, shared by every + /// bucket the value lands in. + #[inline(always)] + fn register_and_rank(&self, value: &DataInput) -> (usize, u8) { + let p = self.precision; + let register_bits = 64 - p; + let p_mask = (1u64 << p) - 1; + let hll_hash = H::hash64_seeded(self.hll_seed(), value); + let index = ((hll_hash >> register_bits) & p_mask) as usize; + let rank = ((hll_hash << p) + p_mask).leading_zeros() as u8 + 1; + (index, rank) + } + + /// Inserts one observation: route to one column per row and record the value + /// in that bucket's HyperLogLog. + pub fn insert(&mut self, value: &DataInput) { + let cols = self.cols; + let (index, rank) = self.register_and_rank(value); + for r in 0..self.rows { + let col_hash = H::hash64_seeded(r, value); + let col = ((col_hash & LOWER_32_MASK) as usize) % cols; + let bucket = self.buckets.bucket_slice_mut(r, col); + if rank > bucket[index] { + bucket[index] = rank; + } + } + } + + /// Inserts each value in the slice. + pub fn insert_many(&mut self, values: &[DataInput]) { + for value in values { + self.insert(value); + } + } + + /// Estimates the number of distinct items sharing `value`'s buckets. + /// + /// Each of the `rows` buckets the value maps to estimates the distinct count + /// of all items routed there (the value plus collisions); the median across + /// rows suppresses collision over-counting. + pub fn estimate(&self, value: &DataInput) -> f64 { + let cols = self.cols; + let mut estimates = Vec::with_capacity(self.rows); + for r in 0..self.rows { + let col_hash = H::hash64_seeded(r, value); + let col = ((col_hash & LOWER_32_MASK) as usize) % cols; + estimates.push(estimate_bucket(self.buckets.bucket_slice(r, col))); + } + if estimates.is_empty() { + return 0.0; + } + estimates.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = estimates.len() / 2; + if estimates.len() % 2 == 1 { + estimates[mid] + } else { + (estimates[mid - 1] + estimates[mid]) / 2.0 + } + } + + /// Estimates the total number of distinct items in the stream. + /// + /// Within a single row, every item is routed to exactly one column, so the + /// columns partition the stream and the per-bucket distinct counts sum to the + /// total cardinality. The median of the per-row sums is returned for + /// stability across rows. + pub fn estimate_total_cardinality(&self) -> f64 { + let mut per_row = Vec::with_capacity(self.rows); + for r in 0..self.rows { + let mut row_sum = 0.0; + for c in 0..self.cols { + row_sum += estimate_bucket(self.buckets.bucket_slice(r, c)); + } + per_row.push(row_sum); + } + if per_row.is_empty() { + return 0.0; + } + per_row.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = per_row.len() / 2; + if per_row.len() % 2 == 1 { + per_row[mid] + } else { + (per_row[mid - 1] + per_row[mid]) / 2.0 + } + } + + /// Merges another sketch by taking the element-wise register maximum. + /// + /// Both sketches must share the same grid dimensions and precision. + pub fn merge(&mut self, other: &Self) { + assert_eq!( + (self.rows, self.cols, self.precision), + (other.rows, other.cols, other.precision), + "dimension/precision mismatch while merging CountHll sketches" + ); + for (reg, other_reg) in self + .buckets + .as_mut_slice() + .iter_mut() + .zip(other.buckets.as_slice().iter().copied()) + { + if other_reg > *reg { + *reg = other_reg; + } + } + } + + /// Serializes the sketch into MessagePack bytes. + pub fn serialize_to_bytes(&self) -> Result, RmpEncodeError> { + to_vec_named(self) + } + + /// Deserializes a sketch from MessagePack bytes. + pub fn deserialize_from_bytes(bytes: &[u8]) -> Result { + from_slice(bytes) + } +} + +/// Classic HyperLogLog cardinality estimate over a single register slice. +/// +/// Mirrors [`crate::sketches::hll`]'s classic estimator, including the +/// small-range linear-counting and large-range corrections. +fn estimate_bucket(registers: &[u8]) -> f64 { + let m = registers.len() as f64; + let alpha_m = 0.7213 / (1.0 + 1.079 / m); + let mut z = 0.0; + for ®_val in registers { + z += 2f64.powi(-(reg_val as i32)); + } + let mut est = alpha_m * m * m / z; + if est <= m * 5.0 / 2.0 { + let zero_count = registers.iter().filter(|&®| reg == 0).count(); + if zero_count != 0 { + est = m * (m / zero_count as f64).ln(); + } + } else if est > 143_165_576.533 { + let correction_aux = i32::MAX as f64; + est = -correction_aux * (1.0 - est / correction_aux).ln(); + } + est +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::DataInput; + + fn distinct_keys(count: u64) -> Vec> { + (0..count).map(DataInput::U64).collect() + } + + #[test] + fn default_initializes_expected_dimensions() { + let sk = CountHll::default(); + assert_eq!(sk.rows(), DEFAULT_ROW_NUM); + assert_eq!(sk.cols(), DEFAULT_COL_NUM); + assert_eq!(sk.precision(), DEFAULT_PRECISION); + assert_eq!(sk.registers_per_bucket(), 1 << DEFAULT_PRECISION); + // Every register starts at zero. + assert!(sk.as_storage().as_slice().iter().all(|&r| r == 0)); + } + + #[test] + fn with_dimensions_uses_custom_sizes() { + let sk = CountHll::::with_dimensions(3, 17, 6); + assert_eq!(sk.rows(), 3); + assert_eq!(sk.cols(), 17); + assert_eq!(sk.precision(), 6); + assert_eq!(sk.registers_per_bucket(), 64); + assert_eq!(sk.as_storage().len(), 3 * 17 * 64); + } + + #[test] + fn repeated_key_counts_as_one_distinct() { + let mut sk = CountHll::::default(); + let key = DataInput::Str("alpha"); + for _ in 0..500 { + sk.insert(&key); + } + // Only one distinct item touched these buckets, so the distinct estimate + // should sit close to 1. + let est = sk.estimate(&key); + assert!(est < 3.0, "expected near-1 distinct estimate, got {est}"); + } + + #[test] + fn estimate_total_cardinality_tracks_distinct_count() { + let mut sk = CountHll::::default(); + let n = 4000u64; + for key in &distinct_keys(n) { + sk.insert(key); + } + let est = sk.estimate_total_cardinality(); + let truth = n as f64; + let rel_err = (est - truth).abs() / truth; + assert!( + rel_err < 0.25, + "total cardinality estimate {est} too far from {truth} (rel_err {rel_err})" + ); + } + + #[test] + fn merge_takes_register_max_and_unions_cardinality() { + let mut a = CountHll::::default(); + let mut b = CountHll::::default(); + for key in (0..2000u64).map(DataInput::U64) { + a.insert(&key); + } + for key in (2000..4000u64).map(DataInput::U64) { + b.insert(&key); + } + let a_card = a.estimate_total_cardinality(); + + a.merge(&b); + let merged = a.estimate_total_cardinality(); + + assert!( + merged > a_card, + "merged cardinality {merged} should exceed single-set {a_card}" + ); + let rel_err = (merged - 4000.0).abs() / 4000.0; + assert!( + rel_err < 0.25, + "merged cardinality {merged} too far from 4000 (rel_err {rel_err})" + ); + } + + #[test] + fn serialize_round_trip_preserves_estimates() { + let mut sk = CountHll::::with_dimensions(4, 32, 8); + for key in &distinct_keys(1500) { + sk.insert(key); + } + let bytes = sk.serialize_to_bytes().expect("serialize"); + let restored = CountHll::::deserialize_from_bytes(&bytes).expect("decode"); + + assert_eq!(sk.rows(), restored.rows()); + assert_eq!(sk.cols(), restored.cols()); + assert_eq!(sk.precision(), restored.precision()); + assert_eq!(sk.as_storage().as_slice(), restored.as_storage().as_slice()); + assert_eq!( + sk.estimate_total_cardinality(), + restored.estimate_total_cardinality() + ); + } +} diff --git a/src/sketches/mod.rs b/src/sketches/mod.rs index 38b7251..339627a 100644 --- a/src/sketches/mod.rs +++ b/src/sketches/mod.rs @@ -49,6 +49,10 @@ pub use coco::CocoBucket; pub mod countsketch; pub use countsketch::Count; +/// Count Sketch grid with a per-bucket HyperLogLog (distinct-count) sketch. +pub mod countsketch_hll; +pub use countsketch_hll::CountHll; + /// Hashing path markers for matrix-backed sketches. pub mod mode; pub use mode::{FastPath, RegularPath}; From dc7ba80de756d122451478895cc3376085fdad3f Mon Sep 17 00:00:00 2001 From: GordonYuanyc Date: Sat, 13 Jun 2026 03:11:08 -0400 Subject: [PATCH 2/2] docs(count-hll): fix broken intra-doc link to Vector3D cargo doc --no-deps -D warnings rejected [`Vector3D`] because rustdoc cannot resolve a path containing generics. Give the link an explicit target (crate::Vector3D) so the display text keeps the generic. Doc-comment only; no API change. Co-Authored-By: Claude Opus 4.8 --- src/sketches/countsketch_hll.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sketches/countsketch_hll.rs b/src/sketches/countsketch_hll.rs index eec720e..13bea97 100644 --- a/src/sketches/countsketch_hll.rs +++ b/src/sketches/countsketch_hll.rs @@ -12,7 +12,7 @@ //! exploiting the fact that, within a row, items are partitioned across //! columns, so the per-bucket distinct counts sum to the total. //! -//! Storage is a [`Vector3D`] of shape `rows x cols x (2^precision)`: the +//! Storage is a [`Vector3D`](crate::Vector3D) of shape `rows x cols x (2^precision)`: the //! third dimension is the HLL register array for each bucket. //! //! The HyperLogLog register/rank math mirrors [`crate::sketches::hll`] (classic