diff --git a/diskann/src/flat/index.rs b/diskann/src/flat/index.rs new file mode 100644 index 000000000..7aff8c130 --- /dev/null +++ b/diskann/src/flat/index.rs @@ -0,0 +1,223 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! [`FlatIndex`] — the index wrapper for a [`DataProvider`](crate::provider::DataProvider) +//! over which we do flat search. +use std::num::NonZeroUsize; + +use diskann_utils::future::SendFuture; + +use crate::{ + ANNResult, + error::{ErrorExt, IntoANNResult}, + flat::{DistancesUnordered, SearchStrategy}, + graph::SearchOutputBuffer, + neighbor::{Neighbor, NeighborPriorityQueue, NeighborPriorityQueueIdType}, + provider::DataProvider, +}; + +/// Statistics collected during a flat search. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct SearchStats { + /// The total number of distance computations performed during the scan. + pub cmps: u32, + + /// The total number of results written to the output buffer. + pub result_count: u32, +} + +/// A thin wrapper around a [`DataProvider`] used for flat search. +#[derive(Debug)] +pub struct FlatIndex { + /// The backing provider. + provider: P, +} + +impl FlatIndex

{ + /// Construct a new [`FlatIndex`] around `provider`. + pub fn new(provider: P) -> Self { + Self { provider } + } + + /// Borrow the underlying provider. + pub fn provider(&self) -> &P { + &self.provider + } + + /// Brute-force k-nearest-neighbor flat search. + /// + /// Streams every element produced by the strategy's visitor through the query + /// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and + /// writes the `(id, distance)` survivors into `output` in best-first order. + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + context: &P::Context, + query: T, + output: &mut OB, + ) -> impl SendFuture> + where + S: SearchStrategy, + S::Id: NeighborPriorityQueueIdType, + T: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized, + { + async move { + let mut visitor = strategy + .create_visitor(&self.provider, context) + .into_ann_result()?; + + let computer = strategy.build_query_computer(query).into_ann_result()?; + + let k = k.get(); + let mut queue = NeighborPriorityQueue::new(k); + let mut cmps: u32 = 0; + + visitor + .distances_unordered(&computer, |id, dist| { + cmps += 1; + queue.insert(Neighbor::new(id, dist)); + }) + .await + .escalate("flat scan must complete to produce correct k-NN results")?; + + let result_count = + output.extend(queue.iter().take(k).map(|n| (n.id, n.distance))) as u32; + + Ok(SearchStats { cmps, result_count }) + } + } +} + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use crate::flat::{ + FlatIndex, + test::{ + harness::KnnOracleRun, + provider::{self as flat_provider}, + }, + }; + use crate::graph::test::synthetic::Grid; + + fn fixture(grid: Grid, size: usize) -> (FlatIndex, usize) { + let provider = flat_provider::Provider::grid(grid, size).unwrap(); + let len = provider.len(); + (FlatIndex::new(provider), len) + } + + /// `knn_search` returns a `Send` future, and a shared `&FlatIndex` can serve + /// many concurrent searches on a multi-threaded runtime, each producing the + /// correct top-k independently. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn multithreaded_knn_search() { + use std::sync::Arc; + + let (index, len) = fixture(Grid::Two, 4); + let index = Arc::new(index); + + // Mix of corner, axis-aligned, and off-grid queries; k spans 1..=len. + let cases: &[(&[f32], usize)] = &[ + (&[-1.0, -1.0], 1), + (&[1.0, 1.0], len), + (&[-1.0, 1.0], len / 2), + (&[1.0, -1.0], len - 1), + (&[0.0, 0.0], 3), + (&[3.0, 3.0], len), + (&[-2.0, 0.5], 2), + (&[0.5, -0.5], len), + ]; + + let mut set = tokio::task::JoinSet::new(); + for (query, k) in cases { + let index = Arc::clone(&index); + let query: Vec = query.to_vec(); + let k = *k; + set.spawn(async move { + let outcome = KnnOracleRun::run( + &index, + &flat_provider::Strategy::new(index.provider().dim()), + &query, + k, + ) + .await + .expect("knn_search failed"); + (query, k, outcome) + }); + } + + while let Some(joined) = set.join_next().await { + let (query, k, outcome) = joined.expect("task panicked"); + assert_eq!( + outcome.top_k, outcome.ground_truth, + "query = {query:?}, k = {k}: top-k must match brute force", + ); + assert_eq!(outcome.stats.cmps as usize, len); + assert_eq!(outcome.stats.result_count as usize, k.min(len)); + } + } + + //////////// + // Errors // + //////////// + + /// A transient error from the visitor's scan must escalate up through `knn_search`. + #[test] + fn transient_scan_error() { + // The flat scan touches every id, so any transient id is guaranteed to be hit. + for transient_ids in [&[0u32][..], &[3][..], &[1, 2, 5][..]] { + let strategy = + flat_provider::Strategy::with_transient(2, transient_ids.iter().copied()); + let (index, _) = fixture(Grid::Two, 3); + let err = KnnOracleRun::run_sync(&index, &strategy, &[1.0, 0.0], 4) + .expect_err("transient error during full scan must escalate"); + + let msg = format!("{err}"); + assert!( + transient_ids + .iter() + .any(|id| msg.contains(&format!("id {id}"))), + "transients = {transient_ids:?}: expected error to name one of the \ + transient ids, got: {msg}", + ); + } + } + + /// Run `knn_search` via the harness, assert it fails, and check the error + /// message contains `expected_msg`. + fn assert_search_error(strategy: &flat_provider::Strategy, query: &[f32], expected_msg: &str) { + let (index, _) = fixture(Grid::Two, 3); + let err = KnnOracleRun::run_sync(&index, strategy, query, 4) + .expect_err("expected knn_search to fail"); + + let msg = format!("{err}"); + assert!( + msg.contains(expected_msg), + "expected error containing {expected_msg:?}, got: {msg}", + ); + } + + #[test] + fn strategy_constructor_errors() { + // Strategy/provider expect dim=2, query has dim=3. + assert_search_error( + &flat_provider::Strategy::new(2), + &[0.0, 0.0, 0.0], + "dimension mismatch", + ); + + // Strategy expects dim=5, provider has dim=2. + assert_search_error( + &flat_provider::Strategy::new(5), + &[0.0, 0.0], + "dimension mismatch", + ); + } +} diff --git a/diskann/src/flat/mod.rs b/diskann/src/flat/mod.rs new file mode 100644 index 000000000..32c926f6e --- /dev/null +++ b/diskann/src/flat/mod.rs @@ -0,0 +1,32 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Sequential ("flat") search. +//! +//! This module is the streaming counterpart to the random-access +//! [`crate::provider::Accessor`] family. It is designed for backends whose natural access +//! pattern is a one-pass scan over their data -- for example append-only buffered stores or +//! on-disk shards streamed via I/O. +//! +//! # Architecture +//! +//! The module mirrors the layering used by graph search: +//! +//! | Graph (random access) | Flat (sequential) | Shared? | +//! | :------------------------------------ | :----------------------------------------- |:--------- | +//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes | +//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No | +//! | [`crate::graph::glue::ExpandBeam`] | [`DistancesUnordered`] | No | +//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No | +//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No | +//! +pub mod index; +pub mod strategy; + +pub use index::{FlatIndex, SearchStats}; +pub use strategy::{DistancesUnordered, SearchStrategy}; + +#[cfg(test)] +mod test; diff --git a/diskann/src/flat/strategy.rs b/diskann/src/flat/strategy.rs new file mode 100644 index 000000000..1634576dc --- /dev/null +++ b/diskann/src/flat/strategy.rs @@ -0,0 +1,354 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Core flat-search traits: [`DistancesUnordered`] and [`SearchStrategy`]. + +use std::fmt::Debug; + +use diskann_utils::future::SendFuture; +use diskann_vector::PreprocessedDistanceFunction; + +use crate::{ + error::{StandardError, ToRanked}, + provider::DataProvider, +}; + +/// Fused iterate-and-score primitive over the elements of a flat index. +/// +/// Implementations drive an entire scan over the underlying data, scoring each element +/// with the supplied computer `C` and invoking `f` with the resulting `(id, distance)` +/// pair. The associated [`Self::ElementRef`] is the reference shape on which `C` must +/// be able to compute distances. +pub trait DistancesUnordered: Send + Sync +where + C: for<'a> PreprocessedDistanceFunction, f32>, +{ + /// Lifetime is intentionally unconstrained so it can appear under HRTB without + /// inducing a `'static` bound on `Self`. + type ElementRef<'a>; + + /// Id type yielded by the underlying data backend, used to uniquely identify + /// each element passed to the closure of [`Self::distances_unordered`]. + type Id; + + /// The error type for [`Self::distances_unordered`]. + type Error: ToRanked + Debug + Send + Sync + 'static; + + /// Drive the entire scan, scoring each element with `computer` and invoking `f` + /// with the resulting `(id, distance)` pair. + fn distances_unordered( + &mut self, + computer: &C, + f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32); +} + +/// Per-call configuration that knows how to construct a per-query +/// [`DistancesUnordered`] visitor for a provider, and the [`Self::QueryComputer`] used +/// to score each element during the scan. +pub trait SearchStrategy: Send + Sync +where + P: DataProvider, +{ + /// The reference element shape on which [`Self::QueryComputer`] computes + /// distances. + type ElementRef<'a>; + + /// Id type yielded by the `Self::Visitor`. + type Id; + + /// The concrete query-computer type. + type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + + Send + + Sync + + 'static; + + /// The error type for [`Self::build_query_computer`]. + type QueryComputerError: StandardError; + + /// The visitor type produced by [`Self::create_visitor`]. + type Visitor<'a>: for<'b> DistancesUnordered< + Self::QueryComputer, + ElementRef<'b> = Self::ElementRef<'b>, + Id = Self::Id, + > + where + Self: 'a, + P: 'a; + + /// The error type for [`Self::create_visitor`]. + type Error: StandardError; + + /// Construct a fresh visitor over `provider` for the given request `context`. + fn create_visitor<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; + + /// Construct the per-query computer. + fn build_query_computer( + &self, + query: T, + ) -> Result; +} + +#[cfg(test)] +mod tests { + //! Direct [`DistancesUnordered`] impls over a few in-memory fixtures: a + //! happy-path scanner over `&[f32]` elements, a scanner whose `ElementRef<'a>` + //! is a lifetime-carrying non-reference type, and a scanner that fails + //! mid-stream. + + use std::marker::PhantomData; + + use diskann_utils::future::SendFuture; + use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; + + use super::*; + use crate::{ANNError, always_escalate, error::Infallible, utils::VectorRepr}; + + /// Sample dataset shared by every test below. + fn sample_items() -> Vec<(u32, Vec)> { + vec![ + (10, vec![0.0, 0.0]), + (11, vec![1.0, 0.0]), + (12, vec![0.0, 2.0]), + ] + } + + ///////////////////////////// + // Scanner yielding slices // + ///////////////////////////// + + /// Scans `items` in order, scoring each with the supplied computer. + struct Scanner { + items: Vec<(u32, Vec)>, + } + + impl DistancesUnordered<::QueryDistance> for Scanner { + type ElementRef<'a> = &'a [f32]; + type Id = u32; + type Error = Infallible; + + fn distances_unordered( + &mut self, + computer: &::QueryDistance, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for (id, v) in &self.items { + let dist = computer.evaluate_similarity(v.as_slice()); + f(*id, dist); + } + Ok(()) + } + } + } + + /// Direct [`DistancesUnordered`] impl yields the expected `(id, distance)` pairs. + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn distances_unordered_scanner() { + let query = vec![0.5_f32, 0.9]; + let computer = f32::query_distance(&query, Metric::L2); + + let expected: Vec<(u32, f32)> = sample_items() + .into_iter() + .map(|(id, v)| (id, computer.evaluate_similarity(v.as_slice()))) + .collect(); + + let mut scanner = Scanner { + items: sample_items(), + }; + + let mut seen: Vec<(u32, f32)> = Vec::new(); + scanner + .distances_unordered(&computer, |id, d| seen.push((id, d))) + .await + .unwrap(); + assert_eq!(seen, expected); + } + + /////////////////////////// + // Failing scanner // + /////////////////////////// + + /// Non-recoverable error type returned by [`Failing`]. + #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] + #[error("synthetic scan failure at id {0}")] + struct Boom(u32); + + always_escalate!(Boom); + + impl From for ANNError { + #[track_caller] + fn from(boom: Boom) -> ANNError { + ANNError::opaque(boom) + } + } + + /// Scans `items`, but returns `Err(Boom(id))` exactly once after `fail_after` + /// successful yields. + struct Failing { + items: Vec<(u32, Vec)>, + fail_after: usize, + } + + impl DistancesUnordered<::QueryDistance> for Failing { + type ElementRef<'a> = &'a [f32]; + type Id = u32; + type Error = Boom; + + fn distances_unordered( + &mut self, + computer: &::QueryDistance, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for (i, (id, v)) in self.items.iter().enumerate() { + if i == self.fail_after { + return Err(Boom(*id)); + } + let dist = computer.evaluate_similarity(v.as_slice()); + f(*id, dist); + } + Ok(()) + } + } + } + + /// An error returned mid-scan propagates up, and the closure stops being invoked + /// at the failure point. + #[tokio::test] + async fn failures_midstream() { + let mut scanner = Failing { + items: sample_items(), + fail_after: 1, // Yield item 0 successfully, fail on item 1. + }; + + let query = vec![0.0_f32, 0.0]; + let computer = f32::query_distance(&query, Metric::L2); + + let mut seen: Vec = Vec::new(); + let err = scanner + .distances_unordered(&computer, |id, _d| seen.push(id)) + .await + .expect_err("Failing scanner must surface its error"); + + assert_eq!(err, Boom(11)); + assert_eq!( + seen, + vec![10], + "the closure must only see items yielded before the failure", + ); + } + + ///////////////////////////////////////////// + // Lifetime-carrying concrete `ElementRef` // + ///////////////////////////////////////////// + + struct View<'a> { + ptr: *const f32, + len: usize, + _phantom: PhantomData<&'a [f32]>, + } + + // SAFETY: `View<'a>` semantically carries a `&'a [f32]`, which is `Send + Sync`. + unsafe impl Send for View<'_> {} + unsafe impl Sync for View<'_> {} + + /// Computer that reconstructs a `&[f32]` from a [`View`]'s ptr+len and + /// computes inner product against a stored query. + struct ViewComputer { + query: Vec, + } + + impl<'a> PreprocessedDistanceFunction, f32> for ViewComputer { + fn evaluate_similarity(&self, v: View<'a>) -> f32 { + // SAFETY: `v.ptr` / `v.len` were produced from a `&'a [f32]` held by the + // scanner that owns the backing `Vec`; the phantom lifetime ties this view + // to that borrow, so the slice is valid for the duration of this call. + let s = unsafe { std::slice::from_raw_parts(v.ptr, v.len) }; + s.iter().zip(&self.query).map(|(a, b)| a * b).sum() + } + } + + /// Scans `rows`, yielding a [`View`] tied (via its phantom lifetime) to the + /// borrow of the underlying `Vec`. + struct ViewScanner { + rows: Vec<(u32, Vec)>, + } + + impl ViewScanner { + fn iter<'a>(&self) -> impl Iterator)> { + self.rows.iter().map(|(x, y)| { + ( + *x, + View { + ptr: y.as_ptr(), + len: y.len(), + _phantom: PhantomData, + }, + ) + }) + } + } + + impl DistancesUnordered for ViewScanner { + type ElementRef<'a> = View<'a>; + type Id = u32; + type Error = Infallible; + + fn distances_unordered( + &mut self, + computer: &ViewComputer, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for (id, v) in self.iter() { + f(id, computer.evaluate_similarity(v)); + } + Ok(()) + } + } + } + + #[tokio::test] + async fn distances_unordered_lifetime_carrying_element_ref() { + let mut scanner = ViewScanner { + rows: vec![ + (10, vec![1.0, 0.0]), + (11, vec![0.5, 0.5]), + (12, vec![0.0, 2.0]), + ], + }; + let computer = ViewComputer { + query: vec![1.0, 3.0], + }; + let expected: Vec<(u32, f32)> = vec![ + (10, 1.0 * 1.0 + 0.0 * 3.0), + (11, 0.5 * 1.0 + 0.5 * 3.0), + (12, 0.0 * 1.0 + 2.0 * 3.0), + ]; + + let mut seen: Vec<(u32, f32)> = Vec::new(); + scanner + .distances_unordered(&computer, |id, d| seen.push((id, d))) + .await + .unwrap(); + assert_eq!(seen, expected); + } +} diff --git a/diskann/src/flat/test/cases/flat_knn_search.rs b/diskann/src/flat/test/cases/flat_knn_search.rs new file mode 100644 index 000000000..fde7cea81 --- /dev/null +++ b/diskann/src/flat/test/cases/flat_knn_search.rs @@ -0,0 +1,206 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Baseline-cached regression sweep for [`crate::flat::FlatIndex::knn_search`]. +//! +//! Bbuilds a fresh index per parameter combination, runs `knn_search` through the +//! [`crate::flat::test::harness`], snapshots the result + statistics into +//! [`FlatKnnBaseline`], and compares the entire batch against the JSON committed under +//! `diskann/test/generated/flat/test/cases/flat_knn_search/`. + +use crate::{ + flat::{ + FlatIndex, + test::{ + harness, + provider::{self as flat_provider, ElementCounter, Strategy}, + }, + }, + graph::test::synthetic::Grid, + test::{ + TestPath, TestRoot, + cmp::{assert_eq_verbose, verbose_eq}, + get_or_save_test_results, + }, +}; + +fn root() -> TestRoot { + TestRoot::new("flat/test/cases/flat_knn_search") +} + +/// `k` values exercised for every `(grid, query)` combination. +const KS: [usize; 3] = [1, 4, 10]; + +/// One row of the baseline JSON: a single `(grid, size, query, k)` execution of +/// `FlatIndex::knn_search` plus the brute-force ground truth, search stats, and +/// per-row provider metrics. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct FlatKnnBaseline { + /// Free-form description of what this row exercises. + description: String, + + /// The query vector. + query: Vec, + + /// The dimensionality of the underlying grid. + grid_dims: usize, + + /// The side length of the underlying grid. + grid_size: usize, + + /// The requested `k`. + k: usize, + + /// Sorted distance multiset of the top-`k` returned by `knn_search`. + /// We store the distance multiset rather than `(id, distance)` pairs because + /// the priority queue may evict different *ids* on a boundary distance tie + /// (the queue's tie-breaking is heap-internal, not id-based) ΓÇö but the + /// multiset of distances is invariant. + top_k_distances: Vec, + + /// Brute-force ground-truth top-`k` `(id, distance)` (sorted by `(distance asc, + /// id asc)`). The brute-force pass enumerates ids in ascending order, so on a + /// tie this prefers the smaller id and gives a canonical answer for the JSON. + ground_truth: Vec<(u32, f32)>, + + /// `cmps` reported by `knn_search`. Must equal `provider.len()`. + comparisons: usize, + + /// `result_count` reported by `knn_search`. Must equal `min(k, provider.len())`. + result_count: usize, + + /// Per-provider metrics observed for this row (see [`Metrics`]). + metrics: ElementCounter, +} + +verbose_eq!(FlatKnnBaseline { + description, + query, + grid_dims, + grid_size, + k, + top_k_distances, + ground_truth, + comparisons, + result_count, + metrics, +}); + +/// Run `knn_search` + brute-force oracle against a *shared* `index`, assert the +/// cross-row invariants, and produce the baseline row. The per-row provider metrics +/// captured into the baseline are the *delta* observed during this row, which keeps +/// the snapshot independent of how many rows preceded it. +fn run_row( + index: &FlatIndex, + grid_dim: usize, + grid_size: usize, + query: &[f32], + k: usize, + desc: &str, +) -> FlatKnnBaseline { + let len = index.provider().len(); + let metrics_before = index.provider().metrics(); + + let outcome = + harness::KnnOracleRun::run_sync(index, &Strategy::new(index.provider().dim()), query, k) + .unwrap(); + let stats = outcome.stats; + + assert_eq!( + stats.cmps as usize, len, + "flat scan must touch every element exactly once", + ); + assert_eq!( + stats.result_count as usize, + k.min(len), + "result_count must equal min(k, provider.len())", + ); + + let gt_distances: Vec = outcome.ground_truth.iter().map(|(_, d)| *d).collect(); + assert_eq!( + outcome.top_k_distances, gt_distances, + "flat scan top-k distance multiset must agree with brute force", + ); + + let metrics_after = index.provider().metrics(); + let metrics = ElementCounter { + count: metrics_after.count - metrics_before.count, + }; + // `get_element` is incremented only by the [`Visitor`] used during `knn_search`; + // the brute-force oracle iterates `Provider::items()` directly and does not touch + // the visitor, so we expect exactly one scan's worth of increments per row. + assert_eq!( + metrics.count, len, + "expected exactly one scan (from knn_search) to increment get_element", + ); + + FlatKnnBaseline { + description: desc.to_string(), + query: query.to_vec(), + grid_dims: grid_dim, + grid_size, + k, + top_k_distances: outcome.top_k_distances, + ground_truth: outcome.ground_truth, + comparisons: stats.cmps as usize, + result_count: stats.result_count as usize, + metrics, + } +} + +/// Sweep [`KS`] ├ù `queries` for the given `(grid, size)` and snapshot the results. +fn _flat_knn_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { + let dim: usize = grid.dim().into(); + + // Build the provider and index once, mirroring the production pattern where a + // single index serves many queries. + let provider = flat_provider::Provider::grid(grid, size).unwrap(); + let len = provider.len(); + assert_eq!( + len, + size.pow(dim as u32), + "flat::test::Provider::grid should produce size^dim rows", + ); + let index = FlatIndex::new(provider); + + let queries: [(Vec, &str); 2] = [ + ( + vec![-1.0; dim], + "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + ), + ( + vec![(size - 1) as f32; dim], + "All `size-1`: query coincides with the last grid corner.", + ), + ]; + + let index_ref = &index; + let results: Vec = queries + .iter() + .flat_map(|(q, desc)| { + KS.iter() + .map(move |&k| run_row(index_ref, dim, size, q, k, desc)) + }) + .collect(); + + let name = parent.push(format!("search_{dim}_{size}")); + let expected = get_or_save_test_results(&name, &results); + assert_eq_verbose!(expected, results); +} + +#[test] +fn flat_knn_search_1_100() { + _flat_knn_search(Grid::One, 100, root().path()); +} + +#[test] +fn flat_knn_search_2_5() { + _flat_knn_search(Grid::Two, 5, root().path()); +} + +#[test] +fn flat_knn_search_3_4() { + _flat_knn_search(Grid::Three, 4, root().path()); +} diff --git a/diskann/src/flat/test/cases/mod.rs b/diskann/src/flat/test/cases/mod.rs new file mode 100644 index 000000000..ffaf5b91c --- /dev/null +++ b/diskann/src/flat/test/cases/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +mod flat_knn_search; diff --git a/diskann/src/flat/test/harness.rs b/diskann/src/flat/test/harness.rs new file mode 100644 index 000000000..43a19b32d --- /dev/null +++ b/diskann/src/flat/test/harness.rs @@ -0,0 +1,134 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Reusable execution harness for [`crate::flat::FlatIndex`] tests. +//! +//! Use [`KnnOracleRun::run`] to drive `knn_search` and pair the result with the +//! brute-force ground truth. + +use std::{cmp::Ordering, num::NonZeroUsize}; + +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; + +use crate::{ + ANNResult, + flat::{ + FlatIndex, SearchStats, + test::provider::{Provider, Strategy}, + }, + neighbor::{BackInserter, Neighbor}, + test::tokio::current_thread_runtime, + utils::VectorRepr, +}; + +/// Result of running [`FlatIndex::knn_search`] under the harness alongside a +/// brute-force ground-truth oracle. +#[derive(Debug, Clone)] +pub(crate) struct KnnOracleRun { + /// Top-`k` `(id, distance)` pairs. + /// Re-sorted from the heap output so equality checks are deterministic on ties. + pub top_k: Vec<(u32, f32)>, + /// `top_k.iter().map(|(_, d)| d).collect()`. + pub top_k_distances: Vec, + /// Statistics returned by `knn_search` (cmps, result_count). + pub stats: SearchStats, + /// Brute-force ground-truth top-`k` `(id, distance)` pairs in `(distance asc, + /// id asc)` order. + pub ground_truth: Vec<(u32, f32)>, +} + +impl KnnOracleRun { + /// Run [`FlatIndex::knn_search`] once, blocking on a fresh single-threaded + /// runtime, and pair the result with the brute-force ground truth. + pub fn run_sync( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + current_thread_runtime().block_on(Self::run(index, strategy, query, k)) + } + + /// Async variant of [`KnnOracleRun::run_sync`]. Use this from tests that already + /// have a Tokio runtime (e.g. `#[tokio::test]`) or that need to drive + /// `knn_search` concurrently across tasks. + pub async fn run( + index: &FlatIndex, + strategy: &Strategy, + query: &[f32], + k: usize, + ) -> ANNResult { + let context = crate::flat::test::provider::Context::new(); + let mut buf = vec![Neighbor::::default(); k]; + + let stats = index + .knn_search( + NonZeroUsize::new(k).expect("flat::test::harness requires k > 0"), + strategy, + &context, + query, + &mut BackInserter::new(buf.as_mut_slice()), + ) + .await?; + + let top_k = top_k_sorted(&buf, stats.result_count as usize); + let top_k_distances = top_k.iter().map(|(_, d)| *d).collect(); + let ground_truth = brute_force_topk(index.provider(), Metric::L2, query, k); + + Ok(Self { + top_k, + top_k_distances, + stats, + ground_truth, + }) + } +} + +/// Compute the brute-force top-`k` `(id, distance)` pairs over every element of +/// `provider` under `metric`. Iterates [`Provider::items`] directly and scores with +/// a fresh [`f32::query_distance`] computer, so the oracle is independent of the +/// [`crate::flat::test::provider::Visitor`] under test. Ties are broken by ascending +/// id for determinism. +pub(crate) fn brute_force_topk( + provider: &Provider, + metric: Metric, + query: &[f32], + k: usize, +) -> Vec<(u32, f32)> { + let computer = f32::query_distance(query, metric); + + let mut neighbors: Vec> = provider + .items() + .row_iter() + .enumerate() + .map(|(id, element)| Neighbor::new(id as u32, computer.evaluate_similarity(element))) + .collect(); + + sort_neighbors(&mut neighbors); + neighbors + .into_iter() + .take(k) + .map(|n| n.as_tuple()) + .collect() +} + +/// Take the first `result_count` neighbors and return them in `(distance asc, id asc)` +/// order. +fn top_k_sorted(buf: &[Neighbor], result_count: usize) -> Vec<(u32, f32)> { + let mut neighbors: Vec> = buf.iter().copied().take(result_count).collect(); + sort_neighbors(&mut neighbors); + neighbors.into_iter().map(|n| n.as_tuple()).collect() +} + +/// Sort a slice of [`Neighbor`] by `(distance asc, id asc)`. NaN distances are +/// treated as equal (test data should not produce NaN). +fn sort_neighbors(neighbors: &mut [Neighbor]) { + neighbors.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(Ordering::Equal) + .then(a.id.cmp(&b.id)) + }); +} diff --git a/diskann/src/flat/test/mod.rs b/diskann/src/flat/test/mod.rs new file mode 100644 index 000000000..8cbd01622 --- /dev/null +++ b/diskann/src/flat/test/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Test fixtures and helpers for the flat module. +pub(crate) mod harness; +pub(crate) mod provider; + +mod cases; diff --git a/diskann/src/flat/test/provider.rs b/diskann/src/flat/test/provider.rs new file mode 100644 index 000000000..6dfda8453 --- /dev/null +++ b/diskann/src/flat/test/provider.rs @@ -0,0 +1,449 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Self-contained test provider for the flat-search module. + +use std::{ + borrow::Cow, + collections::HashSet, + fmt::{self, Debug}, + future::Future, + sync::Arc, +}; + +use diskann_utils::{future::SendFuture, views::Matrix}; +use diskann_vector::{PreprocessedDistanceFunction, distance::Metric}; +use thiserror::Error; + +use crate::{ + ANNError, always_escalate, + error::{RankedError, ToRanked, TransientError}, + flat::{DistancesUnordered, SearchStrategy}, + graph::test::synthetic::Grid, + internal::counter::{Counter, LocalCounter}, + provider::{self, ExecutionContext, HasId, NoopGuard}, + utils::VectorRepr, +}; + +/// Error conditions for [`Provider::new`]. +#[derive(Debug, Error)] +pub enum ProviderError { + #[error("flat::test::Provider needs at least one item")] + Empty, + #[error("flat::test::Provider items must have non-zero dimension")] + ZeroDimension, +} + +impl From for ANNError { + #[track_caller] + fn from(err: ProviderError) -> ANNError { + ANNError::opaque(err) + } +} + +////////////// +// Provider // +////////////// + +/// In-memory test provider for flat search. +#[derive(Debug)] +pub struct Provider { + items: Matrix, + get_element: Counter, +} + +impl Provider { + /// Construct a provider from a matrix of vectors. + /// + /// # Errors + /// + /// Returns an error if the matrix is empty or has zero-width columns. + pub fn new(items: Matrix) -> Result { + if items.nrows() == 0 { + return Err(ProviderError::Empty); + } + if items.ncols() == 0 { + return Err(ProviderError::ZeroDimension); + } + Ok(Self { + items, + get_element: Counter::new(), + }) + } + + /// Build a provider over the row vectors of [`Grid::data`]. IDs are `0..n` in + /// row-major order (last coordinate varies fastest). + /// + /// Unlike the graph-side `Provider::grid`, this does *not* add a separate + /// start-point row — flat search has no notion of one. + /// + /// # Errors + /// + /// Propagates errors from [`Self::new`]. + pub fn grid(grid: Grid, size: usize) -> Result { + Self::new(grid.data(size)) + } + + /// Number of vectors in the provider. + pub fn len(&self) -> usize { + self.items.nrows() + } + + /// Dimension of each vector in the provider. + pub fn dim(&self) -> usize { + self.items.ncols() + } + + /// Snapshot of the per-provider counters. + pub fn metrics(&self) -> ElementCounter { + ElementCounter { + count: self.get_element.value(), + } + } + + /// Expose the items for brute force. + pub fn items(&self) -> &Matrix { + &self.items + } +} + +/// Counters tracked by [`Provider`]. +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(serde::Serialize, serde::Deserialize))] +pub struct ElementCounter { + /// The number of times any [`Visitor`] yielded an element. + pub count: usize, +} + +#[cfg(test)] +crate::test::cmp::verbose_eq!(ElementCounter { count }); + +///////////// +// Context // +///////////// + +/// Per-search execution context. No spawn/clone tracking — flat search runs on +/// the calling task and never spawns. +#[derive(Debug, Clone, Default)] +pub struct Context; + +impl Context { + pub fn new() -> Self { + Self + } +} + +impl ExecutionContext for Context { + fn wrap_spawn(&self, f: F) -> impl Future + Send + 'static + where + F: Future + Send + 'static, + { + f + } +} + +///////////////////// +// Errors / Guards // +///////////////////// + +/// Critical id-validation error: the requested id is out of range. +#[derive(Debug, Clone, Copy, Error, PartialEq, Eq)] +#[error("flat::test::Provider has no id {0}")] +pub struct InvalidId(pub u32); + +always_escalate!(InvalidId); + +impl From for ANNError { + #[track_caller] + fn from(err: InvalidId) -> ANNError { + ANNError::opaque(err) + } +} + +/// Transient access error injected by [`Visitor::flaky`]. +/// +/// Matches the shape of `graph::test::TransientAccessError`: panics in `Drop` if it +/// is dropped without being acknowledged or escalated. This guards against accidental +/// silent suppression of the error in the test code itself. +#[must_use] +#[derive(Debug)] +pub struct TransientGetError { + id: u32, + handled: bool, +} + +impl TransientGetError { + fn new(id: u32) -> Self { + Self { id, handled: false } + } +} + +impl fmt::Display for TransientGetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "transient failure retrieving id {}", self.id) + } +} + +impl std::error::Error for TransientGetError {} + +impl Drop for TransientGetError { + fn drop(&mut self) { + assert!( + self.handled, + "dropped an unhandled TransientGetError for id {}", + self.id, + ); + } +} + +impl TransientError for TransientGetError { + fn acknowledge(mut self, _why: D) + where + D: fmt::Display, + { + self.handled = true; + } + + fn escalate(mut self, _why: D) -> InvalidId + where + D: fmt::Display, + { + self.handled = true; + InvalidId(self.id) + } +} + +/// Two-tier error for [`Visitor::distances_unordered`]: a critical [`InvalidId`] +/// or a recoverable [`TransientGetError`]. +#[derive(Debug)] +pub enum AccessError { + InvalidId(InvalidId), + Transient(TransientGetError), +} + +impl fmt::Display for AccessError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidId(e) => fmt::Display::fmt(e, f), + Self::Transient(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for AccessError {} + +impl ToRanked for AccessError { + type Transient = TransientGetError; + type Error = InvalidId; + + fn to_ranked(self) -> RankedError { + match self { + Self::InvalidId(e) => RankedError::Error(e), + Self::Transient(e) => RankedError::Transient(e), + } + } + + fn from_transient(transient: TransientGetError) -> Self { + Self::Transient(transient) + } + + fn from_error(error: InvalidId) -> Self { + Self::InvalidId(error) + } +} + +////////////////// +// DataProvider // +////////////////// + +impl provider::DataProvider for Provider { + type Context = Context; + type InternalId = u32; + type ExternalId = u32; + type Error = InvalidId; + type Guard = NoopGuard; + + fn to_internal_id(&self, _ctx: &Context, gid: &u32) -> Result { + if (*gid as usize) < self.items.nrows() { + Ok(*gid) + } else { + Err(InvalidId(*gid)) + } + } + + fn to_external_id(&self, _ctx: &Context, id: u32) -> Result { + if (id as usize) < self.items.nrows() { + Ok(id) + } else { + Err(InvalidId(id)) + } + } +} + +///////////// +// Visitor // +///////////// + +/// Per-search visitor over a [`Provider`]. Analog of `graph::test::Accessor`: holds +/// the `'a` borrow of the provider, accumulates a local `get_element` counter that +/// flushes back on drop, and optionally injects transient errors for a configurable +/// set of ids. +pub struct Visitor<'a> { + provider: &'a Provider, + transient_ids: Option>>, + get_element: LocalCounter<'a>, +} + +impl<'a> Visitor<'a> { + /// Construct a visitor with no fault injection. + pub fn new(provider: &'a Provider) -> Self { + Self { + provider, + transient_ids: None, + get_element: provider.get_element.local(), + } + } + + /// Construct a visitor that returns a [`TransientGetError`] for any id in + /// `transient_ids`. Other ids behave normally. + pub fn flaky(provider: &'a Provider, transient_ids: Cow<'a, HashSet>) -> Self { + Self { + provider, + transient_ids: Some(transient_ids), + get_element: provider.get_element.local(), + } + } +} + +impl Debug for Visitor<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Visitor") + .field("provider", &self.provider) + .field("transient_ids", &self.transient_ids) + .finish_non_exhaustive() + } +} + +impl HasId for Visitor<'_> { + type Id = u32; +} + +impl DistancesUnordered<::QueryDistance> for Visitor<'_> { + type ElementRef<'a> = &'a [f32]; + type Id = ::Id; + type Error = AccessError; + + fn distances_unordered( + &mut self, + computer: &::QueryDistance, + mut f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32), + { + async move { + for (i, vector) in self.provider.items.row_iter().enumerate() { + let id = i as u32; + if let Some(ids) = &self.transient_ids + && ids.contains(&id) + { + return Err(AccessError::Transient(TransientGetError::new(id))); + } + self.get_element.increment(); + let dist = computer.evaluate_similarity(vector); + f(id, dist); + } + Ok(()) + } + } +} + +////////////// +// Strategy // +////////////// + +/// Error from [`Strategy::create_visitor`] or [`Strategy::build_query_computer`] +/// when dimensions don't match. +#[derive(Debug, Clone, Error)] +#[error("dimension mismatch: strategy expects {expected}, got {actual}")] +pub struct StrategyError { + pub expected: usize, + pub actual: usize, +} + +impl From for ANNError { + #[track_caller] + fn from(err: StrategyError) -> ANNError { + ANNError::opaque(err) + } +} + +/// Factory of [`Visitor`]s that validates dimensions and optionally injects +/// transient errors into the scan. +#[derive(Clone, Debug)] +pub struct Strategy { + dim: usize, + transient_ids: Option>>, +} + +impl Strategy { + /// Construct a strategy expecting vectors of dimension `dim`. + pub fn new(dim: usize) -> Self { + Self { + dim, + transient_ids: None, + } + } + + /// Construct a strategy whose visitors return a transient error on `get_element` + /// for every id in `transient_ids`. + pub fn with_transient(dim: usize, transient_ids: impl IntoIterator) -> Self { + Self { + dim, + transient_ids: Some(Arc::new(transient_ids.into_iter().collect())), + } + } +} + +impl SearchStrategy for Strategy { + type ElementRef<'a> = &'a [f32]; + type Id = u32; + type QueryComputer = ::QueryDistance; + type QueryComputerError = StrategyError; + type Visitor<'a> = Visitor<'a>; + type Error = StrategyError; + + fn create_visitor<'a>( + &'a self, + provider: &'a Provider, + _context: &'a Context, + ) -> Result, Self::Error> { + let actual = provider.dim(); + if actual != self.dim { + return Err(StrategyError { + expected: self.dim, + actual, + }); + } + let visitor = match &self.transient_ids { + Some(ids) => Visitor::flaky(provider, Cow::Borrowed(ids)), + None => Visitor::new(provider), + }; + Ok(visitor) + } + + fn build_query_computer( + &self, + from: &[f32], + ) -> Result { + if from.len() != self.dim { + return Err(StrategyError { + expected: self.dim, + actual: from.len(), + }); + } + Ok(f32::query_distance(from, Metric::L2)) + } +} diff --git a/diskann/src/lib.rs b/diskann/src/lib.rs index 71cb3ed41..9c1f6ac76 100644 --- a/diskann/src/lib.rs +++ b/diskann/src/lib.rs @@ -13,6 +13,7 @@ pub mod utils; pub(crate) mod internal; // Index Implementations +pub mod flat; pub mod graph; // Top level exports. diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json new file mode 100644 index 000000000..0ab9aee6a --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_1_100.json @@ -0,0 +1,264 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_1_100", + "payload": [ + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ] + ], + "k": 1, + "metrics": { + "count": 100 + }, + "query": [ + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 1.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ] + ], + "k": 4, + "metrics": { + "count": 100 + }, + "query": [ + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0 + ] + }, + { + "comparisons": 100, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 0, + 1.0 + ], + [ + 1, + 4.0 + ], + [ + 2, + 9.0 + ], + [ + 3, + 16.0 + ], + [ + 4, + 25.0 + ], + [ + 5, + 36.0 + ], + [ + 6, + 49.0 + ], + [ + 7, + 64.0 + ], + [ + 8, + 81.0 + ], + [ + 9, + 100.0 + ] + ], + "k": 10, + "metrics": { + "count": 100 + }, + "query": [ + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0, + 100.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ] + ], + "k": 1, + "metrics": { + "count": 100 + }, + "query": [ + 99.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ] + ], + "k": 4, + "metrics": { + "count": 100 + }, + "query": [ + 99.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0 + ] + }, + { + "comparisons": 100, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 1, + "grid_size": 100, + "ground_truth": [ + [ + 99, + 0.0 + ], + [ + 98, + 1.0 + ], + [ + 97, + 4.0 + ], + [ + 96, + 9.0 + ], + [ + 95, + 16.0 + ], + [ + 94, + 25.0 + ], + [ + 93, + 36.0 + ], + [ + 92, + 49.0 + ], + [ + 91, + 64.0 + ], + [ + 90, + 81.0 + ] + ], + "k": 10, + "metrics": { + "count": 100 + }, + "query": [ + 99.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 4.0, + 9.0, + 16.0, + 25.0, + 36.0, + 49.0, + 64.0, + 81.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json new file mode 100644 index 000000000..be49e6cf6 --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_2_5.json @@ -0,0 +1,270 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_2_5", + "payload": [ + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ] + ], + "k": 1, + "metrics": { + "count": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ] + ], + "k": 4, + "metrics": { + "count": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0 + ] + }, + { + "comparisons": 25, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 0, + 2.0 + ], + [ + 1, + 5.0 + ], + [ + 5, + 5.0 + ], + [ + 6, + 8.0 + ], + [ + 2, + 10.0 + ], + [ + 10, + 10.0 + ], + [ + 7, + 13.0 + ], + [ + 11, + 13.0 + ], + [ + 3, + 17.0 + ], + [ + 15, + 17.0 + ] + ], + "k": 10, + "metrics": { + "count": 25 + }, + "query": [ + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 2.0, + 5.0, + 5.0, + 8.0, + 10.0, + 10.0, + 13.0, + 13.0, + 17.0, + 17.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ] + ], + "k": 1, + "metrics": { + "count": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ] + ], + "k": 4, + "metrics": { + "count": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0 + ] + }, + { + "comparisons": 25, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 2, + "grid_size": 5, + "ground_truth": [ + [ + 24, + 0.0 + ], + [ + 19, + 1.0 + ], + [ + 23, + 1.0 + ], + [ + 18, + 2.0 + ], + [ + 14, + 4.0 + ], + [ + 22, + 4.0 + ], + [ + 13, + 5.0 + ], + [ + 17, + 5.0 + ], + [ + 12, + 8.0 + ], + [ + 9, + 9.0 + ] + ], + "k": 10, + "metrics": { + "count": 25 + }, + "query": [ + 4.0, + 4.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 2.0, + 4.0, + 4.0, + 5.0, + 5.0, + 8.0, + 9.0 + ] + } + ] +} \ No newline at end of file diff --git a/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json new file mode 100644 index 000000000..1e3c997bc --- /dev/null +++ b/diskann/test/generated/flat/test/cases/flat_knn_search/search_3_4.json @@ -0,0 +1,276 @@ +{ + "file": "diskann/src/flat/test/cases/flat_knn_search.rs", + "test": "flat/test/cases/flat_knn_search/search_3_4", + "payload": [ + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ] + ], + "k": 1, + "metrics": { + "count": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 1, + "top_k_distances": [ + 3.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ] + ], + "k": 4, + "metrics": { + "count": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 4, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0 + ] + }, + { + "comparisons": 64, + "description": "All -1: nearest is the all-zeros corner; result_count = min(k, len).", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 0, + 3.0 + ], + [ + 1, + 6.0 + ], + [ + 4, + 6.0 + ], + [ + 16, + 6.0 + ], + [ + 5, + 9.0 + ], + [ + 17, + 9.0 + ], + [ + 20, + 9.0 + ], + [ + 2, + 11.0 + ], + [ + 8, + 11.0 + ], + [ + 32, + 11.0 + ] + ], + "k": 10, + "metrics": { + "count": 64 + }, + "query": [ + -1.0, + -1.0, + -1.0 + ], + "result_count": 10, + "top_k_distances": [ + 3.0, + 6.0, + 6.0, + 6.0, + 9.0, + 9.0, + 9.0, + 11.0, + 11.0, + 11.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ] + ], + "k": 1, + "metrics": { + "count": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 1, + "top_k_distances": [ + 0.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ] + ], + "k": 4, + "metrics": { + "count": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 4, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0 + ] + }, + { + "comparisons": 64, + "description": "All `size-1`: query coincides with the last grid corner.", + "grid_dims": 3, + "grid_size": 4, + "ground_truth": [ + [ + 63, + 0.0 + ], + [ + 47, + 1.0 + ], + [ + 59, + 1.0 + ], + [ + 62, + 1.0 + ], + [ + 43, + 2.0 + ], + [ + 46, + 2.0 + ], + [ + 58, + 2.0 + ], + [ + 42, + 3.0 + ], + [ + 31, + 4.0 + ], + [ + 55, + 4.0 + ] + ], + "k": 10, + "metrics": { + "count": 64 + }, + "query": [ + 3.0, + 3.0, + 3.0 + ], + "result_count": 10, + "top_k_distances": [ + 0.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 3.0, + 4.0, + 4.0 + ] + } + ] +} \ No newline at end of file diff --git a/rfcs/00983-flat-search.md b/rfcs/00983-flat-search.md new file mode 100644 index 000000000..3d56e2db8 --- /dev/null +++ b/rfcs/00983-flat-search.md @@ -0,0 +1,214 @@ +# Flat Search + +| | | +|------------------|--------------------------------| +| **Authors** | Aditya Krishnan, Alex Razumov, Dongliang Wu | +| **Created** | 2026-04-24 | +| **Updated** | 2026-05-20 | + +## 1. Motivation + +### 1.1 Background + +DiskANN today exposes a single abstraction family centered on the +[`crate::provider::Accessor`] trait. Accessors are random access by design since the graph greedy search algorithm needs to decide which ids to fetch and the accessor materializes the corresponding elements (vectors, quantized vectors and neighbor lists) on demand. This is the right contract for graph search, where neighborhood expansion is inherently random-access against the [`crate::provider::DataProvider`]. + +A growing class of consumers diverge from our current pattern of use by accesssing their index **sequentially**. Some consumers build their index in an "append-only" fashion and require that they walk the index in a sequential, fixed order, relying on iteration position to enforce versioning / deduplication invariants. + +### 1.2 Problem Statement + +The problem-statement here is simple: provide first-class support for sequential, one-pass scans over a data backend without stuffing the algorithm or the backend through the `Accessor` trait surface. + +### 1.3 Goals + +1. Define a fused iterate-and-score primitive — `flat::DistancesUnordered` — that + mirrors the role `Accessor` plays for graph search but exposes a sequential + scan-and-score operation instead of random access. +2. Provide flat-search algorithm implementations built on the new primitives, so consumers can use this against their own providers / backends. +3. (Near-future) Expose support for diferent distance computers and post-processing like re-ranking _out-of-the-box_ without having to reimplement these for the flat search path. + +## 2. Proposal + +The only shared surface between graph and flat search is the `DataProvider` (for id-mapping / context). + +The module exposes three layers: + +| Layer | Trait | Role | +|-------|-------|------| +| Backend | `DistancesUnordered` | Scan-and-score primitive | +| Factory | `SearchStrategy` | Per-query visitor + computer construction | +| Algorithm | `FlatIndex::knn_search` | Brute-force top-k | + +### 2.1 `DistancesUnordered` — the core scanning trait + +The single required trait for flat search. It is generic over a **computer type** `C` +rather than a query type — the algorithm supplies a pre-built computer and the visitor +drives the scan. + +```rust +pub trait DistancesUnordered: Send + Sync +where + C: for<'a> PreprocessedDistanceFunction, f32>, +{ + type ElementRef<'a>; + type Id; + type Error: ToRanked + Debug + Send + Sync + 'static; + + fn distances_unordered( + &mut self, + computer: &C, + f: F, + ) -> impl SendFuture> + where + F: Send + FnMut(Self::Id, f32); +} +``` + +Key differences from the graph-side `Accessor` path: + +- No random access — the visitor drives the entire scan internally. +- `ElementRef<'a>` and `Id` live on `DistancesUnordered` itself, decoupling the + scan-and-score primitive from `HasId` and from any provider-specific id type. A + visitor is free to yield ids derived from but not equal to its provider's + `InternalId`. We expect this constraint to go away once we're able to clean up the `VectorId` trait + and its restrictive bounds - i.e. expects id to be scalar-like. + +### 2.2 `SearchStrategy` — per-query factory + +The strategy owns both visitor construction and query-computer construction: + +```rust +pub trait SearchStrategy: Send + Sync +where + P: DataProvider, +{ + type ElementRef<'a>; + type Id; + type QueryComputer: for<'a> PreprocessedDistanceFunction, f32> + + Send + Sync + 'static; + type QueryComputerError: StandardError; + + type Visitor<'a>: for<'b> DistancesUnordered< + Self::QueryComputer, + ElementRef<'b> = Self::ElementRef<'b>, + Id = Self::Id, + > + where Self: 'a, P: 'a; + + type Error: StandardError; + + fn create_visitor<'a>( + &'a self, + provider: &'a P, + context: &'a P::Context, + ) -> Result, Self::Error>; + + fn build_query_computer( + &self, + query: T, + ) -> Result; +} +``` + +`build_query_computer` lives on the **strategy**, not the visitor. This keeps the +visitor free of any distance-computation trait bounds — it only needs to implement +`DistancesUnordered` for the strategy's computer type. + +### 2.3 `FlatIndex::knn_search` + +`FlatIndex

` is a thin `'static` wrapper around a `DataProvider`. The `knn_search` +method is the brute-force top-k algorithm: + +```rust +impl FlatIndex

{ + pub fn knn_search( + &self, + k: NonZeroUsize, + strategy: &S, + context: &P::Context, + query: T, + output: &mut OB, + ) -> impl SendFuture> + where + S: SearchStrategy, + S::Id: NeighborPriorityQueueIdType, + T: Send + Sync, + OB: SearchOutputBuffer + Send + ?Sized; +} +``` + +Algorithm: + +1. `strategy.create_visitor(&provider, context)` — acquire the scanning visitor. +2. `strategy.build_query_computer(query)` — preprocess the query into a computer. +3. `visitor.distances_unordered(&computer, |id, dist| queue.insert(...))` — full scan. +4. Drain the priority queue into `output` in best-first order. + +**No post-processing parameter (yet).** Currently `knn_search` writes +`(S::Id, f32)` directly into the `SearchOutputBuffer`. Once the graph-search +trait refactor in [PR #1076](https://github.com/microsoft/DiskANN/pull/1076) +lands, `knn_search` will accept an optional `SearchPostProcess` parameter +(the same trait graph search uses), enabling id remapping, re-ranking, and +other transformations as a composable layer. + +#### Call-chain diagram + +```text + Graph Flat + ───── ──── + + DiskANNIndex::search FlatIndex::knn_search + │ │ + ▼ ▼ + graph::glue::SearchStrategy flat::SearchStrategy + ::search_accessor ::create_visitor + │ ::build_query_computer + ▼ │ + Accessor + BuildQueryComputer ▼ + → QueryComputer DistancesUnordered + │ ::distances_unordered(&computer, f) + ▼ │ + ExpandBeam::expand_beam │ + (greedy beam, random access) │ + │ │ + ▼ ▼ + NeighborPriorityQueue NeighborPriorityQueue + │ │ + ▼ ▼ + SearchPostProcess SearchPostProcess (planned, PR #1076) + → SearchOutputBuffer → SearchOutputBuffer +``` + +## Trade-offs + +### No built-in post-processing (temporary) + +`knn_search` currently writes `(InternalId, f32)` directly. Once the graph-search +trait refactor in [PR #1076](https://github.com/microsoft/DiskANN/pull/1076) lands +and stabilizes a shared `SearchPostProcess` trait, `knn_search` will gain an optional +post-processor parameter matching the graph-search signature. Until then, callers that +need id remapping or re-ranking compose it externally. + +### Reusing `DataProvider` + +The design requires implementations to provide `InternalId` / `ExternalId` conversions. +This is arguably too restrictive for some flat-index consumers, but avoids introducing a +second provider trait. + +### Expand `ElementRef` and `QueryComputer` to support batched distance computation? + +The design for `DistancesUnordered` assumes the computer acts on single vectors. An alternative is to allow the computer to work +over batches, enabling (potentially) better cache utilization. Backends that need this can implement `DistancesUnordered` +directly with an optimized bulk loop. Some refactoring for the bounds on `DistancesUnordered` is needed here. + +### Intra-query parallelism + +`DistancesUnordered` requires `&mut self`, precluding internal parallelism within a single +scan. A parallel variant would need a different trait shape (e.g. splitting the scan across +shards). This is left for future work. + +## Future Work +- **Post-processing support** — once [PR #1076](https://github.com/microsoft/DiskANN/pull/1076) lands, add a `SearchPostProcess` parameter to `knn_search` so flat search can share the same id-remapping / re-ranking infrastructure as graph search. +- Support for other flat-search algorithms like filtered, range, and diverse flat algorithms as additional methods on `FlatIndex`. +- Index build — this is just one part of the picture; more work needs to be done around how this fits in with any traits / interface we need for index build. +