Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 62 additions & 62 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions diskann-garnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ bytemuck.workspace = true
crossbeam = "0.8.4"
dashmap = { workspace = true, features = ["inline"] }
diskann.workspace = true
diskann-quantization.workspace = true
diskann-quantization = { workspace = true, features = ["flatbuffers"] }
diskann-providers.workspace = true
diskann-utils.workspace = true
diskann-vector.workspace = true
foldhash = "0.2.0"
rand.workspace = true
thiserror.workspace = true
tokio.workspace = true
diskann-utils.workspace = true
tokio = { workspace = true, features = ["sync"] }
2 changes: 1 addition & 1 deletion diskann-garnet/diskann-garnet.nuspec
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<package>
<metadata>
<id>diskann-garnet</id>
<version>1.0.27</version>
<version>1.0.28</version>
<readme>docs/README.md</readme>
<authors>Microsoft</authors>
<projectUrl>https://github.com/microsoft/DiskANN</projectUrl>
Expand Down
31 changes: 22 additions & 9 deletions diskann-garnet/src/dyn_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
SearchResults,
garnet::{Context, GarnetId},
labels::GarnetQueryLabelProvider,
provider::GarnetProvider,
provider::{DynamicQuantization, GarnetProvider},
};
use diskann::{
ANNResult,
Expand All @@ -16,14 +16,13 @@ use diskann::{
utils::VectorRepr,
};
use diskann_providers::{
index::wrapped_async::DiskANNIndex,
model::graph::provider::{async_::common::FullPrecision, layers::BetaFilter},
index::wrapped_async::DiskANNIndex, model::graph::provider::layers::BetaFilter,
};
use std::sync::Arc;

/// Type-erased version of `DiskANNIndex<GarnetProvider>`.
/// All vector data is passed as untyped byte slices.
pub trait DynIndex: Send + Sync {
pub(crate) trait DynIndex: Send + Sync {
fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>;

fn set_attributes(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()>;
Expand Down Expand Up @@ -55,6 +54,10 @@ pub trait DynIndex: Send + Sync {
fn internal_id_exists(&self, context: &Context, id: u32) -> bool;

fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool;

fn train_quantizer(&self, context: &Context) -> bool;

fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize);
}

impl<T: VectorRepr> DynIndex for DiskANNIndex<GarnetProvider<T>> {
Expand All @@ -63,7 +66,7 @@ impl<T: VectorRepr> DynIndex for DiskANNIndex<GarnetProvider<T>> {
/// The data slice here must be aligned to `T` or this will panic.
fn insert(&self, context: &Context, id: &GarnetId, data: &[u8]) -> ANNResult<()> {
self.insert(
FullPrecision,
DynamicQuantization,
context,
id,
bytemuck::cast_slice::<u8, T>(data),
Expand All @@ -87,10 +90,10 @@ impl<T: VectorRepr> DynIndex for DiskANNIndex<GarnetProvider<T>> {
) -> ANNResult<SearchStats> {
let query = bytemuck::cast_slice::<u8, T>(data);
if let Some((labels, beta)) = filter {
let beta_filter = BetaFilter::new(FullPrecision, Arc::new(labels.clone()), beta);
let beta_filter = BetaFilter::new(DynamicQuantization, Arc::new(labels.clone()), beta);
self.search(*params, &beta_filter, context, query, output)
} else {
self.search(*params, &FullPrecision, context, query, output)
self.search(*params, &DynamicQuantization, context, query, output)
}
}

Expand All @@ -104,14 +107,14 @@ impl<T: VectorRepr> DynIndex for DiskANNIndex<GarnetProvider<T>> {
) -> ANNResult<SearchStats> {
// Look up internal ID
let iid = self.inner.provider().to_internal_id(context, id)?;
let data = self.inner.provider().get_vector(context, iid)?;
let data = self.inner.provider().get_full_vector(context, iid)?;
let data_bytes = bytemuck::cast_slice::<T, u8>(&data);
self.search_vector(context, data_bytes, params, filter, output)
}

fn remove(&self, context: &Context, id: &GarnetId) -> ANNResult<()> {
self.inplace_delete(
FullPrecision,
DynamicQuantization,
context,
id,
3,
Expand All @@ -137,4 +140,14 @@ impl<T: VectorRepr> DynIndex for DiskANNIndex<GarnetProvider<T>> {
fn external_id_exists(&self, context: &Context, id: &GarnetId) -> bool {
self.inner.provider().vector_id_exists(context, id)
}

fn train_quantizer(&self, context: &Context) -> bool {
self.inner.provider().train_quantizer(context)
}

fn backfill_quant_vectors(&self, context: &Context, task_idx: usize, task_count: usize) {
self.inner
.provider()
.backfill_quant_vectors(context, task_idx, task_count);
}
}
18 changes: 8 additions & 10 deletions diskann-garnet/src/ffi_recall_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mod tests {
use diskann_vector::distance::{Cosine, Metric, SquaredL2};

use crate::{
VectorQuantType, VectorValueType, create_index, drop_index, garnet::Context, insert,
search_vector, test_utils::Store,
VectorQuantType, create_index, drop_index, garnet::Context, insert, search_vector,
test_utils::Store,
};

/// Helper to insert a vector with a string external ID and FP32 data.
Expand All @@ -25,16 +25,15 @@ mod tests {
let vector_bytes: &[u8] = bytemuck::cast_slice(vector);
unsafe {
insert(
ctx.0,
ctx.get(),
index_ptr,
id_bytes.as_ptr(),
id_bytes.len(),
VectorValueType::FP32,
vector_bytes.as_ptr(),
vector.len(),
b"".as_ptr(),
0,
)
) > 0
}
}

Expand Down Expand Up @@ -156,14 +155,14 @@ mod tests {
) -> f64 {
store.clear();
let callbacks = store.callbacks();
let ctx = Context(0);
let ctx = Context::new(0);

let reduce_dimensions = 0;
let l_build = 100;
let max_degree = 32;
let index_ptr = unsafe {
create_index(
ctx.0,
ctx.get(),
dimensions,
reduce_dimensions,
VectorQuantType::NoQuant,
Expand Down Expand Up @@ -204,9 +203,8 @@ mod tests {

let count = unsafe {
search_vector(
ctx.0,
ctx.get(),
index_ptr,
VectorValueType::FP32,
query_bytes.as_ptr(),
vec.len(),
delta,
Expand Down Expand Up @@ -237,7 +235,7 @@ mod tests {
total_expected += expected_ids.len();
}

unsafe { drop_index(ctx.0, index_ptr) };
unsafe { drop_index(ctx.get(), index_ptr) };

total_matches as f64 / total_expected as f64
}
Expand Down
Loading
Loading