Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions diskann-garnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ added:

- `XB8`: When specifying vector input type, you can use `XB8` instead of `FP32`
to specify binary data in uint8 format, one byte per dimension.
- `SB8`: When specifying vector input type, you can use `SB8` instead of `FP32`
to specify binary data in int8 (signed) format, one byte per dimension.
- `XPREQ8`: This is a pseudo-quantizer that specifies the vector data will be
stored as full precision data in uint8 format.
- `Q8`: This is a pseudo-quantizer that specifies the vector data will be
stored as full precision data in int8 (signed) format.

Generally you will use `XB8` with `XPREQ8` to input and store uint8 vectors and
`FP32` with `NOQUANT` to input and store f32 vectors.
Generally you will use `XB8` with `XPREQ8` to input and store uint8 vectors,
`SB8` with `Q8` to input and store int8 vectors, and `FP32` with `NOQUANT` to
input and store f32 vectors.

Support for binary and scalar quantization is coming, along with support for
customizing the distance metric.
Expand Down
270 changes: 270 additions & 0 deletions diskann-garnet/src/ffi_recall_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,274 @@ mod tests {
let recall = run_circle_recall(&store, 50, 10.0, 3);
assert!(recall >= 0.99, "circle r=10 recall too low: {recall:.4}");
}

// ── SB8 (signed int8) recall tests ─────────────────────────────────

/// Helper to insert a vector with a string external ID and SB8 data.
fn insert_sb8_vector_str(
ctx: &Context,
index_ptr: *const c_void,
eid: &str,
vector: &[i8],
) -> bool {
let id_bytes = eid.as_bytes();
let vector_bytes: &[u8] = bytemuck::cast_slice(vector);
unsafe {
insert(
ctx.0,
index_ptr,
id_bytes.as_ptr(),
id_bytes.len(),
VectorValueType::SB8,
vector_bytes.as_ptr(),
vector.len(),
b"".as_ptr(),
0,
)
}
}

/// Common helper: create an index, insert SB8 vectors, query each with SB8, return recall.
/// Distances are computed in f32 space (SB8 values are converted to f32 internally).
fn run_sb8_recall(
store: &Store,
dimensions: u32,
metric: Metric,
ids: &[String],
vectors_i8: &[Vec<i8>],
k: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
) -> f64 {
run_sb8_recall_with_quant(
store,
dimensions,
metric,
ids,
vectors_i8,
k,
distance_fn,
VectorQuantType::NoQuant,
)
}

/// Common helper: create an index with specified quant type, insert SB8 vectors, query each
/// with SB8, return recall. Brute-force comparison uses f32 space; index distance computation
/// depends on the quant type (f32 for NoQuant, i8 for Q8).
#[allow(clippy::too_many_arguments)]
fn run_sb8_recall_with_quant(
store: &Store,
dimensions: u32,
metric: Metric,
ids: &[String],
vectors_i8: &[Vec<i8>],
k: usize,
distance_fn: fn(&[f32], &[f32]) -> f32,
quant_type: VectorQuantType,
) -> f64 {
store.clear();
let callbacks = store.callbacks();
let ctx = Context(0);

let reduce_dimensions = 0;
let l_build = 100;
let max_degree = 32;
let index_ptr = unsafe {
create_index(
ctx.0,
dimensions,
reduce_dimensions,
quant_type,
metric as i32,
l_build,
max_degree,
callbacks.read_callback(),
callbacks.write_callback(),
callbacks.delete_callback(),
callbacks.rmw_callback(),
)
};
assert!(!index_ptr.is_null());

let max_id_len = ids.iter().map(|id| id.len()).max().unwrap_or(0);

// Convert i8 vectors to f32 for brute-force comparison
let vectors_f32: Vec<Vec<f32>> = vectors_i8
.iter()
.map(|v| v.iter().map(|&x| x as f32).collect())
.collect();

for i in 0..ids.len() {
assert!(
insert_sb8_vector_str(&ctx, index_ptr, &ids[i], &vectors_i8[i]),
"insert failed for eid={}",
ids[i]
);
}

let mut total_matches = 0usize;
let mut total_expected = 0usize;

let delta = 2.0_f32;
let search_exploration_factor = 200_u32;
let max_filtering_effort = 0_usize;
let continuation = ptr::null_mut();

for (idx, vec_i8) in vectors_i8.iter().enumerate() {
let query_bytes: &[u8] = bytemuck::cast_slice(vec_i8);
let max_id_size = mem::size_of::<u32>() + max_id_len;
let mut output_id_buffer = vec![0u8; k * max_id_size];
let mut output_dists = vec![0f32; k];

let count = unsafe {
search_vector(
ctx.0,
index_ptr,
VectorValueType::SB8,
query_bytes.as_ptr(),
vec_i8.len(),
delta,
search_exploration_factor,
ptr::null(),
0,
max_filtering_effort,
output_id_buffer.as_mut_ptr(),
output_id_buffer.len(),
output_dists.as_mut_ptr(),
output_dists.len(),
continuation,
)
};
assert!(count >= 0, "search failed");

let result_ids = parse_string_ids(&output_id_buffer, count as usize);
let expected_ids =
brute_force_knn(ids, &vectors_f32, &vectors_f32[idx], k, distance_fn);
let matches = distance_based_intersection(
&vectors_f32,
ids,
&vectors_f32[idx],
&expected_ids,
&result_ids,
distance_fn,
);
total_matches += matches;
total_expected += expected_ids.len();
}

unsafe { drop_index(ctx.0, index_ptr) };

total_matches as f64 / total_expected as f64
}

