Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6b6c63c
first draft
arkrishn94 Apr 28, 2026
9f718a1
remove unnecessary parts
arkrishn94 Apr 28, 2026
55cafc2
Merge remote-tracking branch 'origin/main' into u/adkrishnan/flat-index
arkrishn94 Apr 28, 2026
f0a9dbd
rename file
arkrishn94 Apr 28, 2026
0672f3d
fmt
arkrishn94 Apr 28, 2026
3ac0e1b
split iterator to callback
arkrishn94 Apr 29, 2026
f887f2f
use distance unordered callback
arkrishn94 Apr 29, 2026
dc2281c
rfc update
arkrishn94 Apr 29, 2026
ee48f7d
rustfmt flat module
arkrishn94 Apr 29, 2026
3b2a1e0
Merge remote-tracking branch 'origin/main' into u/adkrishnan/flat-index
arkrishn94 Apr 29, 2026
7fd903e
fix clippy: replace doc-comment divider with regular comment
arkrishn94 Apr 29, 2026
8af5e00
small edits
arkrishn94 Apr 29, 2026
1dd1c72
renames and uplevel query computer
arkrishn94 May 4, 2026
3d24ef7
split buildquerycomputer, haselementref and distancesunordered
arkrishn94 May 5, 2026
0a63759
delete flatpostprocess, cleanup and docs
arkrishn94 May 5, 2026
1f48945
Merge remote-tracking branch 'origin/main' into u/adkrishnan/flat-index
arkrishn94 May 5, 2026
57d86e9
update rfc
arkrishn94 May 6, 2026
715150a
error types
arkrishn94 May 6, 2026
0627cc0
small doc fixes
arkrishn94 May 7, 2026
1d4e90b
iterator docs
arkrishn94 May 7, 2026
b4d9df0
strong pass on testing
arkrishn94 May 8, 2026
f338a83
minor cleanup docs
arkrishn94 May 8, 2026
2928404
Add HasQueryComputer trait
May 12, 2026
3620c6d
Revert "Add HasQueryComputer trait"
May 12, 2026
409d177
minor comments
arkrishn94 May 13, 2026
cf8847f
Merge remote-tracking branch 'origin/main' into u/adkrishnan/flat-index
arkrishn94 May 13, 2026
ffc70db
Merge branch 'u/adkrishnan/flat-index' of https://github.com/microsof…
arkrishn94 May 13, 2026
3a125f7
remove OnElementsUnordered
arkrishn94 May 14, 2026
e322565
remove OnElementsUnordered
arkrishn94 May 14, 2026
82ee77e
remove haselement and buildquerycomputer refactor
arkrishn94 May 18, 2026
728ad44
update rfc
arkrishn94 May 18, 2026
8f3b383
testing consolidate and docs
arkrishn94 May 18, 2026
6ae4b68
remove iterator and add Id hook
arkrishn94 May 20, 2026
2f27864
Merge remote-tracking branch 'origin/main' into u/adkrishnan/flat-index
arkrishn94 May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions diskann/src/flat/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! [`FlatIndex`] — the index wrapper for a [`DataProvider`](crate::provider::DataProvider)
//! over which we do flat search.
use std::num::NonZeroUsize;

use diskann_utils::future::SendFuture;

use crate::{
ANNResult,
error::{ErrorExt, IntoANNResult},
flat::{DistancesUnordered, SearchStrategy},
graph::SearchOutputBuffer,
neighbor::{Neighbor, NeighborPriorityQueue, NeighborPriorityQueueIdType},
provider::DataProvider,
};

/// Statistics collected during a flat search.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct SearchStats {
/// The total number of distance computations performed during the scan.
pub cmps: u32,

/// The total number of results written to the output buffer.
pub result_count: u32,
}

/// A thin wrapper around a [`DataProvider`] used for flat search.
#[derive(Debug)]
pub struct FlatIndex<P: DataProvider> {
/// The backing provider.
provider: P,
}

impl<P: DataProvider> FlatIndex<P> {
/// Construct a new [`FlatIndex`] around `provider`.
pub fn new(provider: P) -> Self {
Self { provider }
}

/// Borrow the underlying provider.
pub fn provider(&self) -> &P {
&self.provider
}

/// Brute-force k-nearest-neighbor flat search.
///
/// Streams every element produced by the strategy's visitor through the query
/// computer, keeps the best `k` candidates in a [`NeighborPriorityQueue`], and
/// writes the `(id, distance)` survivors into `output` in best-first order.
pub fn knn_search<S, T, OB>(
&self,
k: NonZeroUsize,
strategy: &S,
context: &P::Context,
query: T,
output: &mut OB,
) -> impl SendFuture<ANNResult<SearchStats>>
where
S: SearchStrategy<P, T>,
S::Id: NeighborPriorityQueueIdType,
T: Send + Sync,
OB: SearchOutputBuffer<S::Id> + Send + ?Sized,
{
async move {
let mut visitor = strategy
.create_visitor(&self.provider, context)
.into_ann_result()?;

let computer = strategy.build_query_computer(query).into_ann_result()?;

let k = k.get();
let mut queue = NeighborPriorityQueue::new(k);
let mut cmps: u32 = 0;

visitor
.distances_unordered(&computer, |id, dist| {
cmps += 1;
queue.insert(Neighbor::new(id, dist));
Comment thread
arrayka marked this conversation as resolved.
})
.await
.escalate("flat scan must complete to produce correct k-NN results")?;

let result_count =
output.extend(queue.iter().take(k).map(|n| (n.id, n.distance))) as u32;

Ok(SearchStats { cmps, result_count })
}
}
}

