diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index a7c15b4161..f118146fba 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -1,11 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * reserved. SPDX-License-Identifier: Apache-2.0 */ #pragma once #include "../../../core/nvtx.hpp" #include "../../../preprocessing/quantize/vpq_build-ext.cuh" +#include "../../ivf_pq/ivf_pq_fp16_overflow.cuh" #include "graph_core.cuh" #include @@ -1664,6 +1665,24 @@ void build_knn_graph( RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str()); auto index = cuvs::neighbors::ivf_pq::build(res, pq.build_params, dataset); + // Empirically detect FP16 distance overflow on the just-built index: run a small FP16 probe + // search and downgrade the internal/coarse dtypes to FP32 if any distance comes back non-finite. + // This observes the actual computation, so it is agnostic of the selected distance type. + if (pq.search_params.internal_distance_dtype == CUDA_R_16F || + pq.search_params.coarse_search_dtype == CUDA_R_16F) { + { + const bool fp16_overflow = cuvs::neighbors::ivf_pq::helpers::detect_fp16_overflow( + res, index, pq.search_params, dataset); + if (fp16_overflow) { + RAFT_LOG_WARN( + "IVF-PQ FP16 distance produced non-finite results on a probe search for this dataset -> " + "switching 'internal_distance_dtype' and 'coarse_search_dtype' to FP32"); + pq.search_params.internal_distance_dtype = CUDA_R_32F; + pq.search_params.coarse_search_dtype = CUDA_R_32F; + } + } + } + // // search top (k + 1) neighbors // @@ -2217,6 +2236,7 @@ index build( knn_build_params = cagra::graph_build_params::ivf_pq_params(dataset.extents(), params.metric); } } + RAFT_EXPECTS( params.metric != cuvs::distance::DistanceType::BitwiseHamming || std::holds_alternative( diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_fp16_overflow.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_fp16_overflow.cuh new file mode 100644 index 0000000000..4cd239af5d --- /dev/null +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_fp16_overflow.cuh @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../detail/ann_utils.cuh" // cuvs::spatial::knn::detail::utils::mapping + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace cuvs::neighbors::ivf_pq::helpers { + +/** + * @brief Detect whether FP16 internal distance dtypes overflow for this dataset during search. + * + * Runs a small probe search against an already-built IVF-PQ index with current distance types, + * and reports whether any returned distance is non-finite (inf/NaN). + */ +template +bool detect_fp16_overflow( + raft::resources const& handle, + const cuvs::neighbors::ivf_pq::index& index, + cuvs::neighbors::ivf_pq::search_params search_params, + raft::mdspan, raft::row_major, Accessor> dataset) +{ + const int64_t n_rows = dataset.extent(0); + if (n_rows == 0) { return false; } + + auto stream = raft::resource::get_cuda_stream(handle); + const int64_t dim = dataset.extent(1); + + constexpr int64_t kMaxSampleQueries = 128; + constexpr uint32_t kProbeTopK = 32; + const int64_t n_sample = std::min(n_rows, kMaxSampleQueries); + const uint32_t top_k = std::min(static_cast(n_rows), kProbeTopK); + + auto mr = raft::resource::get_workspace_resource_ref(handle); + auto queries = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_sample, dim)); + raft::copy(queries.data_handle(), dataset.data_handle(), n_sample * dim, stream); + + auto neighbors = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_sample, top_k)); + auto distances = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_sample, top_k)); + + cuvs::neighbors::ivf_pq::search(handle, + search_params, + index, + raft::make_const_mdspan(queries.view()), + neighbors.view(), + distances.view()); + + const int64_t count = n_sample * static_cast(top_k); + auto is_non_finite_op = [] __device__(float v) { return isnan(v) || isinf(v); }; + const bool any_non_finite = thrust::any_of(raft::resource::get_thrust_policy(handle), + distances.data_handle(), + distances.data_handle() + count, + is_non_finite_op); + raft::resource::sync_stream(handle); + return any_non_finite; +} + +} // namespace cuvs::neighbors::ivf_pq::helpers