From 0cec9062537d4fd54a5dce2cd009a09f69036a86 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 12:12:25 -0700 Subject: [PATCH] Remove flat search. --- .../src/search/provider/disk_provider.rs | 72 ++++++++++---- diskann/src/graph/glue.rs | 19 ---- diskann/src/graph/index.rs | 98 +------------------ diskann/src/graph/test/cases/index.rs | 81 --------------- diskann/src/graph/test/provider.rs | 7 -- 5 files changed, 56 insertions(+), 221 deletions(-) diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 33938caea..1dd518781 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -7,7 +7,6 @@ use std::{ collections::HashMap, future::Future, num::NonZeroUsize, - ops::Range, sync::{ atomic::{AtomicU64, AtomicUsize}, Arc, @@ -21,13 +20,12 @@ use diskann::{ graph::{ self, glue::{ - self, DefaultPostProcessor, ExpandBeam, IdIterator, SearchExt, SearchPostProcess, - SearchStrategy, + self, DefaultPostProcessor, ExpandBeam, SearchExt, SearchPostProcess, SearchStrategy, }, search::Knn, search_output_buffer, AdjacencyList, DiskANNIndex, }, - neighbor::Neighbor, + neighbor::{Neighbor, NeighborPriorityQueue}, provider::{ Accessor, BuildQueryComputer, DataProvider, DefaultContext, DelegateNeighbor, HasId, NeighborAccessor, NoopGuard, @@ -715,16 +713,6 @@ where } } -impl IdIterator> for DiskAccessor<'_, Data, VP> -where - Data: GraphDataType, - VP: VertexProvider, -{ - async fn id_iterator(&mut self) -> Result, ANNError> { - Ok(0..self.provider.num_points as u32) - } -} - impl<'a, 'b, Data, VP> DelegateNeighbor<'a> for DiskAccessor<'b, Data, VP> where Data: GraphDataType, @@ -916,6 +904,55 @@ where } } + /// Perform a brute-force linear scan of all points in the index, returning the + /// nearest neighbors that pass `vector_filter`. + /// + /// The top `neighbors_before_reranking` candidates from the quantized scan will be + /// provided to full-precision reranking. + async fn flat_search( + &self, + strategy: &DiskSearchStrategy<'_, Data, ProviderFactory>, + query: &[Data::VectorDataType], + vector_filter: &(dyn Fn(&u32) -> bool + Send + Sync), + neighbors_before_reranking: usize, + output: &mut OB, + ) -> ANNResult + where + OB: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + Send, + { + let provider = self.index.provider(); + let mut accessor = strategy + .search_accessor(provider, &DefaultContext) + .into_ann_result()?; + let computer = accessor.build_query_computer(query).into_ann_result()?; + + let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); + let mut cmps = 0u32; + + let num_points = provider.num_points as u32; + for id in 0..num_points { + if vector_filter(&id) { + let element = accessor.get_element(id).await.into_ann_result()?; + let dist = computer.evaluate_similarity(element); + best.insert(Neighbor::new(id, dist)); + cmps += 1; + } + } + + let result_count = strategy + .default_post_processor() + .post_process(&mut accessor, query, &computer, best.iter(), output) + .await + .into_ann_result()?; + + Ok(graph::index::SearchStats { + cmps, + hops: 0, + result_count: result_count as u32, + range_search_second_round: false, + }) + } + /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. pub fn search( @@ -993,12 +1030,11 @@ where let k = k_value; let l = search_list_size as usize; let stats = if is_flat_search { - self.runtime.block_on(self.index.flat_search( + self.runtime.block_on(self.flat_search( &strategy, - &DefaultContext, - strategy.query, + query, vector_filter, - &Knn::new(k, l, beam_width)?, + l, &mut result_output_buffer, ))? } else { diff --git a/diskann/src/graph/glue.rs b/diskann/src/graph/glue.rs index cb098b85c..5e691d515 100644 --- a/diskann/src/graph/glue.rs +++ b/diskann/src/graph/glue.rs @@ -91,7 +91,6 @@ use crate::{ Accessor, AsNeighbor, AsNeighborMut, BuildDistanceComputer, BuildQueryComputer, DataProvider, HasId, NeighborAccessor, }, - utils::VectorId, }; /// A trait to override search constraints such as early termination based on constraints @@ -819,24 +818,6 @@ where ) -> impl Future> + Send; } -/// Provides asynchronous access to an iterator over vector IDs. -/// -/// This trait defines a method to asynchronously retrieve an iterator over vector IDs. -/// -/// # Type Parameters -/// -/// - `I`: The iterator type returned by the accessor. It must implement `Iterator` with items of type implementing `VectorId`. -/// -/// # Errors -/// -/// Returns an [`ANNError`] if the iterator cannot be retrieved successfully. -pub trait IdIterator -where - I: Iterator, -{ - fn id_iterator(&mut self) -> impl std::future::Future>; -} - /////////// // Tests // /////////// diff --git a/diskann/src/graph/index.rs b/diskann/src/graph/index.rs index d696135a0..38e30dbc2 100644 --- a/diskann/src/graph/index.rs +++ b/diskann/src/graph/index.rs @@ -25,12 +25,11 @@ use tokio::task::JoinSet; use super::{ AdjacencyList, Config, ConsolidateKind, InplaceDeleteMethod, Search, glue::{ - self, Batch, ExpandBeam, IdIterator, InplaceDeleteStrategy, InsertStrategy, - MultiInsertStrategy, PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, + self, Batch, ExpandBeam, InplaceDeleteStrategy, InsertStrategy, MultiInsertStrategy, + PruneStrategy, SearchExt, SearchPostProcess, SearchStrategy, }, internal::{BackedgeBuffer, SortedNeighbors, prune}, search::{ - Knn, record::{NoopSearchRecord, SearchRecord, VisitedSearchRecord}, scratch::{self, PriorityQueueConfiguration, SearchScratch, SearchScratchParams}, }, @@ -2183,99 +2182,6 @@ where search_params.search(self, strategy, processor, context, query, output) } - /// Performs a brute-force flat search over the points matching a provided filter function. - /// - /// This method executes a linear scan through all points in the index, applying the provided - /// `vector_filter` to select candidate points. It computes the similarity between the query - /// vector and each candidate, returning the top results according to the provided search parameters. - /// - /// # Arguments - /// - /// * `strategy` - The search strategy to use for accessing and processing elements. - /// * `context` - The context to pass through to providers. - /// * `query` - The query vector for which nearest neighbors are sought. - /// * `vector_filter` - A predicate function used to filter candidate vectors based on their external IDs. - /// * `search_params` - Parameters controlling the search behavior, such as search depth (`l_value`). - /// * `output` - A mutable buffer to store the search results. Must be pre-allocated by the caller. - /// - /// # Returns - /// - /// Returns search statistics including the number of distance computations performed. - /// - /// # Errors - /// - /// Returns an error if there is a failure accessing elements or if the provided parameters are invalid. - /// - /// # Notes - /// - /// This method is computationally expensive for large datasets, as it does not leverage the graph structure - /// and instead performs a linear scan of all filtered points. - pub async fn flat_search<'a, S, T, O, OB, I>( - &'a self, - strategy: &'a S, - context: &'a DP::Context, - query: T, - vector_filter: &(dyn Fn(&DP::ExternalId) -> bool + Send + Sync), - search_params: &Knn, - output: &mut OB, - ) -> ANNResult - where - T: Copy + Send, - S: glue::DefaultSearchStrategy: IdIterator>, - I: Iterator::InternalId>, - O: Send, - OB: search_output_buffer::SearchOutputBuffer + Send, - { - let mut accessor = strategy - .search_accessor(&self.data_provider, context) - .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; - - let mut scratch = { - let num_start_points = accessor.starting_points().await?.len(); - self.search_scratch(search_params.l_value().get(), num_start_points) - }; - - let id_iterator = accessor.id_iterator().await?; - for id in id_iterator { - let external_id = self - .data_provider - .to_external_id(context, id) - .escalate("external id should be found")?; - - if vector_filter(&external_id) { - scratch.visited.insert(id); - let element = accessor - .get_element(id) - .await - .escalate("matched point retrieval must succeed")?; - let dist = computer.evaluate_similarity(element.reborrow()); - scratch.best.insert(Neighbor::new(id, dist)); - scratch.cmps += 1; - } - } - - let result_count = strategy - .default_post_processor() - .post_process( - &mut accessor, - query, - &computer, - scratch.best.iter().take(search_params.l_value().get()), - output, - ) - .send() - .await - .into_ann_result()?; - - Ok(SearchStats { - cmps: scratch.cmps, - hops: scratch.hops, - result_count: result_count as u32, - range_search_second_round: false, - }) - } - ////////////////// // Paged Search // ////////////////// diff --git a/diskann/src/graph/test/cases/index.rs b/diskann/src/graph/test/cases/index.rs index 5d91de7bb..540398c39 100644 --- a/diskann/src/graph/test/cases/index.rs +++ b/diskann/src/graph/test/cases/index.rs @@ -281,84 +281,3 @@ async fn test_drop_deleted_neighbors_noop() { .unwrap(); assert_eq!(result, graph::ConsolidateKind::Complete); } - -#[tokio::test(flavor = "current_thread")] -async fn test_flat_search_basic() { - use crate::graph::search::Knn; - use crate::graph::search_output_buffer::IdDistance; - - let adjacency_list = generate_2d_square_adjacency_list(); - let index = setup_2d_square(adjacency_list, 4); - let strategy = test_provider::Strategy::new(); - let ctx = test_provider::Context::new(); - - // Query near origin — node 0 at (0,0) is closest. - // l_value must cover all 5 points (4 data + 1 start) so the working set - // doesn't drop any before the post-processor runs. - let query = [0.1_f32, 0.1]; - let params = Knn::new(4, 5, None).unwrap(); - - let mut ids = [0u32; 4]; - let mut distances = [0.0f32; 4]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let stats = index - .flat_search( - &strategy, - &ctx, - query.as_slice(), - &|_| true, - ¶ms, - &mut output, - ) - .await - .unwrap(); - - // FilterStartPoints removes the start node, leaving 4 data nodes. - assert_eq!(stats.result_count, 4); - let results: std::collections::HashSet = - ids[..stats.result_count as usize].iter().copied().collect(); - for id in 0..4u32 { - assert!(results.contains(&id), "data node {id} should be in results"); - } -} - -#[tokio::test(flavor = "current_thread")] -async fn test_flat_search_with_filter() { - use crate::graph::search::Knn; - use crate::graph::search_output_buffer::IdDistance; - - let adjacency_list = generate_2d_square_adjacency_list(); - let index = setup_2d_square(adjacency_list, 4); - let strategy = test_provider::Strategy::new(); - let ctx = test_provider::Context::new(); - - // Query near origin, but filter out node 0. - let query = [0.1_f32, 0.1]; - let params = Knn::new(2, 4, None).unwrap(); - - let mut ids = [0u32; 2]; - let mut distances = [0.0f32; 2]; - let mut output = IdDistance::new(&mut ids, &mut distances); - - let stats = index - .flat_search( - &strategy, - &ctx, - query.as_slice(), - &|ext_id: &u32| *ext_id != 0, - ¶ms, - &mut output, - ) - .await - .unwrap(); - - assert_eq!(stats.result_count, 2); - assert!( - !ids[..stats.result_count as usize].contains(&0), - "node 0 should be filtered out" - ); - // Nodes 1, 2, 3 remain — closest two to (0.1, 0.1) are 1 (1,0) and 2 (0,1). - assert!(ids.contains(&1), "node 1 should be present"); - assert!(ids.contains(&2), "node 2 should be present"); -} diff --git a/diskann/src/graph/test/provider.rs b/diskann/src/graph/test/provider.rs index 051ec6fc7..afa80b38c 100644 --- a/diskann/src/graph/test/provider.rs +++ b/diskann/src/graph/test/provider.rs @@ -1111,13 +1111,6 @@ impl glue::SearchExt for Accessor<'_> { impl glue::ExpandBeam<&[f32]> for Accessor<'_> {} -impl glue::IdIterator> for Accessor<'_> { - async fn id_iterator(&mut self) -> Result, ANNError> { - let ids: Vec = self.provider.terms.iter().map(|r| *r.key()).collect(); - Ok(ids.into_iter()) - } -} - #[derive(Debug, Clone)] pub struct Strategy { // Set this flag to enable reuse within the [`workingset::Map`]. For multi-threaded