/// Generates `grid_size ^ dimensions` SB8 vectors on a signed integer grid
/// centered around 0 (values from -(grid_size/2) to +(grid_size/2)-1).
fn generate_sb8_grid_vectors(
dimensions: usize,
grid_size: usize,
) -> (Vec<String>, Vec<Vec<i8>>) {
let total = grid_size.pow(dimensions as u32);
let offset = (grid_size / 2) as i8;
let mut ids = Vec::with_capacity(total);
let mut vectors = Vec::with_capacity(total);

for i in 0..total {
let mut vec = vec![0i8; dimensions];
let mut pos = i;
for d in (0..dimensions).rev() {
vec[d] = (pos % grid_size) as i8 - offset;
pos /= grid_size;
}
ids.push(format!("sb8_grid_{:08}_dim{}", i + 1, dimensions));
vectors.push(vec);
}

(ids, vectors)
}

/// Helper: create an L2 index, insert SB8 grid vectors, query each, return recall.
fn run_sb8_grid_recall(store: &Store, dimensions: u32, grid_size: usize, k: usize) -> f64 {
let (ids, vectors) = generate_sb8_grid_vectors(dimensions as usize, grid_size);
run_sb8_recall(
store,
dimensions,
Metric::L2,
&ids,
&vectors,
k,
distance_fn::<SquaredL2>,
)
}

#[test]
fn sb8_grid_l2_recall_1d_100() {
let store = Store;
let recall = run_sb8_grid_recall(&store, 1, 100, 3);
assert!(recall >= 0.99, "SB8 1D grid recall too low: {recall:.4}");
}

#[test]
fn sb8_grid_l2_recall_2d_10() {
let store = Store;
let recall = run_sb8_grid_recall(&store, 2, 10, 3);
assert!(recall >= 0.99, "SB8 2D grid recall too low: {recall:.4}");
}

#[test]
fn sb8_grid_l2_recall_3d_7() {
let store = Store;
let recall = run_sb8_grid_recall(&store, 3, 7, 3);
assert!(recall >= 0.99, "SB8 3D grid recall too low: {recall:.4}");
}

#[test]
fn sb8_grid_l2_recall_4d_5() {
let store = Store;
let recall = run_sb8_grid_recall(&store, 4, 5, 3);
assert!(recall >= 0.99, "SB8 4D grid recall too low: {recall:.4}");
}

// ── Q8 (native int8 index) recall tests ─────────────────────────────

/// Helper: create a Q8 L2 index, insert SB8 grid vectors, query each, return recall.
fn run_q8_grid_recall(store: &Store, dimensions: u32, grid_size: usize, k: usize) -> f64 {
let (ids, vectors) = generate_sb8_grid_vectors(dimensions as usize, grid_size);
run_sb8_recall_with_quant(
store,
dimensions,
Metric::L2,
&ids,
&vectors,
k,
distance_fn::<SquaredL2>,
VectorQuantType::Q8,
)
}

#[test]
fn q8_grid_l2_recall_1d_100() {
let store = Store;
let recall = run_q8_grid_recall(&store, 1, 100, 3);
assert!(recall >= 0.99, "Q8 1D grid recall too low: {recall:.4}");
}

#[test]
fn q8_grid_l2_recall_2d_10() {
let store = Store;
let recall = run_q8_grid_recall(&store, 2, 10, 3);
assert!(recall >= 0.99, "Q8 2D grid recall too low: {recall:.4}");
}

#[test]
fn q8_grid_l2_recall_3d_7() {
let store = Store;
let recall = run_q8_grid_recall(&store, 3, 7, 3);
assert!(recall >= 0.99, "Q8 3D grid recall too low: {recall:.4}");
}

#[test]
fn q8_grid_l2_recall_4d_5() {
let store = Store;
let recall = run_q8_grid_recall(&store, 4, 5, 3);
assert!(recall >= 0.99, "Q8 4D grid recall too low: {recall:.4}");
}
}
Loading