Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 206 additions & 3 deletions diskann-benchmark/src/backend/exhaustive/spherical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mod imp {
use indicatif::{ProgressBar, ProgressStyle};
use rand::SeedableRng;
use rayon::iter::{IndexedParallelIterator, ParallelIterator};
use serde::Serialize;
use serde::{Deserialize, Serialize};

use crate::{
backend::exhaustive::algos::{self, LinearSearch},
Expand Down Expand Up @@ -266,7 +266,7 @@ mod imp {
}

/// Results from an end-to-end run of Spherical Quantization.
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct Results {
/// The time it takes to generate the base quantizer.
training_time: MicroSeconds,
Expand Down Expand Up @@ -306,7 +306,7 @@ mod imp {
}
}

#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
struct SearchResults {
num_threads: usize,
time: MicroSeconds,
Expand Down Expand Up @@ -506,4 +506,207 @@ mod imp {
.map(UnwrapErr::new)?)
}
}

#[cfg(test)]
mod tests {
use super::*;
use diskann_benchmark_runner::{
files::InputFile, output::Memory, utils::datatype::DataType,
};
use std::num::NonZeroUsize;

const RECALL_TOLERANCE: f64 = 0.01;

fn test_data_dir() -> std::path::PathBuf {
let manifest = env!("CARGO_MANIFEST_DIR");
std::path::Path::new(manifest)
.parent()
.unwrap()
.join("test_data")
.join("disk_index_search")
}

fn baseline_path() -> std::path::PathBuf {
let manifest = env!("CARGO_MANIFEST_DIR");
std::path::Path::new(manifest)
.join("test")
.join("generated")
.join("spherical_1bit_regression.json")
}

fn make_1bit_input() -> inputs::exhaustive::Spherical {
let data_dir = test_data_dir();
inputs::exhaustive::Spherical {
data: InputFile::new(data_dir.join("disk_index_siftsmall_learn_256pts_data.fbin")),
data_type: DataType::Float32,
distance: SimilarityMeasure::SquaredL2,
compression_threads: NonZeroUsize::new(1).unwrap(),
search: inputs::exhaustive::SearchPhase {
queries: InputFile::new(data_dir.join("disk_index_sample_query_10pts.fbin")),
groundtruth: InputFile::new(
data_dir.join("disk_index_10pts_idx_uint32_truth_search_res.bin"),
),
num_threads: NonZeroUsize::new(1).unwrap(),
recalls: inputs::exhaustive::SearchValues {
recall_k: vec![5, 10],
recall_n: vec![5, 10],
},
},
query_layouts: vec![
inputs::exhaustive::SphericalQuery::SameAsData,
inputs::exhaustive::SphericalQuery::FourBitTransposed,
],
seed: 7831252621480178695,
transform_kind: inputs::exhaustive::TransformKind::PaddingHadamard(
inputs::exhaustive::TargetDim::Same,
),
num_bits: NonZeroUsize::new(1).unwrap(),
pre_scale: inputs::exhaustive::PreScale::Some(0.00390625),
}
}

/// Serializable baseline for regression comparison.
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct Baseline {
quantized_dim: usize,
quantized_bytes: usize,
original_dim: usize,
/// One entry per (query_layout, recall_k, recall_n) combination.
recalls: Vec<RecallBaseline>,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct RecallBaseline {
layout: inputs::exhaustive::SphericalQuery,
recall_k: usize,
recall_n: usize,
num_queries: usize,
average: f64,
minimum: usize,
maximum: usize,
}

impl Baseline {
fn from_results(results: &Results) -> Self {
let mut recalls = Vec::new();
for sr in &results.search_results {
for r in &sr.recalls {
recalls.push(RecallBaseline {
layout: sr.layout,
recall_k: r.recall_k,
recall_n: r.recall_n,
num_queries: r.num_queries,
average: r.average,
minimum: r.minimum,
maximum: r.maximum,
});
}
}
Self {
quantized_dim: results.quantized_dim,
quantized_bytes: results.quantized_bytes,
original_dim: results.original_dim,
recalls,
}
}
}

fn save_baseline(baseline: &Baseline) {
let path = baseline_path();
let dir = path.parent().unwrap();
std::fs::create_dir_all(dir).unwrap();
let file = std::fs::File::create(&path).unwrap();
serde_json::to_writer_pretty(file, baseline).unwrap();
}

fn load_baseline() -> Baseline {
let path = baseline_path();
let file = std::fs::File::open(&path).unwrap_or_else(|err| {
panic!(
"Could not open baseline {}: {}. If this is a new test, \
run with DISKANN_TEST=overwrite to generate the baseline.",
path.display(),
err,
)
});
serde_json::from_reader(std::io::BufReader::new(file)).unwrap()
}

fn is_overwrite_mode() -> bool {
std::env::var("DISKANN_TEST")
.map(|v| v == "overwrite")
.unwrap_or(false)
}

/// Regression test for 1-bit exhaustive spherical quantization.
///
/// This test exercises the full pipeline (train → compress → search) with a
/// fixed seed and checks that deterministic outputs remain stable.
///
/// To regenerate the baseline: `DISKANN_TEST=overwrite cargo test -p diskann-benchmark --features spherical-quantization -- spherical_1bit_regression`
#[test]
fn spherical_1bit_regression() {
let input = make_1bit_input();
let mut output = Memory::new();
let results = SphericalQ::<1>.run(&input, &mut output).unwrap();
let current = Baseline::from_results(&results);

if is_overwrite_mode() {
save_baseline(&current);
println!("Baseline written to {}", baseline_path().display());
return;
}

let expected = load_baseline();

// Exact checks for structural/dimensional properties.
assert_eq!(
current.quantized_dim, expected.quantized_dim,
"quantized_dim changed: got {}, expected {}",
current.quantized_dim, expected.quantized_dim,
);
assert_eq!(
current.quantized_bytes, expected.quantized_bytes,
"quantized_bytes changed: got {}, expected {}",
current.quantized_bytes, expected.quantized_bytes,
);
assert_eq!(
current.original_dim, expected.original_dim,
"original_dim changed: got {}, expected {}",
current.original_dim, expected.original_dim,
);

// Recall checks with tolerance.
assert_eq!(
current.recalls.len(),
expected.recalls.len(),
"number of recall entries changed: got {}, expected {}",
current.recalls.len(),
expected.recalls.len(),
);

for (cur, exp) in current.recalls.iter().zip(expected.recalls.iter()) {
assert_eq!(cur.layout, exp.layout, "query layout mismatch");
assert_eq!(cur.recall_k, exp.recall_k, "recall_k mismatch");
assert_eq!(cur.recall_n, exp.recall_n, "recall_n mismatch");
assert_eq!(cur.num_queries, exp.num_queries, "num_queries mismatch");
assert_eq!(cur.minimum, exp.minimum, "minimum recall changed");
assert_eq!(cur.maximum, exp.maximum, "maximum recall changed");

let delta = (cur.average - exp.average).abs();
assert!(
delta <= RECALL_TOLERANCE,
"recall average regression for {:?} (k={}, n={}): \
got {:.6}, expected {:.6}, delta {:.6} exceeds tolerance {}",
cur.layout,
cur.recall_k,
cur.recall_n,
cur.average,
exp.average,
delta,
RECALL_TOLERANCE,
);
}
}
}
}
4 changes: 2 additions & 2 deletions diskann-benchmark/src/utils/recall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

use diskann_benchmark_core as benchmark_core;

use serde::Serialize;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub(crate) struct RecallMetrics {
/// The `k` value for `k-recall-at-n`.
Expand Down
61 changes: 61 additions & 0 deletions diskann-benchmark/test/generated/spherical_1bit_regression.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"quantized_dim": 128,
"quantized_bytes": 22,
"original_dim": 128,
"recalls": [
{
"layout": "same_as_data",
"recall_k": 5,
"recall_n": 5,
"num_queries": 10,
"average": 0.44,
"minimum": 0,
"maximum": 4
},
{
"layout": "same_as_data",
"recall_k": 5,
"recall_n": 10,
"num_queries": 10,
"average": 0.6,
"minimum": 1,
"maximum": 5
},
{
"layout": "same_as_data",
"recall_k": 10,
"recall_n": 10,
"num_queries": 10,
"average": 0.49,
"minimum": 2,
"maximum": 8
},
{
"layout": "four_bit_transposed",
"recall_k": 5,
"recall_n": 5,
"num_queries": 10,
"average": 0.52,
"minimum": 1,
"maximum": 4
},
{
"layout": "four_bit_transposed",
"recall_k": 5,
"recall_n": 10,
"num_queries": 10,
"average": 0.78,
"minimum": 2,
"maximum": 5
},
{
"layout": "four_bit_transposed",
"recall_k": 10,
"recall_n": 10,
"num_queries": 10,
"average": 0.62,
"minimum": 3,
"maximum": 9
}
]
}
70 changes: 70 additions & 0 deletions diskann-disk/src/utils/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,76 @@ mod kmeans_test {
.contains("Error: Cancellation requested by caller."));
}

