Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* reserved. SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include "../../../core/nvtx.hpp"
#include "../../../preprocessing/quantize/vpq_build-ext.cuh"
#include "../../ivf_pq/ivf_pq_fp16_overflow.cuh"
#include "graph_core.cuh"

#include <raft/core/copy.cuh>
Expand Down Expand Up @@ -1664,6 +1665,24 @@ void build_knn_graph(
RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str());
auto index = cuvs::neighbors::ivf_pq::build(res, pq.build_params, dataset);

// Empirically detect FP16 distance overflow on the just-built index: run a small FP16 probe
// search and downgrade the internal/coarse dtypes to FP32 if any distance comes back non-finite.
// This observes the actual computation, so it is agnostic of the selected distance type.
if (pq.search_params.internal_distance_dtype == CUDA_R_16F ||
pq.search_params.coarse_search_dtype == CUDA_R_16F) {
{
const bool fp16_overflow = cuvs::neighbors::ivf_pq::helpers::detect_fp16_overflow(
res, index, pq.search_params, dataset);
if (fp16_overflow) {
RAFT_LOG_WARN(
"IVF-PQ FP16 distance produced non-finite results on a probe search for this dataset -> "
"switching 'internal_distance_dtype' and 'coarse_search_dtype' to FP32");
pq.search_params.internal_distance_dtype = CUDA_R_32F;
pq.search_params.coarse_search_dtype = CUDA_R_32F;
}
}
}

//
// search top (k + 1) neighbors
//
Expand Down Expand Up @@ -2217,6 +2236,7 @@ index<T, IdxT> build(
knn_build_params = cagra::graph_build_params::ivf_pq_params(dataset.extents(), params.metric);
}
}

RAFT_EXPECTS(
params.metric != cuvs::distance::DistanceType::BitwiseHamming ||
std::holds_alternative<cagra::graph_build_params::iterative_search_params>(
Expand Down
84 changes: 84 additions & 0 deletions cpp/src/neighbors/ivf_pq/ivf_pq_fp16_overflow.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "../detail/ann_utils.cuh" // cuvs::spatial::knn::detail::utils::mapping

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map_reduce.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

#include <thrust/logical.h>

#include <algorithm>
#include <cstdint>

namespace cuvs::neighbors::ivf_pq::helpers {

/**
* @brief Detect whether FP16 internal distance dtypes overflow for this dataset during search.
*
* Runs a small probe search against an already-built IVF-PQ index with current distance types,
* and reports whether any returned distance is non-finite (inf/NaN).
*/
template <typename DataT, typename Accessor>
bool detect_fp16_overflow(
raft::resources const& handle,
const cuvs::neighbors::ivf_pq::index<int64_t>& index,
cuvs::neighbors::ivf_pq::search_params search_params,
raft::mdspan<const DataT, raft::matrix_extent<int64_t>, raft::row_major, Accessor> dataset)
{
const int64_t n_rows = dataset.extent(0);
if (n_rows == 0) { return false; }

auto stream = raft::resource::get_cuda_stream(handle);
const int64_t dim = dataset.extent(1);

constexpr int64_t kMaxSampleQueries = 128;
constexpr uint32_t kProbeTopK = 32;
const int64_t n_sample = std::min<int64_t>(n_rows, kMaxSampleQueries);
const uint32_t top_k = std::min<uint32_t>(static_cast<uint32_t>(n_rows), kProbeTopK);

auto mr = raft::resource::get_workspace_resource_ref(handle);
auto queries =
raft::make_device_mdarray<DataT>(handle, mr, raft::make_extents<int64_t>(n_sample, dim));
raft::copy(queries.data_handle(), dataset.data_handle(), n_sample * dim, stream);

auto neighbors =
raft::make_device_mdarray<int64_t>(handle, mr, raft::make_extents<int64_t>(n_sample, top_k));
auto distances =
raft::make_device_mdarray<float>(handle, mr, raft::make_extents<int64_t>(n_sample, top_k));

cuvs::neighbors::ivf_pq::search(handle,
search_params,
index,
raft::make_const_mdspan(queries.view()),
neighbors.view(),
distances.view());

const int64_t count = n_sample * static_cast<int64_t>(top_k);
auto is_non_finite_op = [] __device__(float v) { return isnan(v) || isinf(v); };
const bool any_non_finite = thrust::any_of(raft::resource::get_thrust_policy(handle),
distances.data_handle(),
distances.data_handle() + count,
is_non_finite_op);
raft::resource::sync_stream(handle);
return any_non_finite;
}

} // namespace cuvs::neighbors::ivf_pq::helpers
Loading