///////////
// Tests //
///////////

#[cfg(test)]
mod tests {
use crate::flat::{
FlatIndex,
test::{
harness::KnnOracleRun,
provider::{self as flat_provider},
},
};
use crate::graph::test::synthetic::Grid;

fn fixture(grid: Grid, size: usize) -> (FlatIndex<flat_provider::Provider>, usize) {
let provider = flat_provider::Provider::grid(grid, size).unwrap();
let len = provider.len();
(FlatIndex::new(provider), len)
}

/// `knn_search` returns a `Send` future, and a shared `&FlatIndex` can serve
/// many concurrent searches on a multi-threaded runtime, each producing the
/// correct top-k independently.
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn multithreaded_knn_search() {
use std::sync::Arc;

let (index, len) = fixture(Grid::Two, 4);
let index = Arc::new(index);

// Mix of corner, axis-aligned, and off-grid queries; k spans 1..=len.
let cases: &[(&[f32], usize)] = &[
(&[-1.0, -1.0], 1),
(&[1.0, 1.0], len),
(&[-1.0, 1.0], len / 2),
(&[1.0, -1.0], len - 1),
(&[0.0, 0.0], 3),
(&[3.0, 3.0], len),
(&[-2.0, 0.5], 2),
(&[0.5, -0.5], len),
];

let mut set = tokio::task::JoinSet::new();
for (query, k) in cases {
let index = Arc::clone(&index);
let query: Vec<f32> = query.to_vec();
let k = *k;
set.spawn(async move {
let outcome = KnnOracleRun::run(
&index,
&flat_provider::Strategy::new(index.provider().dim()),
&query,
k,
)
.await
.expect("knn_search failed");
(query, k, outcome)
});
}

while let Some(joined) = set.join_next().await {
let (query, k, outcome) = joined.expect("task panicked");
assert_eq!(
outcome.top_k, outcome.ground_truth,
"query = {query:?}, k = {k}: top-k must match brute force",
);
assert_eq!(outcome.stats.cmps as usize, len);
assert_eq!(outcome.stats.result_count as usize, k.min(len));
}
}

////////////
// Errors //
////////////

/// A transient error from the visitor's scan must escalate up through `knn_search`.
#[test]
fn transient_scan_error() {
// The flat scan touches every id, so any transient id is guaranteed to be hit.
for transient_ids in [&[0u32][..], &[3][..], &[1, 2, 5][..]] {
let strategy =
flat_provider::Strategy::with_transient(2, transient_ids.iter().copied());
let (index, _) = fixture(Grid::Two, 3);
let err = KnnOracleRun::run_sync(&index, &strategy, &[1.0, 0.0], 4)
.expect_err("transient error during full scan must escalate");

let msg = format!("{err}");
assert!(
transient_ids
.iter()
.any(|id| msg.contains(&format!("id {id}"))),
"transients = {transient_ids:?}: expected error to name one of the \
transient ids, got: {msg}",
);
}
}

/// Run `knn_search` via the harness, assert it fails, and check the error
/// message contains `expected_msg`.
fn assert_search_error(strategy: &flat_provider::Strategy, query: &[f32], expected_msg: &str) {
let (index, _) = fixture(Grid::Two, 3);
let err = KnnOracleRun::run_sync(&index, strategy, query, 4)
.expect_err("expected knn_search to fail");

let msg = format!("{err}");
assert!(
msg.contains(expected_msg),
"expected error containing {expected_msg:?}, got: {msg}",
);
}

#[test]
fn strategy_constructor_errors() {
// Strategy/provider expect dim=2, query has dim=3.
assert_search_error(
&flat_provider::Strategy::new(2),
&[0.0, 0.0, 0.0],
"dimension mismatch",
);

// Strategy expects dim=5, provider has dim=2.
assert_search_error(
&flat_provider::Strategy::new(5),
&[0.0, 0.0],
"dimension mismatch",
);
}
}
Comment thread
arkrishn94 marked this conversation as resolved.
32 changes: 32 additions & 0 deletions diskann/src/flat/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! Sequential ("flat") search.
//!
//! This module is the streaming counterpart to the random-access
//! [`crate::provider::Accessor`] family. It is designed for backends whose natural access
//! pattern is a one-pass scan over their data -- for example append-only buffered stores or
//! on-disk shards streamed via I/O.
//!
Comment thread
arkrishn94 marked this conversation as resolved.
//! # Architecture
//!
//! The module mirrors the layering used by graph search:
//!
//! | Graph (random access) | Flat (sequential) | Shared? |
//! | :------------------------------------ | :----------------------------------------- |:--------- |
//! | [`crate::provider::DataProvider`] | [`crate::provider::DataProvider`] | Yes |
//! | [`crate::graph::DiskANNIndex`] | [`FlatIndex`] | No |
//! | [`crate::graph::glue::ExpandBeam`] | [`DistancesUnordered`] | No |
//! | [`crate::graph::glue::SearchStrategy`] | [`SearchStrategy`] | No |
//! | [`crate::graph::Search`] | [`FlatIndex::knn_search`] | No |
//!
pub mod index;
pub mod strategy;

pub use index::{FlatIndex, SearchStats};
pub use strategy::{DistancesUnordered, SearchStrategy};

#[cfg(test)]
mod test;
Loading
Loading