#[test]
fn k_means_clustering_produces_valid_clusters() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
let max_reps = 10;

let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
let mut centers = vec![0.0; num_centers * dim];
let pool = create_thread_pool_for_test();

let (closest_docs, closest_center, residual) = k_means_clustering(
&data,
num_points,
dim,
&mut centers,
num_centers,
max_reps,
&mut create_rnd_in_tests(),
&mut false,
pool.as_ref(),
)
.unwrap();

// Every point must be assigned to a valid center.
assert_eq!(closest_center.len(), num_points);
for &c in &closest_center {
assert!((c as usize) < num_centers);
}

// closest_docs must have one entry per center, and contain all point indices exactly once.
assert_eq!(closest_docs.len(), num_centers);
let mut all_points: Vec<usize> = closest_docs
.iter()
.flat_map(|v| v.iter().copied())
.collect();
all_points.sort();
assert_eq!(all_points, (0..num_points).collect::<Vec<_>>());

// Residual must be non-negative.
assert!(residual >= 0.0);
}

#[test]
fn k_means_clustering_returns_err_when_canceled() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
let max_reps = 5;

let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
let mut centers = vec![0.0; num_centers * dim];
let pool = create_thread_pool_for_test();

let err = k_means_clustering(
&data,
num_points,
dim,
&mut centers,
num_centers,
max_reps,
&mut create_rnd_in_tests(),
&mut true, // cancellation requested
pool.as_ref(),
)
.unwrap_err();

assert_eq!(err.kind(), ANNErrorKind::PQError);
}

#[test]
fn selecting_random_pivots_test() {
let dim = 2;
Expand Down
Loading