From 17844b9b075a8356b172af8caabd66c7547628e1 Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Wed, 20 May 2026 22:12:46 -0700 Subject: [PATCH 1/2] Added test for K means --- diskann-disk/src/utils/kmeans.rs | 70 ++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/diskann-disk/src/utils/kmeans.rs b/diskann-disk/src/utils/kmeans.rs index 0c7e895ec..61f6aa6a3 100644 --- a/diskann-disk/src/utils/kmeans.rs +++ b/diskann-disk/src/utils/kmeans.rs @@ -561,6 +561,76 @@ mod kmeans_test { .contains("Error: Cancellation requested by caller.")); } + #[test] + fn k_means_clustering_produces_valid_clusters() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 10; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = vec![0.0; num_centers * dim]; + let pool = create_thread_pool_for_test(); + + let (closest_docs, closest_center, residual) = k_means_clustering( + &data, + num_points, + dim, + &mut centers, + num_centers, + max_reps, + &mut create_rnd_in_tests(), + &mut false, + pool.as_ref(), + ) + .unwrap(); + + // Every point must be assigned to a valid center. + assert_eq!(closest_center.len(), num_points); + for &c in &closest_center { + assert!((c as usize) < num_centers); + } + + // closest_docs must have one entry per center, and contain all point indices exactly once. + assert_eq!(closest_docs.len(), num_centers); + let mut all_points: Vec = closest_docs + .iter() + .flat_map(|v| v.iter().copied()) + .collect(); + all_points.sort(); + assert_eq!(all_points, (0..num_points).collect::>()); + + // Residual must be non-negative. + assert!(residual >= 0.0); + } + + #[test] + fn k_means_clustering_returns_err_when_canceled() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 5; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = vec![0.0; num_centers * dim]; + let pool = create_thread_pool_for_test(); + + let err = k_means_clustering( + &data, + num_points, + dim, + &mut centers, + num_centers, + max_reps, + &mut create_rnd_in_tests(), + &mut true, // cancellation requested + pool.as_ref(), + ) + .unwrap_err(); + + assert_eq!(err.kind(), ANNErrorKind::PQError); + } + #[test] fn selecting_random_pivots_test() { let dim = 2; From 25c279a63f49c6e541fe79ca6286abca8b443b1c Mon Sep 17 00:00:00 2001 From: "Alex Razumov (from Dev Box)" Date: Wed, 20 May 2026 22:41:05 -0700 Subject: [PATCH 2/2] Added spherical regression test --- .../src/backend/exhaustive/spherical.rs | 209 +++++++++++++++++- diskann-benchmark/src/utils/recall.rs | 4 +- .../generated/spherical_1bit_regression.json | 61 +++++ 3 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 diskann-benchmark/test/generated/spherical_1bit_regression.json diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index a4df5b702..f3d0d58f7 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -49,7 +49,7 @@ mod imp { use indicatif::{ProgressBar, ProgressStyle}; use rand::SeedableRng; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; - use serde::Serialize; + use serde::{Deserialize, Serialize}; use crate::{ backend::exhaustive::algos::{self, LinearSearch}, @@ -266,7 +266,7 @@ mod imp { } /// Results from an end-to-end run of Spherical Quantization. - #[derive(Debug, Serialize)] + #[derive(Debug, Serialize, Deserialize)] pub(super) struct Results { /// The time it takes to generate the base quantizer. training_time: MicroSeconds, @@ -306,7 +306,7 @@ mod imp { } } - #[derive(Debug, Serialize)] + #[derive(Debug, Serialize, Deserialize)] struct SearchResults { num_threads: usize, time: MicroSeconds, @@ -506,4 +506,207 @@ mod imp { .map(UnwrapErr::new)?) } } + + #[cfg(test)] + mod tests { + use super::*; + use diskann_benchmark_runner::{ + files::InputFile, output::Memory, utils::datatype::DataType, + }; + use std::num::NonZeroUsize; + + const RECALL_TOLERANCE: f64 = 0.01; + + fn test_data_dir() -> std::path::PathBuf { + let manifest = env!("CARGO_MANIFEST_DIR"); + std::path::Path::new(manifest) + .parent() + .unwrap() + .join("test_data") + .join("disk_index_search") + } + + fn baseline_path() -> std::path::PathBuf { + let manifest = env!("CARGO_MANIFEST_DIR"); + std::path::Path::new(manifest) + .join("test") + .join("generated") + .join("spherical_1bit_regression.json") + } + + fn make_1bit_input() -> inputs::exhaustive::Spherical { + let data_dir = test_data_dir(); + inputs::exhaustive::Spherical { + data: InputFile::new(data_dir.join("disk_index_siftsmall_learn_256pts_data.fbin")), + data_type: DataType::Float32, + distance: SimilarityMeasure::SquaredL2, + compression_threads: NonZeroUsize::new(1).unwrap(), + search: inputs::exhaustive::SearchPhase { + queries: InputFile::new(data_dir.join("disk_index_sample_query_10pts.fbin")), + groundtruth: InputFile::new( + data_dir.join("disk_index_10pts_idx_uint32_truth_search_res.bin"), + ), + num_threads: NonZeroUsize::new(1).unwrap(), + recalls: inputs::exhaustive::SearchValues { + recall_k: vec![5, 10], + recall_n: vec![5, 10], + }, + }, + query_layouts: vec![ + inputs::exhaustive::SphericalQuery::SameAsData, + inputs::exhaustive::SphericalQuery::FourBitTransposed, + ], + seed: 7831252621480178695, + transform_kind: inputs::exhaustive::TransformKind::PaddingHadamard( + inputs::exhaustive::TargetDim::Same, + ), + num_bits: NonZeroUsize::new(1).unwrap(), + pre_scale: inputs::exhaustive::PreScale::Some(0.00390625), + } + } + + /// Serializable baseline for regression comparison. + #[derive(Debug, serde::Serialize, serde::Deserialize)] + struct Baseline { + quantized_dim: usize, + quantized_bytes: usize, + original_dim: usize, + /// One entry per (query_layout, recall_k, recall_n) combination. + recalls: Vec, + } + + #[derive(Debug, serde::Serialize, serde::Deserialize)] + struct RecallBaseline { + layout: inputs::exhaustive::SphericalQuery, + recall_k: usize, + recall_n: usize, + num_queries: usize, + average: f64, + minimum: usize, + maximum: usize, + } + + impl Baseline { + fn from_results(results: &Results) -> Self { + let mut recalls = Vec::new(); + for sr in &results.search_results { + for r in &sr.recalls { + recalls.push(RecallBaseline { + layout: sr.layout, + recall_k: r.recall_k, + recall_n: r.recall_n, + num_queries: r.num_queries, + average: r.average, + minimum: r.minimum, + maximum: r.maximum, + }); + } + } + Self { + quantized_dim: results.quantized_dim, + quantized_bytes: results.quantized_bytes, + original_dim: results.original_dim, + recalls, + } + } + } + + fn save_baseline(baseline: &Baseline) { + let path = baseline_path(); + let dir = path.parent().unwrap(); + std::fs::create_dir_all(dir).unwrap(); + let file = std::fs::File::create(&path).unwrap(); + serde_json::to_writer_pretty(file, baseline).unwrap(); + } + + fn load_baseline() -> Baseline { + let path = baseline_path(); + let file = std::fs::File::open(&path).unwrap_or_else(|err| { + panic!( + "Could not open baseline {}: {}. If this is a new test, \ + run with DISKANN_TEST=overwrite to generate the baseline.", + path.display(), + err, + ) + }); + serde_json::from_reader(std::io::BufReader::new(file)).unwrap() + } + + fn is_overwrite_mode() -> bool { + std::env::var("DISKANN_TEST") + .map(|v| v == "overwrite") + .unwrap_or(false) + } + + /// Regression test for 1-bit exhaustive spherical quantization. + /// + /// This test exercises the full pipeline (train → compress → search) with a + /// fixed seed and checks that deterministic outputs remain stable. + /// + /// To regenerate the baseline: `DISKANN_TEST=overwrite cargo test -p diskann-benchmark --features spherical-quantization -- spherical_1bit_regression` + #[test] + fn spherical_1bit_regression() { + let input = make_1bit_input(); + let mut output = Memory::new(); + let results = SphericalQ::<1>.run(&input, &mut output).unwrap(); + let current = Baseline::from_results(&results); + + if is_overwrite_mode() { + save_baseline(¤t); + println!("Baseline written to {}", baseline_path().display()); + return; + } + + let expected = load_baseline(); + + // Exact checks for structural/dimensional properties. + assert_eq!( + current.quantized_dim, expected.quantized_dim, + "quantized_dim changed: got {}, expected {}", + current.quantized_dim, expected.quantized_dim, + ); + assert_eq!( + current.quantized_bytes, expected.quantized_bytes, + "quantized_bytes changed: got {}, expected {}", + current.quantized_bytes, expected.quantized_bytes, + ); + assert_eq!( + current.original_dim, expected.original_dim, + "original_dim changed: got {}, expected {}", + current.original_dim, expected.original_dim, + ); + + // Recall checks with tolerance. + assert_eq!( + current.recalls.len(), + expected.recalls.len(), + "number of recall entries changed: got {}, expected {}", + current.recalls.len(), + expected.recalls.len(), + ); + + for (cur, exp) in current.recalls.iter().zip(expected.recalls.iter()) { + assert_eq!(cur.layout, exp.layout, "query layout mismatch"); + assert_eq!(cur.recall_k, exp.recall_k, "recall_k mismatch"); + assert_eq!(cur.recall_n, exp.recall_n, "recall_n mismatch"); + assert_eq!(cur.num_queries, exp.num_queries, "num_queries mismatch"); + assert_eq!(cur.minimum, exp.minimum, "minimum recall changed"); + assert_eq!(cur.maximum, exp.maximum, "maximum recall changed"); + + let delta = (cur.average - exp.average).abs(); + assert!( + delta <= RECALL_TOLERANCE, + "recall average regression for {:?} (k={}, n={}): \ + got {:.6}, expected {:.6}, delta {:.6} exceeds tolerance {}", + cur.layout, + cur.recall_k, + cur.recall_n, + cur.average, + exp.average, + delta, + RECALL_TOLERANCE, + ); + } + } + } } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..a7ec82691 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -5,9 +5,9 @@ use diskann_benchmark_core as benchmark_core; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[non_exhaustive] pub(crate) struct RecallMetrics { /// The `k` value for `k-recall-at-n`. diff --git a/diskann-benchmark/test/generated/spherical_1bit_regression.json b/diskann-benchmark/test/generated/spherical_1bit_regression.json new file mode 100644 index 000000000..efc4db3a3 --- /dev/null +++ b/diskann-benchmark/test/generated/spherical_1bit_regression.json @@ -0,0 +1,61 @@ +{ + "quantized_dim": 128, + "quantized_bytes": 22, + "original_dim": 128, + "recalls": [ + { + "layout": "same_as_data", + "recall_k": 5, + "recall_n": 5, + "num_queries": 10, + "average": 0.44, + "minimum": 0, + "maximum": 4 + }, + { + "layout": "same_as_data", + "recall_k": 5, + "recall_n": 10, + "num_queries": 10, + "average": 0.6, + "minimum": 1, + "maximum": 5 + }, + { + "layout": "same_as_data", + "recall_k": 10, + "recall_n": 10, + "num_queries": 10, + "average": 0.49, + "minimum": 2, + "maximum": 8 + }, + { + "layout": "four_bit_transposed", + "recall_k": 5, + "recall_n": 5, + "num_queries": 10, + "average": 0.52, + "minimum": 1, + "maximum": 4 + }, + { + "layout": "four_bit_transposed", + "recall_k": 5, + "recall_n": 10, + "num_queries": 10, + "average": 0.78, + "minimum": 2, + "maximum": 5 + }, + { + "layout": "four_bit_transposed", + "recall_k": 10, + "recall_n": 10, + "num_queries": 10, + "average": 0.62, + "minimum": 3, + "maximum": 9 + } + ] +} \ No newline at end of file