From 41b0713736eb5a8062cc9ab7979ecf9aa933f83e Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 17 Jun 2026 20:58:30 +0000 Subject: [PATCH 01/10] add example --- conda/recipes/libcuvs/recipe.yaml | 5 + cpp/tests/CMakeLists.txt | 2 + cpp/tests/cutile/CMakeLists.txt | 23 ++++ cpp/tests/cutile/cutile_vector_add.cu | 128 ++++++++++++++++++ cpp/tests/cutile/export_vector_add_cubin.py | 101 ++++++++++++++ cpp/tests/cutile/generate_cutile_cubins.cmake | 90 ++++++++++++ cpp/tests/cutile/vector_add_kernel.py | 17 +++ dependencies.yaml | 3 + 8 files changed, 369 insertions(+) create mode 100644 cpp/tests/cutile/CMakeLists.txt create mode 100644 cpp/tests/cutile/cutile_vector_add.cu create mode 100644 cpp/tests/cutile/export_vector_add_cubin.py create mode 100644 cpp/tests/cutile/generate_cutile_cubins.cmake create mode 100644 cpp/tests/cutile/vector_add_kernel.py diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index aa7a37db44..93f31f8cf2 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -80,6 +80,7 @@ cache: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -117,6 +118,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -179,6 +181,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -240,6 +243,7 @@ outputs: - cuda-cudart-dev - cuda-nvrtc-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -299,6 +303,7 @@ outputs: - openblas # required by some CPU algos in benchmarks - cuda-cudart-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9b96f94bf0..ba6ed6e0e7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -386,6 +386,8 @@ ConfigureTest( PERCENT 100 ) +add_subdirectory(cutile) + # ################################################################################################## # Install tests #################################################################################### # ################################################################################################## diff --git a/cpp/tests/cutile/CMakeLists.txt b/cpp/tests/cutile/CMakeLists.txt new file mode 100644 index 0000000000..989c8137d0 --- /dev/null +++ b/cpp/tests/cutile/CMakeLists.txt @@ -0,0 +1,23 @@ +# ============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +include("${CMAKE_CURRENT_LIST_DIR}/generate_cutile_cubins.cmake") + +generate_cutile_vector_add_cubins(CUTILE_GENERATED_INCLUDE_DIR) + +ConfigureTest( + NAME CUTILE_VECTOR_ADD_TEST + PATH "${CMAKE_CURRENT_LIST_DIR}/cutile_vector_add.cu" + GPUS 1 + PERCENT 100 +) + +add_dependencies(CUTILE_VECTOR_ADD_TEST cutile_vector_add_cubins) + +target_include_directories( + CUTILE_VECTOR_ADD_TEST PRIVATE "${CUTILE_GENERATED_INCLUDE_DIR}" +) diff --git a/cpp/tests/cutile/cutile_vector_add.cu b/cpp/tests/cutile/cutile_vector_add.cu new file mode 100644 index 0000000000..77a5e51311 --- /dev/null +++ b/cpp/tests/cutile/cutile_vector_add.cu @@ -0,0 +1,128 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include "vector_add_kernel_symbol.h" +#include "vector_add_sm_100_cubin.h" +#include "vector_add_sm_120_cubin.h" +#include "vector_add_sm_80_cubin.h" +#include "vector_add_sm_86_cubin.h" +#include "vector_add_sm_90_cubin.h" + +#include + +#include + +namespace cuvs { +namespace { + +struct EmbeddedCubin { + int cc_major; + int cc_minor; + const unsigned char* data; + size_t size; +}; + +// Lookup table for cubins built at configure time (see export_vector_add_cubin.py). +constexpr EmbeddedCubin kEmbeddedCubins[] = { + {8, 0, vector_add_sm_80_cubin, sizeof(vector_add_sm_80_cubin)}, + {8, 6, vector_add_sm_86_cubin, sizeof(vector_add_sm_86_cubin)}, + {9, 0, vector_add_sm_90_cubin, sizeof(vector_add_sm_90_cubin)}, + {10, 0, vector_add_sm_100_cubin, sizeof(vector_add_sm_100_cubin)}, + {12, 0, vector_add_sm_120_cubin, sizeof(vector_add_sm_120_cubin)}, +}; + +const EmbeddedCubin* find_embedded_cubin(int cc_major, int cc_minor) +{ + for (const auto& entry : kEmbeddedCubins) { + if (entry.cc_major == cc_major && entry.cc_minor == cc_minor) { return &entry; } + } + // Fall back to a cubin for the same major version (e.g. minor SKUs within a generation). + for (const auto& entry : kEmbeddedCubins) { + if (entry.cc_major == cc_major) { return &entry; } + } + return nullptr; +} + +class CutileVectorAddTest : public ::testing::Test { + protected: + void SetUp() override + { + int device = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&cc_major_, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&cc_minor_, cudaDevAttrComputeCapabilityMinor, device)); + } + + int cc_major_{}; + int cc_minor_{}; +}; + +} // namespace + +TEST_F(CutileVectorAddTest, EmbeddedCubinVectorAdd) +{ + const EmbeddedCubin* cubin = find_embedded_cubin(cc_major_, cc_minor_); + ASSERT_NE(cubin, nullptr) + << "No embedded cuTile cubin for compute capability " << cc_major_ << "." << cc_minor_; + + cudaLibrary_t library{}; + ASSERT_EQ(cudaSuccess, + cudaLibraryLoadData( + &library, cubin->data, nullptr, nullptr, 0, nullptr, nullptr, 0)) + << "cudaLibraryLoadData failed: " << cudaGetErrorString(cudaGetLastError()); + + cudaKernel_t kernel{}; + ASSERT_EQ(cudaSuccess, + cudaLibraryGetKernel(&kernel, library, CUTILE_VECTOR_ADD_KERNEL_SYMBOL)) + << "cudaLibraryGetKernel failed: " << cudaGetErrorString(cudaGetLastError()); + + constexpr int kN = 1024; + constexpr int kTile = 256; + constexpr int kGridDim = (kN + kTile - 1) / kTile; + + float *d_a = nullptr, *d_b = nullptr, *d_c = nullptr; + RAFT_CUDA_TRY(cudaMalloc(&d_a, kN * sizeof(float))); + RAFT_CUDA_TRY(cudaMalloc(&d_b, kN * sizeof(float))); + RAFT_CUDA_TRY(cudaMalloc(&d_c, kN * sizeof(float))); + + std::vector h_a(kN), h_b(kN); + for (int i = 0; i < kN; ++i) { + h_a[i] = static_cast(i); + h_b[i] = static_cast(i * 2); + } + RAFT_CUDA_TRY(cudaMemcpy(d_a, h_a.data(), kN * sizeof(float), cudaMemcpyHostToDevice)); + RAFT_CUDA_TRY(cudaMemcpy(d_b, h_b.data(), kN * sizeof(float), cudaMemcpyHostToDevice)); + RAFT_CUDA_TRY(cudaMemset(d_c, 0, kN * sizeof(float))); + + int64_t shape = kN; + int64_t stride = 1; + void* kernel_args[] = { + &d_a, &shape, &stride, &d_b, &shape, &stride, &d_c, &shape, &stride, + }; + + dim3 grid(kGridDim); + dim3 block(1); + ASSERT_EQ(cudaSuccess, cudaLaunchKernel(kernel, grid, block, kernel_args, 0, 0)) + << "cudaLaunchKernel failed: " << cudaGetErrorString(cudaGetLastError()); + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + std::vector h_c(kN); + RAFT_CUDA_TRY(cudaMemcpy(h_c.data(), d_c, kN * sizeof(float), cudaMemcpyDeviceToHost)); + + for (int i = 0; i < kN; ++i) { + ASSERT_FLOAT_EQ(h_a[i] + h_b[i], h_c[i]) << "@" << i; + } + + RAFT_CUDA_TRY(cudaFree(d_a)); + RAFT_CUDA_TRY(cudaFree(d_b)); + RAFT_CUDA_TRY(cudaFree(d_c)); + RAFT_CUDA_TRY(cudaLibraryUnload(library)); +} + +} // namespace cuvs diff --git a/cpp/tests/cutile/export_vector_add_cubin.py b/cpp/tests/cutile/export_vector_add_cubin.py new file mode 100644 index 0000000000..bf40a4ad80 --- /dev/null +++ b/cpp/tests/cutile/export_vector_add_cubin.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""Export the cuTile vector-add kernel to a cubin for a single GPU target.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import cuda.tile as ct +from cuda.tile.compilation import ( + ArrayConstraint, + CallingConvention, + ConstantConstraint, + KernelSignature, + export_kernel, +) + +from vector_add_kernel import TILE_SIZE, vector_add + +# cuTile / tileiras gpu_code values used at build time. These correspond to the +# cuvs library CUDA 13 real targets as follows (tileiras has no sm_*a/sm_*f names): +# sm_80 -> 80-real +# sm_86 -> 86-real +# sm_90 -> 90a-real +# sm_100 -> 100f-real +# sm_120 -> 120a-real +SUPPORTED_GPU_CODES = ("sm_80", "sm_86", "sm_90", "sm_100", "sm_120") + + +def _kernel_signature() -> KernelSignature: + array = ArrayConstraint( + ct.float32, + 1, + index_dtype=ct.int64, + stride_lower_bound_incl=0, + alias_groups=(), + may_alias_internally=False, + stride_constant=(1,), + ) + return KernelSignature( + parameters=[array, array, array, ConstantConstraint(TILE_SIZE)], + calling_convention=CallingConvention.cutile_python_v1(), + ).with_mangled_symbol("vector_add") + + +def export_cubin(output_file: Path, gpu_code: str, symbol_header: Path | None) -> str: + if gpu_code not in SUPPORTED_GPU_CODES: + raise ValueError( + f"Unsupported gpu_code {gpu_code!r}; expected one of {SUPPORTED_GPU_CODES}" + ) + + signature = _kernel_signature() + export_kernel( + vector_add, + signatures=[signature], + output_file=str(output_file), + gpu_code=gpu_code, + output_format="cubin", + ) + + if symbol_header is not None: + symbol_header.write_text( + "\n".join( + [ + "// Generated by export_vector_add_cubin.py; do not edit.", + "#pragma once", + f'#define CUTILE_VECTOR_ADD_KERNEL_SYMBOL "{signature.symbol}"', + "", + ] + ) + ) + + return signature.symbol + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("output_file", type=Path, help="Output cubin path") + parser.add_argument( + "--gpu-code", + required=True, + choices=SUPPORTED_GPU_CODES, + help="tileiras / export_kernel target (e.g. sm_120)", + ) + parser.add_argument( + "--symbol-header", + type=Path, + default=None, + help="Optional header that defines CUTILE_VECTOR_ADD_KERNEL_SYMBOL", + ) + args = parser.parse_args() + + symbol = export_cubin(args.output_file, args.gpu_code, args.symbol_header) + print(symbol) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cpp/tests/cutile/generate_cutile_cubins.cmake b/cpp/tests/cutile/generate_cutile_cubins.cmake new file mode 100644 index 0000000000..3425b03028 --- /dev/null +++ b/cpp/tests/cutile/generate_cutile_cubins.cmake @@ -0,0 +1,90 @@ +# ============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +include_guard(GLOBAL) + +# Build-time cuTile cubin targets. Maps to cuvs CUDA 13 -real library arches (75-real omitted). +set(CUTILE_VECTOR_ADD_GPU_CODES sm_80 sm_86 sm_90 sm_100 sm_120) + +function(generate_cutile_vector_add_cubins output_include_dir_var) + find_package(Python3 REQUIRED COMPONENTS Interpreter) + find_package(CUDAToolkit REQUIRED) + + find_program( + CUTILE_BIN2C + NAMES bin2c + PATHS ${CUDAToolkit_BIN_DIR} + REQUIRED + ) + + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import cuda.tile" + RESULT_VARIABLE _cutile_import_result + OUTPUT_QUIET + ERROR_QUIET + ) + if(NOT _cutile_import_result EQUAL 0) + message( + FATAL_ERROR + "cuda.tile (cuTile Python) is required to build CUTILE_VECTOR_ADD_TEST. " + "Install it in the active Python environment, e.g. pip install cuda-tile[tileiras]." + ) + endif() + + set(_cutile_source_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}") + set(_cutile_binary_dir "${CMAKE_CURRENT_BINARY_DIR}/cutile_generated") + file(MAKE_DIRECTORY "${_cutile_binary_dir}") + + set(_symbol_header "${_cutile_binary_dir}/vector_add_kernel_symbol.h") + set(_first_gpu_code TRUE) + + foreach(_gpu_code IN LISTS CUTILE_VECTOR_ADD_GPU_CODES) + set(_cubin_file "${_cutile_binary_dir}/vector_add_${_gpu_code}.cubin") + set(_cubin_header "${_cutile_binary_dir}/vector_add_${_gpu_code}_cubin.h") + + if(_first_gpu_code) + set(_symbol_arg --symbol-header "${_symbol_header}") + set(_cubin_outputs "${_cubin_file}" "${_symbol_header}") + set(_first_gpu_code FALSE) + else() + set(_symbol_arg) + set(_cubin_outputs "${_cubin_file}") + endif() + + add_custom_command( + OUTPUT ${_cubin_outputs} + COMMAND + "${Python3_EXECUTABLE}" "${_cutile_source_dir}/export_vector_add_cubin.py" + "${_cubin_file}" --gpu-code "${_gpu_code}" ${_symbol_arg} + DEPENDS "${_cutile_source_dir}/export_vector_add_cubin.py" + "${_cutile_source_dir}/vector_add_kernel.py" + COMMENT "Exporting cuTile vector_add cubin for ${_gpu_code}" + VERBATIM + ) + + add_custom_command( + OUTPUT "${_cubin_header}" + COMMAND "${CUTILE_BIN2C}" --const --name "vector_add_${_gpu_code}_cubin" --static + "${_cubin_file}" > "${_cubin_header}" + DEPENDS "${_cubin_file}" + COMMENT "Embedding vector_add ${_gpu_code} cubin via bin2c" + VERBATIM + ) + + list(APPEND _generated_headers "${_cubin_header}") + endforeach() + + add_custom_target( + cutile_vector_add_cubins + DEPENDS "${_symbol_header}" ${_generated_headers} + ) + + set(${output_include_dir_var} + "${_cutile_binary_dir}" + PARENT_SCOPE + ) +endfunction() diff --git a/cpp/tests/cutile/vector_add_kernel.py b/cpp/tests/cutile/vector_add_kernel.py new file mode 100644 index 0000000000..46b7a607c6 --- /dev/null +++ b/cpp/tests/cutile/vector_add_kernel.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""cuTile Python vector-add kernel used by the embedded-cubin example test.""" + +from __future__ import annotations + +import cuda.tile as ct + +TILE_SIZE = 256 + + +@ct.kernel +def vector_add(a, b, c, TILE_SIZE: ct.Constant): + bid = ct.bid(0) + ta = ct.load(a, bid, TILE_SIZE) + tb = ct.load(b, bid, TILE_SIZE) + ct.store(c, bid, ta + tb) diff --git a/dependencies.yaml b/dependencies.yaml index 744e4d9227..756041e60c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -395,6 +395,7 @@ dependencies: - cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api + - cutile-python - libcublas-dev - libcurand-dev - libcusolver-dev @@ -430,12 +431,14 @@ dependencies: packages: - &ctk_cu13 cuda-toolkit[cublas,curand,cusolver,cusparse,nvrtc]==13.* - &nvjitlink_cu13 nvidia-nvjitlink>=13.0,<14 + - &cutile_cu13 cuda-tile[tileiras] # if no matching matrix selectors passed, list the CUDA 13 requirement # (just as a source of documentation, as this populates pyproject.toml in source control) - matrix: packages: - *ctk_cu13 - *nvjitlink_cu13 + - *cutile_cu13 depends_on_cudart: common: - output_types: conda From b10c02ca5a5ef094ca892c6d32f6f14d6d63447f Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 24 Jun 2026 16:34:11 +0000 Subject: [PATCH 02/10] initial integration --- cpp/CMakeLists.txt | 50 ++- .../modules/generate_cutile_kernels.cmake | 315 ++++++++++++++++++ cpp/cmake/modules/register_cubin.cpp.in | 22 ++ cpp/cmake/modules/register_tileir.cpp.in | 22 ++ .../cuvs/detail/jit_lto/AlgorithmPlanner.hpp | 63 +++- .../cuvs/detail/jit_lto/FragmentEntry.hpp | 63 ++++ .../cuvs/detail/jit_lto/cutile_arch_tags.hpp | 52 +++ .../cuvs/detail/jit_lto/cutile_module.hpp | 75 +++++ .../fused_distance_nn/fused_1nn_fragments.hpp | 21 ++ .../cuvs/detail/jit_lto/tileir_compat.hpp | 99 ++++++ cpp/src/detail/jit_lto/AlgorithmPlanner.cpp | 103 ++---- .../detail/jit_lto/LTOAlgorithmPlanner.cpp | 76 +++++ .../detail/jit_lto/TileAlgorithmPlanner.cpp | 38 +++ cpp/src/distance/detail/fused_distance_nn.cuh | 15 + .../cutile/export_fused_1nn.py | 136 ++++++++ .../cutile/fused_1nn_cutile_cubin_matrix.json | 40 +++ .../fused_1nn_cutile_tileir_matrix.json | 20 ++ .../cutile/fused_1nn_kernel.py | 68 ++++ .../cutile/fused_1nn_planner.hpp | 60 ++++ .../cutile/fused_1nn_tile.cu | 173 ++++++++++ .../cutile/fused_1nn_tile.hpp | 55 +++ .../pairwise_matrix_planner.hpp | 4 +- .../jit_lto_kernels/cagra_planner_base.hpp | 4 +- .../interleaved_scan_planner.hpp | 4 +- .../compute_similarity_planner.hpp | 4 +- .../detail/jit_lto_kernels/scan_planner.hpp | 4 +- cpp/tests/CMakeLists.txt | 3 +- cpp/tests/cutile/cutile_vector_add.cu | 176 ++++++++-- cpp/tests/cutile/export_vector_add_cubin.py | 58 +++- cpp/tests/cutile/generate_cutile_cubins.cmake | 27 ++ cpp/tests/neighbors/distance_nn.cu | 1 + cpp/tests/neighbors/distance_nn_helper.cuh | 45 ++- 32 files changed, 1743 insertions(+), 153 deletions(-) create mode 100644 cpp/cmake/modules/generate_cutile_kernels.cmake create mode 100644 cpp/cmake/modules/register_cubin.cpp.in create mode 100644 cpp/cmake/modules/register_tileir.cpp.in create mode 100644 cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp create mode 100644 cpp/include/cuvs/detail/jit_lto/cutile_module.hpp create mode 100644 cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp create mode 100644 cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp create mode 100644 cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp create mode 100644 cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..cc6e1975b3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -957,6 +957,47 @@ if(NOT BUILD_CPU_ONLY) OUTPUT_FILE_FORMAT "${CMAKE_CURRENT_BINARY_DIR}/src/distance/detail/pairwise_matrix/dispatch_rbf_inst_data_@data_abbrev@_acc_@acc_abbrev@_out_@out_abbrev@_index_@index_abbrev@_op_@op_abbrev@.cu" ) + + include(cmake/modules/generate_cutile_kernels.cmake) + set(fused_1nn_cutile_dir + "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/fused_distance_nn/cutile") + set(cutile_fused_1nn_generated_dir + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/fused_1nn/cutile") + generate_cutile_cubin_kernels( + cutile_fused_1nn_files + KERNEL_DIR "${fused_1nn_cutile_dir}" + KERNEL_BASENAME "fused_1nn" + KERNEL_PYTHON "fused_1nn_kernel.py" + EXPORT_SCRIPT "export_fused_1nn.py" + OUTPUT_DIRECTORY "${cutile_fused_1nn_generated_dir}" + MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_cubin_matrix.json" + FRAGMENT_TAG_FORMAT + "cuvs::distance::detail::fragment_tag_fused_1nn_cubin" + FRAGMENT_TAG_HEADER_FILES + "" + "" + "" + ) + generate_cutile_tileir_kernels( + cutile_fused_1nn_files + KERNEL_DIR "${fused_1nn_cutile_dir}" + KERNEL_BASENAME "fused_1nn" + KERNEL_PYTHON "fused_1nn_kernel.py" + EXPORT_SCRIPT "export_fused_1nn.py" + OUTPUT_DIRECTORY "${cutile_fused_1nn_generated_dir}" + MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_tileir_matrix.json" + FRAGMENT_TAG_FORMAT + "cuvs::distance::detail::fragment_tag_fused_1nn_tileir" + FRAGMENT_TAG_HEADER_FILES + "" + "" + ) + if(NOT DEFINED CUVS_CUTILE_ENABLED) + set(CUVS_CUTILE_ENABLED 0) + endif() + target_compile_definitions( + cuvs_cpp_headers INTERFACE CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} + ) generate_inst_matrix( cagra_build_inst_files MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/cagra_build_matrix.json" @@ -1147,6 +1188,8 @@ if(NOT BUILD_CPU_ONLY) src/util/host_memory.cpp src/detail/jit_lto/AlgorithmLauncher.cpp src/detail/jit_lto/AlgorithmPlanner.cpp + src/detail/jit_lto/LTOAlgorithmPlanner.cpp + src/detail/jit_lto/TileAlgorithmPlanner.cpp src/detail/jit_lto/FragmentEntry.cpp src/detail/jit_lto/nvjitlink_checker.cpp src/detail/jit_lto/NVRTCLTOFragmentCompiler.cpp @@ -1234,6 +1277,8 @@ if(NOT BUILD_CPU_ONLY) src/stats/trustworthiness_score.cu ${CUVS_MG_ALGOS} ${jit_lto_files} + ${cutile_fused_1nn_files} + $<$:src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu> ) set_target_properties( @@ -1257,6 +1302,7 @@ if(NOT BUILD_CPU_ONLY) target_compile_definitions( cuvs_objs PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> $<$:NVTX_ENABLED> + CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} ) target_link_libraries( @@ -1274,7 +1320,9 @@ if(NOT BUILD_CPU_ONLY) PUBLIC "$" "$" INTERFACE "$" - PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_BINARY_DIR}/src" + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" + "${CMAKE_CURRENT_BINARY_DIR}/src" + "${cutile_fused_1nn_generated_dir}" ) # Endian detection diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake new file mode 100644 index 0000000000..7b9c2521c4 --- /dev/null +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -0,0 +1,315 @@ +# ============================================================================= +# cmake-format: off +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +# cmake-format: on +# ============================================================================= + +include_guard(GLOBAL) + +include(${CMAKE_CURRENT_LIST_DIR}/compute_matrix_product.cmake) + +function(generate_cutile_kernels_stub) + set(CUVS_CUTILE_ENABLED 0 PARENT_SCOPE) +endfunction() + +function(_cutile_fragment_tag_header_files output_var) + set(${output_var} "") + foreach(_header IN LISTS ARGN) + if(NOT _header MATCHES "^(\".*\"|<.*>)$") + set(_header "\"${_header}\"") + endif() + string(APPEND ${output_var} "#include ${_header}\n") + endforeach() + set(${output_var} + "${${output_var}}" + PARENT_SCOPE + ) +endfunction() + +function(_cutile_kernels_setup) + set(options) + set(one_value MATRIX_JSON_FILE OUTPUT_DIRECTORY) + set(multi_value) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + find_package(Python3 REQUIRED COMPONENTS Interpreter) + find_package(CUDAToolkit REQUIRED) + + if(CUDAToolkit_VERSION VERSION_LESS 13.0) + message( + STATUS + "cuTile embedded kernels require CUDA 13.0+; skipping cuTile generation (found ${CUDAToolkit_VERSION})." + ) + set(_CUTILE_SETUP_OK + FALSE + PARENT_SCOPE + ) + return() + endif() + + find_program( + CUTILE_BIN2C + NAMES bin2c + PATHS ${CUDAToolkit_BIN_DIR} + REQUIRED + ) + + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import cuda.tile" + RESULT_VARIABLE _cutile_import_result + OUTPUT_QUIET + ERROR_QUIET + ) + if(NOT _cutile_import_result EQUAL 0) + message( + FATAL_ERROR + "cuda.tile (cuTile Python) is required to build cuTile embedded kernels. " + "Install it in the active Python environment, e.g. pip install cuda-tile[tileiras]." + ) + endif() + + set_property( + DIRECTORY + PROPERTY CMAKE_CONFIGURE_DEPENDS "${_CUTILE_MATRIX_JSON_FILE}" + APPEND + ) + + file(MAKE_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}") + + set(_CUTILE_SETUP_OK + TRUE + PARENT_SCOPE + ) +endfunction() + +function(process_cutile_cubin_matrix_entry source_list_var) + set(options) + set(one_value + KERNEL_DIR + KERNEL_BASENAME + KERNEL_PYTHON + EXPORT_SCRIPT + OUTPUT_DIRECTORY + FRAGMENT_TAG_FORMAT + MATRIX_JSON_ENTRY + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") + _cutile_fragment_tag_header_files( + fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + ) + + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY) + + set(_artifact_basename "${_CUTILE_KERNEL_BASENAME}_${data_type}_${gpu_code}") + set(_cubin_file "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}.cubin") + set(_cubin_header "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}_cubin.h") + set(_cubin_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}_cubin.cpp") + set(cubin_header_file "${_artifact_basename}_cubin.h") + + add_custom_command( + OUTPUT "${_cubin_file}" + COMMAND + "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" "${_cubin_file}" + --format cubin --data-type "${data_type}" --gpu-code "${gpu_code}" + DEPENDS "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" + "${_CUTILE_KERNEL_DIR}/${_CUTILE_KERNEL_PYTHON}" + COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} cubin ${data_type} ${gpu_code}" + VERBATIM + ) + + add_custom_command( + OUTPUT "${_cubin_header}" + COMMAND "${CUTILE_BIN2C}" --const --name embedded_cubin --static "${_cubin_file}" + > "${_cubin_header}" + DEPENDS "${_cubin_file}" + VERBATIM + ) + + configure_file( + "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_cubin.cpp.in" "${_cubin_cpp}" @ONLY + ) + list(APPEND ${source_list_var} "${_cubin_header}" "${_cubin_cpp}") + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() + +function(process_cutile_tileir_matrix_entry source_list_var) + set(options) + set(one_value + KERNEL_DIR + KERNEL_BASENAME + KERNEL_PYTHON + EXPORT_SCRIPT + OUTPUT_DIRECTORY + FRAGMENT_TAG_FORMAT + MATRIX_JSON_ENTRY + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") + _cutile_fragment_tag_header_files( + fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + ) + + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY) + set(_tileir_file "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}.tilebc") + set(_tileir_header "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.h") + set(_tileir_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.cpp") + set(tileir_header_file "${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.h") + + add_custom_command( + OUTPUT "${_tileir_file}" + COMMAND + "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" "${_tileir_file}" + --format tileir_bytecode --data-type "${data_type}" --gpu-code "${export_gpu_code}" + --bytecode-version "${bytecode_version}" + DEPENDS "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" + "${_CUTILE_KERNEL_DIR}/${_CUTILE_KERNEL_PYTHON}" + COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} TileIR bytecode ${data_type}" + VERBATIM + ) + + add_custom_command( + OUTPUT "${_tileir_header}" + COMMAND "${CUTILE_BIN2C}" --const --name embedded_tileir --static "${_tileir_file}" + > "${_tileir_header}" + DEPENDS "${_tileir_file}" + VERBATIM + ) + + configure_file( + "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_tileir.cpp.in" "${_tileir_cpp}" @ONLY + ) + list(APPEND ${source_list_var} "${_tileir_header}" "${_tileir_cpp}") + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() + +function(generate_cutile_cubin_kernels source_list_var) + set(options) + set(one_value + KERNEL_DIR + KERNEL_BASENAME + KERNEL_PYTHON + EXPORT_SCRIPT + OUTPUT_DIRECTORY + MATRIX_JSON_FILE + FRAGMENT_TAG_FORMAT + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + if(NOT _CUTILE_KERNEL_BASENAME) + message(FATAL_ERROR "generate_cutile_cubin_kernels: KERNEL_BASENAME is required") + endif() + if(NOT _CUTILE_KERNEL_PYTHON) + set(_CUTILE_KERNEL_PYTHON "fused_1nn_kernel.py") + endif() + + _cutile_kernels_setup( + MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" + OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + ) + if(NOT _CUTILE_SETUP_OK) + generate_cutile_kernels_stub() + set(${source_list_var} + "" + PARENT_SCOPE + ) + return() + endif() + + compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}") + + string(JSON len LENGTH "${matrix_product}") + math(EXPR last "${len} - 1") + + # cmake-lint: disable=C0103,E1120 + foreach(i RANGE "${last}") + string(JSON matrix_json_entry GET "${matrix_product}" "${i}") + process_cutile_cubin_matrix_entry( + "${source_list_var}" + KERNEL_DIR "${_CUTILE_KERNEL_DIR}" + KERNEL_BASENAME "${_CUTILE_KERNEL_BASENAME}" + KERNEL_PYTHON "${_CUTILE_KERNEL_PYTHON}" + EXPORT_SCRIPT "${_CUTILE_EXPORT_SCRIPT}" + OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + FRAGMENT_TAG_FORMAT "${_CUTILE_FRAGMENT_TAG_FORMAT}" + FRAGMENT_TAG_HEADER_FILES ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + MATRIX_JSON_ENTRY "${matrix_json_entry}" + ) + endforeach() + + set(CUVS_CUTILE_ENABLED 1 PARENT_SCOPE) + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() + +function(generate_cutile_tileir_kernels source_list_var) + set(options) + set(one_value + KERNEL_DIR + KERNEL_BASENAME + KERNEL_PYTHON + EXPORT_SCRIPT + OUTPUT_DIRECTORY + MATRIX_JSON_FILE + FRAGMENT_TAG_FORMAT + ) + set(multi_value FRAGMENT_TAG_HEADER_FILES) + cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + + if(NOT _CUTILE_KERNEL_BASENAME) + message(FATAL_ERROR "generate_cutile_tileir_kernels: KERNEL_BASENAME is required") + endif() + if(NOT _CUTILE_KERNEL_PYTHON) + set(_CUTILE_KERNEL_PYTHON "fused_1nn_kernel.py") + endif() + + _cutile_kernels_setup( + MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" + OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + ) + if(NOT _CUTILE_SETUP_OK) + generate_cutile_kernels_stub() + return() + endif() + + compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}") + + string(JSON len LENGTH "${matrix_product}") + math(EXPR last "${len} - 1") + + # cmake-lint: disable=C0103,E1120 + foreach(i RANGE "${last}") + string(JSON matrix_json_entry GET "${matrix_product}" "${i}") + process_cutile_tileir_matrix_entry( + "${source_list_var}" + KERNEL_DIR "${_CUTILE_KERNEL_DIR}" + KERNEL_BASENAME "${_CUTILE_KERNEL_BASENAME}" + KERNEL_PYTHON "${_CUTILE_KERNEL_PYTHON}" + EXPORT_SCRIPT "${_CUTILE_EXPORT_SCRIPT}" + OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + FRAGMENT_TAG_FORMAT "${_CUTILE_FRAGMENT_TAG_FORMAT}" + FRAGMENT_TAG_HEADER_FILES ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + MATRIX_JSON_ENTRY "${matrix_json_entry}" + ) + endforeach() + + set(CUVS_CUTILE_ENABLED 1 PARENT_SCOPE) + set(${source_list_var} + "${${source_list_var}}" + PARENT_SCOPE + ) +endfunction() diff --git a/cpp/cmake/modules/register_cubin.cpp.in b/cpp/cmake/modules/register_cubin.cpp.in new file mode 100644 index 0000000000..c27d6829ee --- /dev/null +++ b/cpp/cmake/modules/register_cubin.cpp.in @@ -0,0 +1,22 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "@cubin_header_file@" +#include + +@fragment_tag_header_files@ + +namespace { + +using fragment_tag = @fragment_tag@; +using fragment_entry = StaticCubinFragmentEntry; + +} // namespace + +template <> +const uint8_t* const fragment_entry::data = embedded_cubin; + +template <> +const size_t fragment_entry::length = sizeof(embedded_cubin); diff --git a/cpp/cmake/modules/register_tileir.cpp.in b/cpp/cmake/modules/register_tileir.cpp.in new file mode 100644 index 0000000000..fb81acedbc --- /dev/null +++ b/cpp/cmake/modules/register_tileir.cpp.in @@ -0,0 +1,22 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "@tileir_header_file@" +#include + +@fragment_tag_header_files@ + +namespace { + +using fragment_tag = @fragment_tag@; +using fragment_entry = StaticTileIrBytecodeFragmentEntry; + +} // namespace + +template <> +const uint8_t* const fragment_entry::data = embedded_tileir; + +template <> +const size_t fragment_entry::length = sizeof(embedded_tileir); diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp index 7f275b1285..d727c73b9d 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -19,6 +20,7 @@ struct LauncherJitCache { std::shared_mutex mutex; std::unordered_map> launchers; + std::unordered_set build_failed; }; struct AlgorithmPlanner { @@ -27,9 +29,32 @@ struct AlgorithmPlanner { { } + virtual ~AlgorithmPlanner() = default; + std::shared_ptr get_launcher(); + /** Returns nullptr when no module can be loaded for the current device (does not RAFT_FAIL). */ + std::shared_ptr try_get_launcher(); + std::string entrypoint; + + protected: + virtual std::shared_ptr build() = 0; + + virtual std::string get_planner_key() const = 0; + + std::shared_ptr read_cache(std::string const& launch_key) const; + + LauncherJitCache& jit_cache_; +}; + +/** Links embedded LTO fatbin fragments at runtime via nvJitLink. */ +struct LTOAlgorithmPlanner : AlgorithmPlanner { + LTOAlgorithmPlanner(std::string entrypoint, LauncherJitCache& jit_cache) + : AlgorithmPlanner(std::move(entrypoint), jit_cache) + { + } + std::vector> fragments; template >> @@ -45,16 +70,38 @@ struct AlgorithmPlanner { } protected: - /** Extra link-time option strings passed to nvJitLink. Base build() - * always passes "-lto" and "-arch=sm_XX" first; derived planners may append here in their - * constructor body. */ + /** Extra link-time option strings passed to nvJitLink. */ std::vector linktime_extra_options; - private: - std::string get_fragments_key() const; - std::shared_ptr build(); + std::string get_planner_key() const override; - std::shared_ptr read_cache(std::string const& launch_key) const; + std::shared_ptr build() override; +}; - LauncherJitCache& jit_cache_; +/** Loads prebuilt cubins or TileIR bytecode via cudaLibraryLoadData. */ +struct TileAlgorithmPlanner : AlgorithmPlanner { + TileAlgorithmPlanner(std::string entrypoint, LauncherJitCache& jit_cache) + : AlgorithmPlanner(std::move(entrypoint), jit_cache) + { + } + + template + void add_static_fragment() + { + cubin_fragments_.push_back(std::make_unique>()); + } + + template + void add_static_tileir_fragment() + { + tileir_fragment_ = std::make_unique>(); + } + + protected: + std::vector> cubin_fragments_; + std::unique_ptr tileir_fragment_; + + std::string get_planner_key() const override; + + std::shared_ptr build() override; }; diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index 35aa46633c..df69ec1d7b 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -62,3 +62,66 @@ struct UDFFatbinFragment final : FatbinFragmentEntry { std::string key_; std::vector bytes_; }; + +/** Embedded CUDA binary module (cubin), loaded directly via cudaLibraryLoadData. */ +struct CubinFragmentEntry { + virtual ~CubinFragmentEntry() = default; + + virtual const uint8_t* get_data() const = 0; + + virtual size_t get_length() const = 0; + + virtual const char* get_key() const = 0; + + virtual int get_cc_major() const = 0; + + virtual int get_cc_minor() const = 0; +}; + +template +struct StaticCubinFragmentEntry final : CubinFragmentEntry { + const uint8_t* get_data() const override { return StaticCubinFragmentEntry::data; } + + size_t get_length() const override { return StaticCubinFragmentEntry::length; } + + const char* get_key() const override + { + return typeid(StaticCubinFragmentEntry).name(); + } + + int get_cc_major() const override { return FragmentTag::cc_major; } + + int get_cc_minor() const override { return FragmentTag::cc_minor; } + + static const uint8_t* const data; + static const size_t length; +}; + +/** Embedded TileIR bytecode, JIT-compiled by the driver when no matching cubin exists. */ +struct TileIrBytecodeFragmentEntry { + virtual ~TileIrBytecodeFragmentEntry() = default; + + virtual const uint8_t* get_data() const = 0; + + virtual size_t get_length() const = 0; + + virtual const char* get_key() const = 0; +}; + +template +struct StaticTileIrBytecodeFragmentEntry final : TileIrBytecodeFragmentEntry { + const uint8_t* get_data() const override + { + return StaticTileIrBytecodeFragmentEntry::data; + } + + size_t get_length() const override { return StaticTileIrBytecodeFragmentEntry::length; } + + const char* get_key() const override + { + return typeid(StaticTileIrBytecodeFragmentEntry).name(); + } + + static const uint8_t* const data; + static const size_t length; +}; diff --git a/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp b/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp new file mode 100644 index 0000000000..2c915a278b --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cutile_arch_tags.hpp @@ -0,0 +1,52 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + +namespace cuvs::detail::jit_lto { + +#if CUVS_CUTILE_ENABLED + +/** Must stay in sync with cuTile matrix _arch entries and planner add_static_fragment calls. */ +struct cutile_arch_8_0 { + static constexpr int cc_major = 8; + static constexpr int cc_minor = 0; +}; + +struct cutile_arch_8_6 { + static constexpr int cc_major = 8; + static constexpr int cc_minor = 6; +}; + +struct cutile_arch_9_0 { + static constexpr int cc_major = 9; + static constexpr int cc_minor = 0; +}; + +struct cutile_arch_12_0 { + static constexpr int cc_major = 12; + static constexpr int cc_minor = 0; +}; + +inline bool is_embedded_cubin_arch(int cc_major, int cc_minor) +{ + if (cc_major == 8 && cc_minor == 0) { return true; } + if (cc_major == 8 && cc_minor == 6) { return true; } + if (cc_major == 9 && cc_minor == 0) { return true; } + if (cc_major == 12 && cc_minor == 0) { return true; } + return false; +} + +#else + +inline bool is_embedded_cubin_arch(int, int) { return false; } + +#endif + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp b/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp new file mode 100644 index 0000000000..dff0f472a7 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/cutile_module.hpp @@ -0,0 +1,75 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace cuvs::detail::jit_lto { + +struct CutileModuleImage { + const uint8_t* data; + size_t size; +}; + +inline bool get_device_compute_capability(int& cc_major, int& cc_minor) +{ + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { return false; } + if (cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device) != cudaSuccess) { + return false; + } + if (cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device) != cudaSuccess) { + return false; + } + return true; +} + +/** Selects a prebuilt cubin for the device CC, or embedded TileIR when the driver can JIT it. */ +inline std::optional resolve_cutile_module_image( + int cc_major, + int cc_minor, + int driver_version, + const std::vector>& cubin_fragments, + const TileIrBytecodeFragmentEntry* tileir_fragment) +{ + for (const auto& fragment : cubin_fragments) { + if (fragment->get_cc_major() == cc_major && fragment->get_cc_minor() == cc_minor) { + return CutileModuleImage{fragment->get_data(), fragment->get_length()}; + } + } + if (tileir_fragment != nullptr && tileir_fallback_available(driver_version)) { + return CutileModuleImage{tileir_fragment->get_data(), tileir_fragment->get_length()}; + } + return std::nullopt; +} + +inline std::shared_ptr load_cutile_launcher(const CutileModuleImage& image, + const std::string& kernel_symbol) +{ + cudaLibrary_t library{}; + RAFT_CUDA_TRY( + cudaLibraryLoadData(&library, image.data, nullptr, nullptr, 0, nullptr, nullptr, 0)); + + cudaKernel_t kernel{}; + RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, kernel_symbol.c_str())); + + return std::make_shared(kernel, library); +} + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp new file mode 100644 index 0000000000..517118bbe2 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp @@ -0,0 +1,21 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::distance::detail { + +template +struct fragment_tag_fused_1nn_cubin { + static constexpr int cc_major = ArchTag::cc_major; + static constexpr int cc_minor = ArchTag::cc_minor; +}; + +template +struct fragment_tag_fused_1nn_tileir {}; + +} // namespace cuvs::distance::detail diff --git a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp new file mode 100644 index 0000000000..d63759fb36 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp @@ -0,0 +1,99 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + +#include +#include + +#include + +namespace cuvs::detail::jit_lto { + +/** Minimum CUDA driver version (from cudaDriverGetVersion) for TileIR JIT of embedded bytecode. */ +inline constexpr int kMinTileIrJitDriverVersion = 13010; // CUDA 13.1 / driver >= 590.44 + +/** Minimum CUDA runtime version (from cudaRuntimeGetVersion) for cuTile integration. */ +inline constexpr int kMinCutileRuntimeVersion = 13000; + +inline constexpr bool library_built_with_cutile() +{ +#if CUVS_CUTILE_ENABLED + return true; +#else + return false; +#endif +} + +inline bool runtime_cuda13_or_newer() +{ + int runtime_version = 0; + if (cudaRuntimeGetVersion(&runtime_version) != cudaSuccess) { return false; } + return runtime_version >= kMinCutileRuntimeVersion; +} + +/** True when this build embeds cuTile artifacts and the runtime is CUDA 13+. */ +inline bool cutile_integration_enabled() +{ + return library_built_with_cutile() && runtime_cuda13_or_newer(); +} + +/** True when this build embeds a prebuilt cubin for the given compute capability. */ +inline bool has_embedded_cubin_for_arch(int cc_major, int cc_minor) +{ + return is_embedded_cubin_arch(cc_major, cc_minor); +} + +/** True when the driver can JIT-compile embedded TileIR bytecode at load time. */ +inline bool tileir_fallback_available(int driver_version) +{ + return driver_version >= kMinTileIrJitDriverVersion; +} + +/** + * True when a cuTile launch may be attempted for the given device: cuTile is enabled, the runtime + * is CUDA 13+, and either a matching embedded cubin exists (no driver JIT required) or the driver + * can JIT the embedded TileIR bytecode fallback. + */ +inline bool cutile_launch_available_for_arch(int cc_major, int cc_minor, int driver_version) +{ + if (!cutile_integration_enabled()) { return false; } + if (has_embedded_cubin_for_arch(cc_major, cc_minor)) { return true; } + return tileir_fallback_available(driver_version); +} + +inline bool query_driver_version(int& driver_version) +{ + return cudaDriverGetVersion(&driver_version) == cudaSuccess; +} + +inline bool query_current_device_arch(int& cc_major, int& cc_minor) +{ + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess) { return false; } + if (cudaDeviceGetAttribute(&cc_major, cudaDevAttrComputeCapabilityMajor, device) != cudaSuccess) { + return false; + } + if (cudaDeviceGetAttribute(&cc_minor, cudaDevAttrComputeCapabilityMinor, device) != cudaSuccess) { + return false; + } + return true; +} + +inline bool cutile_launch_available_on_current_device() +{ + int cc_major = 0; + int cc_minor = 0; + int driver_version = 0; + if (!query_current_device_arch(cc_major, cc_minor)) { return false; } + if (!query_driver_version(driver_version)) { return false; } + return cutile_launch_available_for_arch(cc_major, cc_minor, driver_version); +} + +} // namespace cuvs::detail::jit_lto diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp index 7416ea396d..486d6f1aa5 100644 --- a/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cpp @@ -3,33 +3,16 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include -#include #include #include -#include #include #include -#include #include -#include - -#include "cuda_runtime.h" -#include "nvJitLink.h" #include #include -std::string AlgorithmPlanner::get_fragments_key() const -{ - std::string key = ""; - for (const auto& fragment : this->fragments) { - key += fragment->get_key(); - } - return key; -} - std::shared_ptr AlgorithmPlanner::read_cache(std::string const& launch_key) const { auto& launchers = jit_cache_.launchers; @@ -38,79 +21,37 @@ std::shared_ptr AlgorithmPlanner::read_cache(std::string cons return nullptr; } -std::shared_ptr AlgorithmPlanner::get_launcher() +std::shared_ptr AlgorithmPlanner::try_get_launcher() { - auto& launchers = jit_cache_.launchers; - auto launch_key = this->get_fragments_key(); + auto launch_key = this->get_planner_key(); - if (auto hit = read_cache(launch_key)) { return hit; } + { + std::shared_lock read_lock(jit_cache_.mutex); + if (jit_cache_.build_failed.count(launch_key)) { return nullptr; } + if (auto hit = read_cache(launch_key)) { return hit; } + } std::unique_lock write_lock(jit_cache_.mutex); - if (auto it = launchers.find(launch_key); it != launchers.end()) { return it->second; } + if (jit_cache_.build_failed.count(launch_key)) { return nullptr; } + if (auto it = jit_cache_.launchers.find(launch_key); it != jit_cache_.launchers.end()) { + return it->second; + } - std::string log_message = - "JIT compiling launcher for kernel: " + this->entrypoint + " and device functions: "; - for (const auto& fragment : this->fragments) { - log_message += std::string{fragment->get_key()} + ","; + RAFT_LOG_DEBUG("Building launcher for kernel entrypoint: %s", this->entrypoint.c_str()); + auto launcher = this->build(); + if (!launcher) { + jit_cache_.build_failed.insert(launch_key); + return nullptr; } - log_message.pop_back(); - RAFT_LOG_DEBUG("%s", log_message.c_str()); - auto launcher = this->build(); - launchers[launch_key] = launcher; + jit_cache_.launchers[launch_key] = launcher; return launcher; } -std::shared_ptr AlgorithmPlanner::build() +std::shared_ptr AlgorithmPlanner::get_launcher() { - int device = 0; - int major = 0; - int minor = 0; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - - std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); - - // Load the generated LTO IR and link them together - nvJitLinkHandle handle; - std::vector lopts; - lopts.reserve(2 + linktime_extra_options.size()); - lopts.push_back("-lto"); - lopts.push_back(archs.c_str()); - for (auto const& opt : linktime_extra_options) { - lopts.push_back(opt.c_str()); - } - auto result = nvJitLinkCreate(&handle, static_cast(lopts.size()), lopts.data()); - check_nvjitlink_result(handle, result); - - for (const auto& frag : this->fragments) { - frag->add_to(handle); + auto launcher = try_get_launcher(); + if (!launcher) { + RAFT_FAIL("Failed to build launcher for kernel entrypoint: %s", this->entrypoint.c_str()); } - - // Call to nvJitLinkComplete causes linker to link together all the LTO-IR - // modules perform any optimizations and generate cubin from it. - result = nvJitLinkComplete(handle); - check_nvjitlink_result(handle, result); - - // get cubin from nvJitLink - size_t cubin_size; - result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); - check_nvjitlink_result(handle, result); - - std::unique_ptr cubin{new char[cubin_size]}; - result = nvJitLinkGetLinkedCubin(handle, cubin.get()); - check_nvjitlink_result(handle, result); - - result = nvJitLinkDestroy(&handle); - RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); - - // cubin is linked, so now load it - cudaLibrary_t library; - RAFT_CUDA_TRY( - cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); - - cudaKernel_t kernel; - RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str())); - - return std::make_shared(kernel, library); + return launcher; } diff --git a/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp new file mode 100644 index 0000000000..da7c0408b4 --- /dev/null +++ b/cpp/src/detail/jit_lto/LTOAlgorithmPlanner.cpp @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include +#include + +#include +#include + +#include "cuda_runtime.h" +#include "nvJitLink.h" + +#include + +std::string LTOAlgorithmPlanner::get_planner_key() const +{ + std::string key; + for (const auto& fragment : this->fragments) { + key += fragment->get_key(); + } + return key; +} + +std::shared_ptr LTOAlgorithmPlanner::build() +{ + int device = 0; + int major = 0; + int minor = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + + std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); + + nvJitLinkHandle handle; + std::vector lopts; + lopts.reserve(2 + linktime_extra_options.size()); + lopts.push_back("-lto"); + lopts.push_back(archs.c_str()); + for (auto const& opt : linktime_extra_options) { + lopts.push_back(opt.c_str()); + } + auto result = nvJitLinkCreate(&handle, static_cast(lopts.size()), lopts.data()); + check_nvjitlink_result(handle, result); + + for (const auto& frag : this->fragments) { + frag->add_to(handle); + } + + result = nvJitLinkComplete(handle); + check_nvjitlink_result(handle, result); + + size_t cubin_size; + result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); + check_nvjitlink_result(handle, result); + + std::unique_ptr cubin{new char[cubin_size]}; + result = nvJitLinkGetLinkedCubin(handle, cubin.get()); + check_nvjitlink_result(handle, result); + + result = nvJitLinkDestroy(&handle); + RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); + + cudaLibrary_t library; + RAFT_CUDA_TRY( + cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + + cudaKernel_t kernel; + RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str())); + + return std::make_shared(kernel, library); +} diff --git a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp new file mode 100644 index 0000000000..edb6269213 --- /dev/null +++ b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp @@ -0,0 +1,38 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include +#include + +std::string TileAlgorithmPlanner::get_planner_key() const +{ + std::string key = this->entrypoint; + for (const auto& fragment : cubin_fragments_) { + key += fragment->get_key(); + } + if (tileir_fragment_) { key += tileir_fragment_->get_key(); } + return key; +} + +std::shared_ptr TileAlgorithmPlanner::build() +{ + int cc_major = 0; + int cc_minor = 0; + if (!cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { + return nullptr; + } + + int driver_version = 0; + if (cudaDriverGetVersion(&driver_version) != cudaSuccess) { return nullptr; } + + auto image = cuvs::detail::jit_lto::resolve_cutile_module_image( + cc_major, cc_minor, driver_version, cubin_fragments_, tileir_fragment_.get()); + if (!image) { return nullptr; } + + return cuvs::detail::jit_lto::load_cutile_launcher(*image, this->entrypoint); +} diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index f9dbd968ec..8b47092b58 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -5,14 +5,22 @@ #pragma once +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + #include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op #include "fused_distance_nn/cutlass_base.cuh" +#if CUVS_CUTILE_ENABLED +#include "fused_distance_nn/cutile/fused_1nn_tile.hpp" +#endif #include "fused_distance_nn/fused_cosine_nn.cuh" #include "fused_distance_nn/fused_l2_nn.cuh" #include "fused_distance_nn/helper_structs.cuh" #include "fused_distance_nn/simt_kernel.cuh" #include "pairwise_distance_base.cuh" // PairwiseDistances #include +#include #include // raft::KeyValuePair #include // raft::identity_op #include // Policy @@ -54,6 +62,13 @@ void fusedDistanceNNImpl(OutT* min, // The kernel policy is determined by fusedDistanceNN. typedef Policy P; +#if CUVS_CUTILE_ENABLED + if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device() && + try_fused_1nn_tile(min, x, y, m, n, k, metric, stream)) { + return; + } +#endif + dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py new file mode 100644 index 0000000000..6a20be24ef --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""Export fused 1-NN cuTile kernels to cubin or TileIR bytecode.""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Literal + +import cuda.tile as ct +from cuda.tile.compilation import ( + ArrayConstraint, + CallingConvention, + ConstantConstraint, + KernelSignature, + ScalarConstraint, + export_kernel, +) + +from fused_1nn_kernel import KERNELS, KERNEL_SYMBOLS, TILE_CONSTANTS + +DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" +# cuTile requires a gpu_code even for TileIR bytecode export: it selects the compilation +# target / feature set for lowering, not the runtime architecture (the driver JITs at load). +DEFAULT_TILEIR_EXPORT_GPU_CODE = "sm_80" + + +def _dtype_for(data_type: str): + if data_type == "half": + return ct.float16 + if data_type == "float": + return ct.float32 + raise ValueError(f"Unsupported data_type {data_type!r}") + + +def _kernel_signature(data_type: str) -> KernelSignature: + elem = _dtype_for(data_type) + array = ArrayConstraint( + elem, + 2, + index_dtype=ct.int64, + stride_lower_bound_incl=0, + alias_groups=(), + may_alias_internally=False, + ) + idx_array = ArrayConstraint( + ct.int64, + 1, + index_dtype=ct.int64, + stride_lower_bound_incl=0, + alias_groups=(), + may_alias_internally=False, + stride_constant=(1,), + ) + dist_array = ArrayConstraint( + ct.float32, + 1, + index_dtype=ct.int64, + stride_lower_bound_incl=0, + alias_groups=(), + may_alias_internally=False, + stride_constant=(1,), + ) + tm, tn, tk = TILE_CONSTANTS + return KernelSignature( + parameters=[ + array, + array, + idx_array, + dist_array, + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ConstantConstraint(tm), + ConstantConstraint(tn), + ConstantConstraint(tk), + ], + calling_convention=CallingConvention.cutile_python_v1(), + ).with_symbol(KERNEL_SYMBOLS[data_type]) + + +def export_binary( + output_file: Path, + *, + output_format: Literal["cubin", "tileir_bytecode"], + data_type: str, + gpu_code: str, + bytecode_version: str | None = None, +) -> str: + kernel = KERNELS[data_type] + signature = _kernel_signature(data_type) + + export_kwargs = { + "kernel": kernel, + "signatures": [signature], + "output_file": str(output_file), + "gpu_code": gpu_code, + "output_format": output_format, + } + if output_format == "tileir_bytecode": + export_kwargs["bytecode_version"] = bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION + + export_kernel(**export_kwargs) + + return signature.symbol + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("output_file", type=Path) + parser.add_argument("--format", choices=("cubin", "tileir_bytecode"), default="cubin") + parser.add_argument("--data-type", choices=tuple(KERNELS.keys()), required=True) + parser.add_argument( + "--gpu-code", + default=DEFAULT_TILEIR_EXPORT_GPU_CODE, + help="Target SM for cubin export, or compile hint for TileIR bytecode export", + ) + parser.add_argument("--bytecode-version", default=DEFAULT_TILEIR_BYTECODE_VERSION) + args = parser.parse_args() + + print( + export_binary( + args.output_file, + output_format=args.format, + data_type=args.data_type, + gpu_code=args.gpu_code, + bytecode_version=args.bytecode_version, + ) + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json new file mode 100644 index 0000000000..fbd4bfdd64 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json @@ -0,0 +1,40 @@ +[ + { + "_data": [ + { + "data_type": "half", + "data_abbrev": "h" + }, + { + "data_type": "float", + "data_abbrev": "f" + } + ], + "_arch": [ + { + "gpu_code": "sm_80", + "cc_major": 8, + "cc_minor": 0, + "arch_tag": "cutile_arch_8_0" + }, + { + "gpu_code": "sm_86", + "cc_major": 8, + "cc_minor": 6, + "arch_tag": "cutile_arch_8_6" + }, + { + "gpu_code": "sm_90", + "cc_major": 9, + "cc_minor": 0, + "arch_tag": "cutile_arch_9_0" + }, + { + "gpu_code": "sm_120", + "cc_major": 12, + "cc_minor": 0, + "arch_tag": "cutile_arch_12_0" + } + ] + } +] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json new file mode 100644 index 0000000000..364c94594c --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json @@ -0,0 +1,20 @@ +[ + { + "_data": [ + { + "data_type": "half", + "data_abbrev": "h" + }, + { + "data_type": "float", + "data_abbrev": "f" + } + ], + "_tileir": [ + { + "export_gpu_code": "sm_80", + "bytecode_version": "13.1" + } + ] + } +] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py new file mode 100644 index 0000000000..232b9506af --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +"""cuTile fused GEMM + inner-product 1-NN (argmax dot product) for cuVS.""" + +from __future__ import annotations + +import cuda.tile as ct + +ConstInt = ct.Constant[int] + +TILE_M = 128 +TILE_N = 256 +TILE_K = 64 + + +def _make_kernel(data_type: str): + if data_type == "half": + dtype = ct.float16 + acc_dtype = ct.float32 + elif data_type == "float": + dtype = ct.float32 + acc_dtype = ct.float32 + else: + raise ValueError(f"Unsupported data_type {data_type!r}") + + @ct.kernel + def fused_1nn_kernel(A, B, OutIdx, OutDist, M, N, K, tm: ConstInt, tn: ConstInt, tk: ConstInt): + bidm = ct.bid(0) + + best_dist = ct.full((tm,), -3.4e38, acc_dtype) + best_idx = ct.zeros((tm,), ct.int64) + + num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) + num_tiles_n = ct.num_tiles(B, axis=0, shape=(tn, tk)) + zero_pad = ct.PaddingMode.ZERO + + for n in range(num_tiles_n): + accumulator = ct.full((tm, tn), 0, dtype=acc_dtype) + + for k in range(num_tiles_k): + a = ct.load(A, index=(bidm, k), shape=(tm, tk), padding_mode=zero_pad) + b_T = ct.load(B, index=(n, k), shape=(tn, tk), padding_mode=zero_pad) + accumulator = ct.mma(a, ct.transpose(b_T), accumulator) + + curr_max = ct.max(accumulator, axis=1) + curr_idx = ct.argmax(accumulator, axis=1) + + update = curr_max > best_dist + best_dist = ct.where(update, curr_max, best_dist) + best_idx = ct.where(update, n * tn + curr_idx, best_idx) + + ct.store(OutIdx, index=(bidm,), tile=best_idx) + ct.store(OutDist, index=(bidm,), tile=best_dist) + + return fused_1nn_kernel + + +KERNELS = { + "half": _make_kernel("half"), + "float": _make_kernel("float"), +} + +KERNEL_SYMBOLS = { + "half": "fused_1nn_half", + "float": "fused_1nn_float", +} + +TILE_CONSTANTS = (TILE_M, TILE_N, TILE_K) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp new file mode 100644 index 0000000000..dd2a539528 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp @@ -0,0 +1,60 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include +#include +#include + +namespace cuvs::distance::detail { + +/** Must match KERNEL_SYMBOLS in fused_1nn_kernel.py (export uses with_symbol). */ +template +inline const char* fused_1nn_kernel_entrypoint() +{ + if constexpr (std::is_same_v) { + return "fused_1nn_half"; + } else if constexpr (std::is_same_v) { + return "fused_1nn_float"; + } else { + static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data type"); + return ""; + } +} + +template +struct Fused1nnTilePlanner : TileAlgorithmPlanner { + inline static LauncherJitCache launcher_jit_cache{}; + + Fused1nnTilePlanner() + : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), launcher_jit_cache) + { + } + + /** Registers embedded cubin modules (one per SM); see register_cubin.cpp object files. */ + void add_entrypoint() + { + using cuvs::detail::jit_lto::cutile_arch_12_0; + using cuvs::detail::jit_lto::cutile_arch_8_0; + using cuvs::detail::jit_lto::cutile_arch_8_6; + using cuvs::detail::jit_lto::cutile_arch_9_0; + + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + } + + void add_tileir_fallback() + { + this->add_static_tileir_fragment>(); + } +}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu new file mode 100644 index 0000000000..af8b0b181f --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -0,0 +1,173 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "fused_1nn_tile.hpp" + +#include "fused_1nn_planner.hpp" + +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { + +namespace { + +template +__global__ void pack_fused_1nn_kvp(OutT* out, const int64_t* idx, const float* dist, IdxT len) +{ + IdxT i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + out[i].key = static_cast(idx[i]); + out[i].value = static_cast(dist[i]); + } +} + +template +bool launch_fused_1nn_tile(const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + cudaStream_t stream) +{ + Fused1nnTilePlanner planner; + planner.add_entrypoint(); + planner.add_tileir_fallback(); + auto launcher = planner.try_get_launcher(); + if (!launcher) { return false; } + + int64_t* d_idx = nullptr; + float* d_dist = nullptr; + RAFT_CUDA_TRY(cudaMallocAsync(&d_idx, m * sizeof(int64_t), stream)); + RAFT_CUDA_TRY(cudaMallocAsync(&d_dist, m * sizeof(float), stream)); + + int64_t shape_x[2] = {m, k}; + int64_t stride_x[2] = {k, 1}; + int64_t shape_y[2] = {n, k}; + int64_t stride_y[2] = {k, 1}; + int64_t shape_idx[1] = {m}; + int64_t stride_idx[1] = {1}; + int64_t shape_dist[1] = {m}; + int64_t stride_dist[1] = {1}; + + int64_t M = m, N = n, K = k; + constexpr int64_t tm = 128, tn = 256, tk = 64; + + void* x_ptr = const_cast(x); + void* y_ptr = const_cast(y); + void* idx_ptr = d_idx; + void* dist_ptr = d_dist; + + dim3 grid((m + tm - 1) / tm, 1, 1); + dim3 block(1, 1, 1); + + using fused_1nn_cutile_kernel_t = void(void*, + int64_t*, + int64_t*, + void*, + int64_t*, + int64_t*, + void*, + int64_t*, + int64_t*, + void*, + int64_t*, + int64_t*, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t, + int64_t); + launcher->template dispatch( + stream, + grid, + block, + 0, + x_ptr, + shape_x, + stride_x, + y_ptr, + shape_y, + stride_y, + idx_ptr, + shape_idx, + stride_idx, + dist_ptr, + shape_dist, + stride_dist, + M, + N, + K, + tm, + tn, + tk); + + pack_fused_1nn_kvp<<<(m + 255) / 256, 256, 0, stream>>>(out, d_idx, d_dist, m); + RAFT_CUDA_TRY(cudaGetLastError()); + RAFT_CUDA_TRY(cudaFreeAsync(d_idx, stream)); + RAFT_CUDA_TRY(cudaFreeAsync(d_dist, stream)); + return true; +} + +} // namespace + +template , int>> +bool try_fused_1nn_tile(OutT* min, + const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + cudaStream_t stream) +{ + if (metric != cuvs::distance::DistanceType::InnerProduct) { return false; } + + if constexpr (std::is_same_v) { + return launch_fused_1nn_tile( + x, y, min, m, n, k, stream); + } else if constexpr (std::is_same_v) { + return launch_fused_1nn_tile( + x, y, min, m, n, k, stream); + } else { + return false; + } +} + +using kvp_i_f = raft::KeyValuePair; +using kvp_i64_f = raft::KeyValuePair; +using kvp_i_h = raft::KeyValuePair; +using kvp_i64_h = raft::KeyValuePair; + +#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, OutT, IdxT) \ + template CUVS_EXPORT bool try_fused_1nn_tile(OutT*, \ + const DataT*, \ + const DataT*, \ + IdxT, \ + IdxT, \ + IdxT, \ + cuvs::distance::DistanceType, \ + cudaStream_t) + +// int and int32_t are the same on LP64; one instantiation covers both. +CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i_f, int); +CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i64_f, int64_t); +CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i_f, int); +CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i64_f, int64_t); +CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i_h, int); +CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i64_h, int64_t); + +#undef CUVS_INST_TRY_FUSED_1NN_TILE + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp new file mode 100644 index 0000000000..30f804d399 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { + +template +inline constexpr bool is_fused_1nn_kvp_output_v = + std::is_same_v> || + std::is_same_v>; + +template , int> = 0> +bool try_fused_1nn_tile(OutT* min, + const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + cudaStream_t stream); + +template , int> = 0> +bool try_fused_1nn_tile(OutT*, + const DataT*, + const DataT*, + IdxT, + IdxT, + IdxT, + cuvs::distance::DistanceType, + cudaStream_t) +{ + return false; +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp index 0d00b3eca6..f89a383596 100644 --- a/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp +++ b/cpp/src/distance/detail/pairwise_matrix/jit_lto_kernels/pairwise_matrix_planner.hpp @@ -20,7 +20,7 @@ template -struct PairwiseMatrixPlanner : AlgorithmPlanner { +struct PairwiseMatrixPlanner : LTOAlgorithmPlanner { using DistanceTag = DistanceTag_; using DataTag = DataTag_; using AccTag = AccTag_; @@ -33,7 +33,7 @@ struct PairwiseMatrixPlanner : AlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - PairwiseMatrixPlanner() : AlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} + PairwiseMatrixPlanner() : LTOAlgorithmPlanner(kPairwiseMatrixJitEntrypoint, launcher_jit_cache) {} void add_entrypoint() { diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 0c3ed64d13..b44a7f044e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -25,7 +25,7 @@ template -struct CagraPlannerBase : AlgorithmPlanner { +struct CagraPlannerBase : LTOAlgorithmPlanner { using DataTag = DataTag_; using IndexTag = IndexTag_; using DistanceTag = DistanceTag_; @@ -34,7 +34,7 @@ struct CagraPlannerBase : AlgorithmPlanner { using SampleFilterJitTag = SampleFilterJitTag_; explicit CagraPlannerBase(std::string entrypoint, LauncherJitCache& jit_cache) - : AlgorithmPlanner(std::move(entrypoint), jit_cache) + : LTOAlgorithmPlanner(std::move(entrypoint), jit_cache) { } diff --git a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp index ed8191016b..7899d970ab 100644 --- a/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp +++ b/cpp/src/neighbors/ivf_flat/detail/jit_lto_kernels/interleaved_scan_planner.hpp @@ -14,10 +14,10 @@ namespace cuvs::neighbors::ivf_flat::detail { -struct InterleavedScanPlanner : AlgorithmPlanner { +struct InterleavedScanPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - InterleavedScanPlanner() : AlgorithmPlanner("interleaved_scan", launcher_jit_cache) {} + InterleavedScanPlanner() : LTOAlgorithmPlanner("interleaved_scan", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp b/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp index 0621966cad..7152aaeebd 100644 --- a/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp +++ b/cpp/src/neighbors/ivf_pq/detail/jit_lto_kernels/compute_similarity_planner.hpp @@ -12,10 +12,10 @@ namespace cuvs::neighbors::ivf_pq::detail { -struct ComputeSimilarityPlanner : AlgorithmPlanner { +struct ComputeSimilarityPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - ComputeSimilarityPlanner() : AlgorithmPlanner("compute_similarity", launcher_jit_cache) {} + ComputeSimilarityPlanner() : LTOAlgorithmPlanner("compute_similarity", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp index 05ea34532e..5dc47dc612 100644 --- a/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp +++ b/cpp/src/neighbors/ivf_sq/detail/jit_lto_kernels/scan_planner.hpp @@ -13,10 +13,10 @@ namespace cuvs::neighbors::ivf_sq::detail { -struct IvfSqScanPlanner : AlgorithmPlanner { +struct IvfSqScanPlanner : LTOAlgorithmPlanner { inline static LauncherJitCache launcher_jit_cache{}; - IvfSqScanPlanner() : AlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} + IvfSqScanPlanner() : LTOAlgorithmPlanner("ivf_sq_scan", launcher_jit_cache) {} template void add_entrypoint() diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index ba6ed6e0e7..006b35b5c4 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -386,7 +386,8 @@ ConfigureTest( PERCENT 100 ) -add_subdirectory(cutile) +# cuTile vector-add example test disabled; fused 1-NN cuTile is covered via libcuvs integration. +# add_subdirectory(cutile) # ################################################################################################## # Install tests #################################################################################### diff --git a/cpp/tests/cutile/cutile_vector_add.cu b/cpp/tests/cutile/cutile_vector_add.cu index 77a5e51311..07d694bef1 100644 --- a/cpp/tests/cutile/cutile_vector_add.cu +++ b/cpp/tests/cutile/cutile_vector_add.cu @@ -11,10 +11,15 @@ #include "vector_add_sm_80_cubin.h" #include "vector_add_sm_86_cubin.h" #include "vector_add_sm_90_cubin.h" +#include "vector_add_tileir_bytecode.h" + +#include #include #include +#include +#include namespace cuvs { namespace { @@ -26,7 +31,7 @@ struct EmbeddedCubin { size_t size; }; -// Lookup table for cubins built at configure time (see export_vector_add_cubin.py). +// Prebuilt cubins for known library targets (see export_vector_add_cubin.py). constexpr EmbeddedCubin kEmbeddedCubins[] = { {8, 0, vector_add_sm_80_cubin, sizeof(vector_add_sm_80_cubin)}, {8, 6, vector_add_sm_86_cubin, sizeof(vector_add_sm_86_cubin)}, @@ -35,53 +40,128 @@ constexpr EmbeddedCubin kEmbeddedCubins[] = { {12, 0, vector_add_sm_120_cubin, sizeof(vector_add_sm_120_cubin)}, }; -const EmbeddedCubin* find_embedded_cubin(int cc_major, int cc_minor) +constexpr EmbeddedCubin kTileIrBytecode = { + -1, + -1, + vector_add_tileir_bytecode, + sizeof(vector_add_tileir_bytecode), +}; + +struct CutileModuleImage { + const uint8_t* data; + size_t size; +}; + +std::optional resolve_vector_add_module(int cc_major, int cc_minor) { for (const auto& entry : kEmbeddedCubins) { - if (entry.cc_major == cc_major && entry.cc_minor == cc_minor) { return &entry; } + if (entry.cc_major == cc_major && entry.cc_minor == cc_minor) { + return CutileModuleImage{reinterpret_cast(entry.data), entry.size}; + } } - // Fall back to a cubin for the same major version (e.g. minor SKUs within a generation). - for (const auto& entry : kEmbeddedCubins) { - if (entry.cc_major == cc_major) { return &entry; } + + int driver_version = 0; + if (cudaDriverGetVersion(&driver_version) != cudaSuccess) { return std::nullopt; } + if (!cuvs::detail::jit_lto::tileir_fallback_available(driver_version)) { + return std::nullopt; } - return nullptr; + return CutileModuleImage{ + reinterpret_cast(kTileIrBytecode.data), kTileIrBytecode.size}; } -class CutileVectorAddTest : public ::testing::Test { - protected: - void SetUp() override +struct LoadedKernel { + cudaLibrary_t library = nullptr; + cudaKernel_t kernel = nullptr; + bool used_tileir_jit{false}; + const char* skip_reason{nullptr}; + + LoadedKernel() = default; + + LoadedKernel(LoadedKernel&& other) noexcept { *this = std::move(other); } + + LoadedKernel& operator=(LoadedKernel&& other) noexcept { - int device = 0; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY( - cudaDeviceGetAttribute(&cc_major_, cudaDevAttrComputeCapabilityMajor, device)); - RAFT_CUDA_TRY( - cudaDeviceGetAttribute(&cc_minor_, cudaDevAttrComputeCapabilityMinor, device)); + if (this != &other) { + unload(); + library = other.library; + kernel = other.kernel; + used_tileir_jit = other.used_tileir_jit; + skip_reason = other.skip_reason; + other.library = nullptr; + other.kernel = nullptr; + } + return *this; } - int cc_major_{}; - int cc_minor_{}; -}; + LoadedKernel(const LoadedKernel&) = delete; + LoadedKernel& operator=(const LoadedKernel&) = delete; -} // namespace + ~LoadedKernel() { unload(); } -TEST_F(CutileVectorAddTest, EmbeddedCubinVectorAdd) + explicit operator bool() const { return kernel != nullptr; } + + private: + void unload() + { + if (library != nullptr) { + RAFT_CUDA_TRY(cudaLibraryUnload(library)); + library = nullptr; + kernel = nullptr; + } + } +}; + +LoadedKernel load_vector_add_kernel(int cc_major, int cc_minor) { - const EmbeddedCubin* cubin = find_embedded_cubin(cc_major_, cc_minor_); - ASSERT_NE(cubin, nullptr) - << "No embedded cuTile cubin for compute capability " << cc_major_ << "." << cc_minor_; + LoadedKernel result{}; + result.used_tileir_jit = !cuvs::detail::jit_lto::is_embedded_cubin_arch(cc_major, cc_minor); + + auto image = resolve_vector_add_module(cc_major, cc_minor); + if (!image) { + if (result.used_tileir_jit) { + result.skip_reason = + "TileIR driver JIT unavailable for this GPU. Requires CUDA 13.1+ driver (>= 590.44)."; + } else { + ADD_FAILURE() << "No embedded cuTile module for compute capability " << cc_major << "." + << cc_minor; + } + return result; + } - cudaLibrary_t library{}; - ASSERT_EQ(cudaSuccess, - cudaLibraryLoadData( - &library, cubin->data, nullptr, nullptr, 0, nullptr, nullptr, 0)) - << "cudaLibraryLoadData failed: " << cudaGetErrorString(cudaGetLastError()); + const cudaError_t load_status = + cudaLibraryLoadData(&result.library, image->data, nullptr, nullptr, 0, nullptr, nullptr, 0); + if (load_status != cudaSuccess) { + if (result.used_tileir_jit) { + result.skip_reason = + "TileIR driver JIT unavailable for this GPU (requires CUDA 13.1+ driver >= 590.44)."; + SCOPED_TRACE(cudaGetErrorString(load_status)); + } else { + ADD_FAILURE() << "cudaLibraryLoadData failed: " << cudaGetErrorString(load_status); + } + return result; + } - cudaKernel_t kernel{}; - ASSERT_EQ(cudaSuccess, - cudaLibraryGetKernel(&kernel, library, CUTILE_VECTOR_ADD_KERNEL_SYMBOL)) - << "cudaLibraryGetKernel failed: " << cudaGetErrorString(cudaGetLastError()); + const cudaError_t kernel_status = + cudaLibraryGetKernel(&result.kernel, result.library, CUTILE_VECTOR_ADD_KERNEL_SYMBOL); + if (kernel_status != cudaSuccess) { + if (result.library != nullptr) { + RAFT_CUDA_TRY(cudaLibraryUnload(result.library)); + result.library = nullptr; + } + result.kernel = nullptr; + if (result.used_tileir_jit) { + result.skip_reason = + "TileIR driver JIT unavailable for this GPU (requires CUDA 13.1+ driver >= 590.44)."; + SCOPED_TRACE(cudaGetErrorString(kernel_status)); + } else { + ADD_FAILURE() << "cudaLibraryGetKernel failed: " << cudaGetErrorString(kernel_status); + } + } + return result; +} +void run_vector_add(cudaKernel_t kernel) +{ constexpr int kN = 1024; constexpr int kTile = 256; constexpr int kGridDim = (kN + kTile - 1) / kTile; @@ -122,7 +202,35 @@ TEST_F(CutileVectorAddTest, EmbeddedCubinVectorAdd) RAFT_CUDA_TRY(cudaFree(d_a)); RAFT_CUDA_TRY(cudaFree(d_b)); RAFT_CUDA_TRY(cudaFree(d_c)); - RAFT_CUDA_TRY(cudaLibraryUnload(library)); +} + +class CutileVectorAddTest : public ::testing::Test { + protected: + void SetUp() override + { + int device = 0; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&cc_major_, cudaDevAttrComputeCapabilityMajor, device)); + RAFT_CUDA_TRY( + cudaDeviceGetAttribute(&cc_minor_, cudaDevAttrComputeCapabilityMinor, device)); + } + + int cc_major_{}; + int cc_minor_{}; +}; + +} // namespace + +TEST_F(CutileVectorAddTest, EmbeddedCubinVectorAdd) +{ + LoadedKernel loaded = load_vector_add_kernel(cc_major_, cc_minor_); + if (loaded.skip_reason) { GTEST_SKIP() << loaded.skip_reason; } + if (!loaded) { return; } + + SCOPED_TRACE(loaded.used_tileir_jit ? "loaded via TileIR driver JIT" + : "loaded via prebuilt cubin"); + run_vector_add(loaded.kernel); } } // namespace cuvs diff --git a/cpp/tests/cutile/export_vector_add_cubin.py b/cpp/tests/cutile/export_vector_add_cubin.py index bf40a4ad80..fa099189cd 100644 --- a/cpp/tests/cutile/export_vector_add_cubin.py +++ b/cpp/tests/cutile/export_vector_add_cubin.py @@ -1,12 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -"""Export the cuTile vector-add kernel to a cubin for a single GPU target.""" +"""Export the cuTile vector-add kernel to cubin or TileIR bytecode.""" from __future__ import annotations import argparse import sys from pathlib import Path +from typing import Literal import cuda.tile as ct from cuda.tile.compilation import ( @@ -28,6 +29,9 @@ # sm_120 -> 120a-real SUPPORTED_GPU_CODES = ("sm_80", "sm_86", "sm_90", "sm_100", "sm_120") +# Minimum TileIR bytecode version supported by cuTile; also the most portable choice. +DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" + def _kernel_signature() -> KernelSignature: array = ArrayConstraint( @@ -45,20 +49,31 @@ def _kernel_signature() -> KernelSignature: ).with_mangled_symbol("vector_add") -def export_cubin(output_file: Path, gpu_code: str, symbol_header: Path | None) -> str: - if gpu_code not in SUPPORTED_GPU_CODES: +def export_kernel_binary( + output_file: Path, + *, + output_format: Literal["cubin", "tileir_bytecode"], + gpu_code: str, + bytecode_version: str | None = None, + symbol_header: Path | None = None, +) -> str: + if output_format == "cubin" and gpu_code not in SUPPORTED_GPU_CODES: raise ValueError( f"Unsupported gpu_code {gpu_code!r}; expected one of {SUPPORTED_GPU_CODES}" ) signature = _kernel_signature() - export_kernel( - vector_add, - signatures=[signature], - output_file=str(output_file), - gpu_code=gpu_code, - output_format="cubin", - ) + export_kwargs: dict = { + "kernel": vector_add, + "signatures": [signature], + "output_file": str(output_file), + "gpu_code": gpu_code, + "output_format": output_format, + } + if output_format == "tileir_bytecode": + export_kwargs["bytecode_version"] = bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION + + export_kernel(**export_kwargs) if symbol_header is not None: symbol_header.write_text( @@ -77,12 +92,23 @@ def export_cubin(output_file: Path, gpu_code: str, symbol_header: Path | None) - def main() -> int: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("output_file", type=Path, help="Output cubin path") + parser.add_argument("output_file", type=Path, help="Output cubin or .tilebc path") + parser.add_argument( + "--format", + choices=("cubin", "tileir_bytecode"), + default="cubin", + help="Export format (default: cubin)", + ) parser.add_argument( "--gpu-code", required=True, choices=SUPPORTED_GPU_CODES, - help="tileiras / export_kernel target (e.g. sm_120)", + help="tileiras / export_kernel compile target (e.g. sm_120)", + ) + parser.add_argument( + "--bytecode-version", + default=DEFAULT_TILEIR_BYTECODE_VERSION, + help="TileIR bytecode version when --format=tileir_bytecode (default: 13.1)", ) parser.add_argument( "--symbol-header", @@ -92,7 +118,13 @@ def main() -> int: ) args = parser.parse_args() - symbol = export_cubin(args.output_file, args.gpu_code, args.symbol_header) + symbol = export_kernel_binary( + args.output_file, + output_format=args.format, + gpu_code=args.gpu_code, + bytecode_version=args.bytecode_version, + symbol_header=args.symbol_header, + ) print(symbol) return 0 diff --git a/cpp/tests/cutile/generate_cutile_cubins.cmake b/cpp/tests/cutile/generate_cutile_cubins.cmake index 3425b03028..766d3167c6 100644 --- a/cpp/tests/cutile/generate_cutile_cubins.cmake +++ b/cpp/tests/cutile/generate_cutile_cubins.cmake @@ -78,6 +78,33 @@ function(generate_cutile_vector_add_cubins output_include_dir_var) list(APPEND _generated_headers "${_cubin_header}") endforeach() + # Portable TileIR bytecode for driver JIT on architectures without a prebuilt cubin. + # Requires a CUDA 13.1+ driver (>= 590.44); see Tile IR bytecode docs. + set(_tileir_file "${_cutile_binary_dir}/vector_add.tilebc") + set(_tileir_header "${_cutile_binary_dir}/vector_add_tileir_bytecode.h") + + add_custom_command( + OUTPUT "${_tileir_file}" + COMMAND + "${Python3_EXECUTABLE}" "${_cutile_source_dir}/export_vector_add_cubin.py" + "${_tileir_file}" --format tileir_bytecode --gpu-code sm_80 --bytecode-version 13.1 + DEPENDS "${_cutile_source_dir}/export_vector_add_cubin.py" + "${_cutile_source_dir}/vector_add_kernel.py" + COMMENT "Exporting cuTile vector_add TileIR bytecode (v13.1)" + VERBATIM + ) + + add_custom_command( + OUTPUT "${_tileir_header}" + COMMAND "${CUTILE_BIN2C}" --const --name vector_add_tileir_bytecode --static "${_tileir_file}" + > "${_tileir_header}" + DEPENDS "${_tileir_file}" + COMMENT "Embedding vector_add TileIR bytecode via bin2c" + VERBATIM + ) + + list(APPEND _generated_headers "${_tileir_header}") + add_custom_target( cutile_vector_add_cubins DEPENDS "${_symbol_header}" ${_generated_headers} diff --git a/cpp/tests/neighbors/distance_nn.cu b/cpp/tests/neighbors/distance_nn.cu index f31f3ebacf..f5efaa5bec 100644 --- a/cpp/tests/neighbors/distance_nn.cu +++ b/cpp/tests/neighbors/distance_nn.cu @@ -187,6 +187,7 @@ const std::vector> input_fp32 = { {4096, 16384, 128, DistanceType::L2Expanded, true, uint64_t(31415926), 0.1}, {4096, 4096, 64, DistanceType::L2SqrtExpanded, false, uint64_t(31415926), 0.1}, {4096, 16384, 128, DistanceType::L2SqrtExpanded, false, uint64_t(31415926), 0.1}, + {512, 1024, 64, DistanceType::InnerProduct, false, uint64_t(31415926), 0.1}, {4096, 4096, 64, DistanceType::CosineExpanded, false, uint64_t(31415926), 0.1}, {8192, 4096, 64, DistanceType::CosineExpanded, false, uint64_t(31415926), 0.1}, // Fused implementation for cosine distance ignores the sqrt parameter, therefore diff --git a/cpp/tests/neighbors/distance_nn_helper.cuh b/cpp/tests/neighbors/distance_nn_helper.cuh index fda7b76573..422879918f 100644 --- a/cpp/tests/neighbors/distance_nn_helper.cuh +++ b/cpp/tests/neighbors/distance_nn_helper.cuh @@ -66,6 +66,16 @@ __device__ AccT cosine_distance(const DataT* v1, const DataT* v2, IdxT K) } // This is a naive implementation of 1-NN computation +template +__device__ AccT inner_product_score(const DataT* v1, const DataT* v2, IdxT K) +{ + AccT score = AccT(0.0); + for (IdxT i = 0; i < K; i++) { + score += AccT(v1[i]) * AccT(v2[i]); + } + return score; +} + template RAFT_KERNEL ref_nn_kernel( OutT* out, const DataT* A, const DataT* B, IdxT M, IdxT N, IdxT K, bool sqrt, DistanceType metric) @@ -73,22 +83,47 @@ RAFT_KERNEL ref_nn_kernel( IdxT tid = threadIdx.x + blockIdx.x * IdxT(blockDim.x); for (IdxT m = tid; m < M; m += (blockDim.x * gridDim.x)) { - IdxT min_index = N + 1; - AccT min_dist = max_val(); + IdxT best_index = N + 1; + AccT best_score = min_val(); + AccT best_dist = max_val(); for (IdxT n = 0; n < N; n++) { + if (metric == DistanceType::InnerProduct) { + AccT score = inner_product_score(&A[m * K], &B[n * K], K); + if (score > best_score) { + best_score = score; + best_index = n; + } + continue; + } + AccT dist; if (metric == DistanceType::L2SqrtExpanded || metric == DistanceType::L2Expanded) { dist = l2_distance(&A[m * K], &B[n * K], K); } else if (metric == DistanceType::CosineExpanded) { dist = cosine_distance(&A[m * K], &B[n * K], K); + } else { + continue; + } + if (dist < best_dist) { + best_dist = dist; + best_index = n; } - if (dist < min_dist) { - min_dist = dist; - min_index = n; + } + + if (metric == DistanceType::InnerProduct) { + if constexpr (std::is_fundamental::value) { + out[m] = AccT(best_score); + } else { + out[m].key = IdxT(best_index); + out[m].value = AccT(best_score); } + continue; } + IdxT min_index = best_index; + AccT min_dist = best_dist; + if constexpr (std::is_fundamental::value) { static_assert(std::is_same::value, "OutT and AccT are not same type"); out[m] = AccT(min_dist); From 1b934ddaa684bb09d2df9e611e6c7482b7ee975b Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 24 Jun 2026 19:33:18 +0000 Subject: [PATCH 03/10] attempt to fix tile linkage --- cpp/CMakeLists.txt | 22 +-- .../modules/generate_cutile_kernels.cmake | 183 +++++------------- ...cpp.in => register_cutile_fragment.cpp.in} | 8 +- cpp/cmake/modules/register_tileir.cpp.in | 22 --- .../cutile/fused_1nn_cutile_cubin_matrix.json | 40 ---- .../cutile/fused_1nn_cutile_matrix.json | 64 ++++++ .../fused_1nn_cutile_tileir_matrix.json | 20 -- 7 files changed, 120 insertions(+), 239 deletions(-) rename cpp/cmake/modules/{register_cubin.cpp.in => register_cutile_fragment.cpp.in} (57%) delete mode 100644 cpp/cmake/modules/register_tileir.cpp.in delete mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json create mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json delete mode 100644 cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cc6e1975b3..a1f3f3973c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -963,33 +963,21 @@ if(NOT BUILD_CPU_ONLY) "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/fused_distance_nn/cutile") set(cutile_fused_1nn_generated_dir "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/fused_1nn/cutile") - generate_cutile_cubin_kernels( + generate_cutile_kernels( cutile_fused_1nn_files KERNEL_DIR "${fused_1nn_cutile_dir}" KERNEL_BASENAME "fused_1nn" KERNEL_PYTHON "fused_1nn_kernel.py" EXPORT_SCRIPT "export_fused_1nn.py" OUTPUT_DIRECTORY "${cutile_fused_1nn_generated_dir}" - MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_cubin_matrix.json" - FRAGMENT_TAG_FORMAT + MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" + FRAGMENT_TAG_FORMAT_CUBIN "cuvs::distance::detail::fragment_tag_fused_1nn_cubin" - FRAGMENT_TAG_HEADER_FILES - "" - "" - "" - ) - generate_cutile_tileir_kernels( - cutile_fused_1nn_files - KERNEL_DIR "${fused_1nn_cutile_dir}" - KERNEL_BASENAME "fused_1nn" - KERNEL_PYTHON "fused_1nn_kernel.py" - EXPORT_SCRIPT "export_fused_1nn.py" - OUTPUT_DIRECTORY "${cutile_fused_1nn_generated_dir}" - MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_tileir_matrix.json" - FRAGMENT_TAG_FORMAT + FRAGMENT_TAG_FORMAT_TILEIR "cuvs::distance::detail::fragment_tag_fused_1nn_tileir" FRAGMENT_TAG_HEADER_FILES "" + "" "" ) if(NOT DEFINED CUVS_CUTILE_ENABLED) diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake index 7b9c2521c4..f0219dc842 100644 --- a/cpp/cmake/modules/generate_cutile_kernels.cmake +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -77,13 +77,15 @@ function(_cutile_kernels_setup) file(MAKE_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}") + set(Python3_EXECUTABLE "${Python3_EXECUTABLE}" PARENT_SCOPE) + set(CUTILE_BIN2C "${CUTILE_BIN2C}" PARENT_SCOPE) set(_CUTILE_SETUP_OK TRUE PARENT_SCOPE ) endfunction() -function(process_cutile_cubin_matrix_entry source_list_var) +function(process_cutile_matrix_entry source_list_var) set(options) set(one_value KERNEL_DIR @@ -91,110 +93,75 @@ function(process_cutile_cubin_matrix_entry source_list_var) KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY - FRAGMENT_TAG_FORMAT + FRAGMENT_TAG_FORMAT_CUBIN + FRAGMENT_TAG_FORMAT_TILEIR MATRIX_JSON_ENTRY ) set(multi_value FRAGMENT_TAG_HEADER_FILES) cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) - populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") - _cutile_fragment_tag_header_files( - fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} - ) - - string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY) - - set(_artifact_basename "${_CUTILE_KERNEL_BASENAME}_${data_type}_${gpu_code}") - set(_cubin_file "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}.cubin") - set(_cubin_header "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}_cubin.h") - set(_cubin_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_basename}_cubin.cpp") - set(cubin_header_file "${_artifact_basename}_cubin.h") - - add_custom_command( - OUTPUT "${_cubin_file}" - COMMAND - "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" "${_cubin_file}" - --format cubin --data-type "${data_type}" --gpu-code "${gpu_code}" - DEPENDS "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" - "${_CUTILE_KERNEL_DIR}/${_CUTILE_KERNEL_PYTHON}" - COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} cubin ${data_type} ${gpu_code}" - VERBATIM - ) - - add_custom_command( - OUTPUT "${_cubin_header}" - COMMAND "${CUTILE_BIN2C}" --const --name embedded_cubin --static "${_cubin_file}" - > "${_cubin_header}" - DEPENDS "${_cubin_file}" - VERBATIM - ) + find_package(Python3 REQUIRED COMPONENTS Interpreter) - configure_file( - "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_cubin.cpp.in" "${_cubin_cpp}" @ONLY - ) - list(APPEND ${source_list_var} "${_cubin_header}" "${_cubin_cpp}") - set(${source_list_var} - "${${source_list_var}}" - PARENT_SCOPE - ) -endfunction() + populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") -function(process_cutile_tileir_matrix_entry source_list_var) - set(options) - set(one_value - KERNEL_DIR - KERNEL_BASENAME - KERNEL_PYTHON - EXPORT_SCRIPT - OUTPUT_DIRECTORY - FRAGMENT_TAG_FORMAT - MATRIX_JSON_ENTRY - ) - set(multi_value FRAGMENT_TAG_HEADER_FILES) - cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) + if(register STREQUAL "cubin") + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" fragment_tag @ONLY) + set(bin2c_symbol embedded_cubin) + set(fragment_entry_type "StaticCubinFragmentEntry") + elseif(register STREQUAL "tileir") + string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" fragment_tag @ONLY) + set(bin2c_symbol embedded_tileir) + set(fragment_entry_type "StaticTileIrBytecodeFragmentEntry") + else() + message(FATAL_ERROR "Unknown cuTile register kind '${register}'") + endif() - populate_matrix_variables("${_CUTILE_MATRIX_JSON_ENTRY}") _cutile_fragment_tag_header_files( fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} ) - string(CONFIGURE "${_CUTILE_FRAGMENT_TAG_FORMAT}" fragment_tag @ONLY) - set(_tileir_file "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}.tilebc") - set(_tileir_header "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.h") - set(_tileir_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.cpp") - set(tileir_header_file "${_CUTILE_KERNEL_BASENAME}_${data_type}_tileir.h") + string(CONFIGURE "${artifact_basename}" _artifact_basename @ONLY) + set(_artifact_stem "${_CUTILE_KERNEL_BASENAME}_${_artifact_basename}") + set(_artifact_file "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}.${artifact_ext}") + set(_embedded_header "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}_${register}.h") + set(_fragment_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}_${register}.cpp") + set(embedded_header_file "${_artifact_stem}_${register}.h") + + set(_python_args --format "${output_format}" --data-type "${data_type}" --gpu-code "${gpu_code}") + if(DEFINED bytecode_version AND NOT "${bytecode_version}" STREQUAL "") + list(APPEND _python_args --bytecode-version "${bytecode_version}") + endif() add_custom_command( - OUTPUT "${_tileir_file}" - COMMAND - "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" "${_tileir_file}" - --format tileir_bytecode --data-type "${data_type}" --gpu-code "${export_gpu_code}" - --bytecode-version "${bytecode_version}" + OUTPUT "${_artifact_file}" + COMMAND "${Python3_EXECUTABLE}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" + "${_artifact_file}" ${_python_args} + WORKING_DIRECTORY "${_CUTILE_KERNEL_DIR}" DEPENDS "${_CUTILE_KERNEL_DIR}/${_CUTILE_EXPORT_SCRIPT}" "${_CUTILE_KERNEL_DIR}/${_CUTILE_KERNEL_PYTHON}" - COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} TileIR bytecode ${data_type}" + COMMENT "Exporting cuTile ${_CUTILE_KERNEL_BASENAME} ${output_format} ${data_type}" VERBATIM ) add_custom_command( - OUTPUT "${_tileir_header}" - COMMAND "${CUTILE_BIN2C}" --const --name embedded_tileir --static "${_tileir_file}" - > "${_tileir_header}" - DEPENDS "${_tileir_file}" + OUTPUT "${_embedded_header}" + COMMAND "${CUTILE_BIN2C}" --const --name ${bin2c_symbol} --static "${_artifact_file}" + > "${_embedded_header}" + DEPENDS "${_artifact_file}" VERBATIM ) configure_file( - "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_tileir.cpp.in" "${_tileir_cpp}" @ONLY + "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/register_cutile_fragment.cpp.in" "${_fragment_cpp}" @ONLY ) - list(APPEND ${source_list_var} "${_tileir_header}" "${_tileir_cpp}") + list(APPEND ${source_list_var} "${_embedded_header}" "${_fragment_cpp}") set(${source_list_var} "${${source_list_var}}" PARENT_SCOPE ) endfunction() -function(generate_cutile_cubin_kernels source_list_var) +function(generate_cutile_kernels source_list_var) set(options) set(one_value KERNEL_DIR @@ -203,13 +170,14 @@ function(generate_cutile_cubin_kernels source_list_var) EXPORT_SCRIPT OUTPUT_DIRECTORY MATRIX_JSON_FILE - FRAGMENT_TAG_FORMAT + FRAGMENT_TAG_FORMAT_CUBIN + FRAGMENT_TAG_FORMAT_TILEIR ) set(multi_value FRAGMENT_TAG_HEADER_FILES) cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) if(NOT _CUTILE_KERNEL_BASENAME) - message(FATAL_ERROR "generate_cutile_cubin_kernels: KERNEL_BASENAME is required") + message(FATAL_ERROR "generate_cutile_kernels: KERNEL_BASENAME is required") endif() if(NOT _CUTILE_KERNEL_PYTHON) set(_CUTILE_KERNEL_PYTHON "fused_1nn_kernel.py") @@ -236,72 +204,15 @@ function(generate_cutile_cubin_kernels source_list_var) # cmake-lint: disable=C0103,E1120 foreach(i RANGE "${last}") string(JSON matrix_json_entry GET "${matrix_product}" "${i}") - process_cutile_cubin_matrix_entry( - "${source_list_var}" - KERNEL_DIR "${_CUTILE_KERNEL_DIR}" - KERNEL_BASENAME "${_CUTILE_KERNEL_BASENAME}" - KERNEL_PYTHON "${_CUTILE_KERNEL_PYTHON}" - EXPORT_SCRIPT "${_CUTILE_EXPORT_SCRIPT}" - OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" - FRAGMENT_TAG_FORMAT "${_CUTILE_FRAGMENT_TAG_FORMAT}" - FRAGMENT_TAG_HEADER_FILES ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} - MATRIX_JSON_ENTRY "${matrix_json_entry}" - ) - endforeach() - - set(CUVS_CUTILE_ENABLED 1 PARENT_SCOPE) - set(${source_list_var} - "${${source_list_var}}" - PARENT_SCOPE - ) -endfunction() - -function(generate_cutile_tileir_kernels source_list_var) - set(options) - set(one_value - KERNEL_DIR - KERNEL_BASENAME - KERNEL_PYTHON - EXPORT_SCRIPT - OUTPUT_DIRECTORY - MATRIX_JSON_FILE - FRAGMENT_TAG_FORMAT - ) - set(multi_value FRAGMENT_TAG_HEADER_FILES) - cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) - - if(NOT _CUTILE_KERNEL_BASENAME) - message(FATAL_ERROR "generate_cutile_tileir_kernels: KERNEL_BASENAME is required") - endif() - if(NOT _CUTILE_KERNEL_PYTHON) - set(_CUTILE_KERNEL_PYTHON "fused_1nn_kernel.py") - endif() - - _cutile_kernels_setup( - MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" - OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" - ) - if(NOT _CUTILE_SETUP_OK) - generate_cutile_kernels_stub() - return() - endif() - - compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}") - - string(JSON len LENGTH "${matrix_product}") - math(EXPR last "${len} - 1") - - # cmake-lint: disable=C0103,E1120 - foreach(i RANGE "${last}") - string(JSON matrix_json_entry GET "${matrix_product}" "${i}") - process_cutile_tileir_matrix_entry( + process_cutile_matrix_entry( "${source_list_var}" KERNEL_DIR "${_CUTILE_KERNEL_DIR}" KERNEL_BASENAME "${_CUTILE_KERNEL_BASENAME}" KERNEL_PYTHON "${_CUTILE_KERNEL_PYTHON}" EXPORT_SCRIPT "${_CUTILE_EXPORT_SCRIPT}" OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" - FRAGMENT_TAG_FORMAT "${_CUTILE_FRAGMENT_TAG_FORMAT}" + FRAGMENT_TAG_FORMAT_CUBIN "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" + FRAGMENT_TAG_FORMAT_TILEIR "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" FRAGMENT_TAG_HEADER_FILES ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} MATRIX_JSON_ENTRY "${matrix_json_entry}" ) diff --git a/cpp/cmake/modules/register_cubin.cpp.in b/cpp/cmake/modules/register_cutile_fragment.cpp.in similarity index 57% rename from cpp/cmake/modules/register_cubin.cpp.in rename to cpp/cmake/modules/register_cutile_fragment.cpp.in index c27d6829ee..0fc074bdbb 100644 --- a/cpp/cmake/modules/register_cubin.cpp.in +++ b/cpp/cmake/modules/register_cutile_fragment.cpp.in @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "@cubin_header_file@" +#include "@embedded_header_file@" #include @fragment_tag_header_files@ @@ -11,12 +11,12 @@ namespace { using fragment_tag = @fragment_tag@; -using fragment_entry = StaticCubinFragmentEntry; +using fragment_entry = @fragment_entry_type@; } // namespace template <> -const uint8_t* const fragment_entry::data = embedded_cubin; +const uint8_t* const fragment_entry::data = @bin2c_symbol@; template <> -const size_t fragment_entry::length = sizeof(embedded_cubin); +const size_t fragment_entry::length = sizeof(@bin2c_symbol@); diff --git a/cpp/cmake/modules/register_tileir.cpp.in b/cpp/cmake/modules/register_tileir.cpp.in deleted file mode 100644 index fb81acedbc..0000000000 --- a/cpp/cmake/modules/register_tileir.cpp.in +++ /dev/null @@ -1,22 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "@tileir_header_file@" -#include - -@fragment_tag_header_files@ - -namespace { - -using fragment_tag = @fragment_tag@; -using fragment_entry = StaticTileIrBytecodeFragmentEntry; - -} // namespace - -template <> -const uint8_t* const fragment_entry::data = embedded_tileir; - -template <> -const size_t fragment_entry::length = sizeof(embedded_tileir); diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json deleted file mode 100644 index fbd4bfdd64..0000000000 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_cubin_matrix.json +++ /dev/null @@ -1,40 +0,0 @@ -[ - { - "_data": [ - { - "data_type": "half", - "data_abbrev": "h" - }, - { - "data_type": "float", - "data_abbrev": "f" - } - ], - "_arch": [ - { - "gpu_code": "sm_80", - "cc_major": 8, - "cc_minor": 0, - "arch_tag": "cutile_arch_8_0" - }, - { - "gpu_code": "sm_86", - "cc_major": 8, - "cc_minor": 6, - "arch_tag": "cutile_arch_8_6" - }, - { - "gpu_code": "sm_90", - "cc_major": 9, - "cc_minor": 0, - "arch_tag": "cutile_arch_9_0" - }, - { - "gpu_code": "sm_120", - "cc_major": 12, - "cc_minor": 0, - "arch_tag": "cutile_arch_12_0" - } - ] - } -] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json new file mode 100644 index 0000000000..52955863c5 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json @@ -0,0 +1,64 @@ +[ + { + "_data": [ + { + "data_type": "half", + "data_abbrev": "h" + }, + { + "data_type": "float", + "data_abbrev": "f" + } + ], + "_export": [ + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_80", + "cc_major": 8, + "cc_minor": 0, + "arch_tag": "cutile_arch_8_0" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_86", + "cc_major": 8, + "cc_minor": 6, + "arch_tag": "cutile_arch_8_6" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_90", + "cc_major": 9, + "cc_minor": 0, + "arch_tag": "cutile_arch_9_0" + }, + { + "output_format": "cubin", + "artifact_ext": "cubin", + "artifact_basename": "@data_type@_@gpu_code@", + "register": "cubin", + "gpu_code": "sm_120", + "cc_major": 12, + "cc_minor": 0, + "arch_tag": "cutile_arch_12_0" + }, + { + "output_format": "tileir_bytecode", + "artifact_ext": "tilebc", + "artifact_basename": "@data_type@", + "register": "tileir", + "gpu_code": "sm_80", + "bytecode_version": "13.1" + } + ] + } +] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json deleted file mode 100644 index 364c94594c..0000000000 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_tileir_matrix.json +++ /dev/null @@ -1,20 +0,0 @@ -[ - { - "_data": [ - { - "data_type": "half", - "data_abbrev": "h" - }, - { - "data_type": "float", - "data_abbrev": "f" - } - ], - "_tileir": [ - { - "export_gpu_code": "sm_80", - "bytecode_version": "13.1" - } - ] - } -] From c7f7cbd30bcf408c4823606d7f93f9e0df1526cc Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 24 Jun 2026 21:10:59 +0000 Subject: [PATCH 04/10] working test, remove example --- .../cutile/fused_1nn_tile.cu | 110 ++++---- .../cutile/fused_1nn_tile.hpp | 15 +- cpp/tests/CMakeLists.txt | 3 - cpp/tests/cutile/CMakeLists.txt | 23 -- cpp/tests/cutile/cutile_vector_add.cu | 236 ------------------ cpp/tests/cutile/export_vector_add_cubin.py | 133 ---------- cpp/tests/cutile/generate_cutile_cubins.cmake | 117 --------- cpp/tests/cutile/vector_add_kernel.py | 17 -- 8 files changed, 60 insertions(+), 594 deletions(-) delete mode 100644 cpp/tests/cutile/CMakeLists.txt delete mode 100644 cpp/tests/cutile/cutile_vector_add.cu delete mode 100644 cpp/tests/cutile/export_vector_add_cubin.py delete mode 100644 cpp/tests/cutile/generate_cutile_cubins.cmake delete mode 100644 cpp/tests/cutile/vector_add_kernel.py diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu index af8b0b181f..0ad4ee62a5 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -16,6 +16,8 @@ namespace detail { namespace { +constexpr int64_t TILE_M = 128; + template __global__ void pack_fused_1nn_kvp(OutT* out, const int64_t* idx, const float* dist, IdxT len) { @@ -27,13 +29,8 @@ __global__ void pack_fused_1nn_kvp(OutT* out, const int64_t* idx, const float* d } template -bool launch_fused_1nn_tile(const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - cudaStream_t stream) +bool launch_fused_1nn_tile( + const DataT* x, const DataT* y, OutT* out, IdxT m, IdxT n, IdxT k, cudaStream_t stream) { Fused1nnTilePlanner planner; planner.add_entrypoint(); @@ -46,67 +43,70 @@ bool launch_fused_1nn_tile(const DataT* x, RAFT_CUDA_TRY(cudaMallocAsync(&d_idx, m * sizeof(int64_t), stream)); RAFT_CUDA_TRY(cudaMallocAsync(&d_dist, m * sizeof(float), stream)); - int64_t shape_x[2] = {m, k}; - int64_t stride_x[2] = {k, 1}; - int64_t shape_y[2] = {n, k}; - int64_t stride_y[2] = {k, 1}; - int64_t shape_idx[1] = {m}; - int64_t stride_idx[1] = {1}; - int64_t shape_dist[1] = {m}; - int64_t stride_dist[1] = {1}; + int64_t shape_x[2] = {m, k}; + int64_t stride_x[2] = {k, 1}; + int64_t shape_y[2] = {n, k}; + int64_t stride_y[2] = {k, 1}; + int64_t shape_idx = m; + int64_t stride_idx = 1; + int64_t shape_dist = m; + int64_t stride_dist = 1; int64_t M = m, N = n, K = k; - constexpr int64_t tm = 128, tn = 256, tk = 64; void* x_ptr = const_cast(x); void* y_ptr = const_cast(y); void* idx_ptr = d_idx; void* dist_ptr = d_dist; - dim3 grid((m + tm - 1) / tm, 1, 1); + dim3 grid((m + TILE_M - 1) / TILE_M, 1, 1); dim3 block(1, 1, 1); + // cutile_python_v1 (see fused_1nn_float PTX): each 2D array is (ptr, shape0, shape1, + // stride0, stride1); each 1D array is (ptr, shape, stride); ConstantConstraint tile sizes + // are embedded in the module. using fused_1nn_cutile_kernel_t = void(void*, - int64_t*, - int64_t*, - void*, - int64_t*, - int64_t*, + int64_t, + int64_t, + int64_t, + int64_t, void*, - int64_t*, - int64_t*, + int64_t, + int64_t, + int64_t, + int64_t, void*, - int64_t*, - int64_t*, int64_t, int64_t, + void*, + int64_t, int64_t, int64_t, int64_t, int64_t); - launcher->template dispatch( - stream, - grid, - block, - 0, - x_ptr, - shape_x, - stride_x, - y_ptr, - shape_y, - stride_y, - idx_ptr, - shape_idx, - stride_idx, - dist_ptr, - shape_dist, - stride_dist, - M, - N, - K, - tm, - tn, - tk); + launcher->template dispatch(stream, + grid, + block, + 0, + x_ptr, + shape_x[0], + shape_x[1], + stride_x[0], + stride_x[1], + y_ptr, + shape_y[0], + shape_y[1], + stride_y[0], + stride_y[1], + idx_ptr, + shape_idx, + stride_idx, + dist_ptr, + shape_dist, + stride_dist, + M, + N, + K); pack_fused_1nn_kvp<<<(m + 255) / 256, 256, 0, stream>>>(out, d_idx, d_dist, m); RAFT_CUDA_TRY(cudaGetLastError()); @@ -148,13 +148,13 @@ using kvp_i64_f = raft::KeyValuePair; using kvp_i_h = raft::KeyValuePair; using kvp_i64_h = raft::KeyValuePair; -#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, OutT, IdxT) \ +#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, OutT, IdxT) \ template CUVS_EXPORT bool try_fused_1nn_tile(OutT*, \ - const DataT*, \ - const DataT*, \ - IdxT, \ - IdxT, \ - IdxT, \ + const DataT*, \ + const DataT*, \ + IdxT, \ + IdxT, \ + IdxT, \ cuvs::distance::DistanceType, \ cudaStream_t) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp index 30f804d399..d72a020ba7 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -18,8 +18,9 @@ namespace detail { template inline constexpr bool is_fused_1nn_kvp_output_v = - std::is_same_v> || - std::is_same_v>; + (std::is_same_v || std::is_same_v) && + (std::is_same_v> || + std::is_same_v>); template , int> = 0> -bool try_fused_1nn_tile(OutT*, - const DataT*, - const DataT*, - IdxT, - IdxT, - IdxT, - cuvs::distance::DistanceType, - cudaStream_t) +bool try_fused_1nn_tile( + OutT*, const DataT*, const DataT*, IdxT, IdxT, IdxT, cuvs::distance::DistanceType, cudaStream_t) { return false; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 006b35b5c4..9b96f94bf0 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -386,9 +386,6 @@ ConfigureTest( PERCENT 100 ) -# cuTile vector-add example test disabled; fused 1-NN cuTile is covered via libcuvs integration. -# add_subdirectory(cutile) - # ################################################################################################## # Install tests #################################################################################### # ################################################################################################## diff --git a/cpp/tests/cutile/CMakeLists.txt b/cpp/tests/cutile/CMakeLists.txt deleted file mode 100644 index 989c8137d0..0000000000 --- a/cpp/tests/cutile/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# ============================================================================= -# cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. -# SPDX-License-Identifier: Apache-2.0 -# cmake-format: on -# ============================================================================= - -include("${CMAKE_CURRENT_LIST_DIR}/generate_cutile_cubins.cmake") - -generate_cutile_vector_add_cubins(CUTILE_GENERATED_INCLUDE_DIR) - -ConfigureTest( - NAME CUTILE_VECTOR_ADD_TEST - PATH "${CMAKE_CURRENT_LIST_DIR}/cutile_vector_add.cu" - GPUS 1 - PERCENT 100 -) - -add_dependencies(CUTILE_VECTOR_ADD_TEST cutile_vector_add_cubins) - -target_include_directories( - CUTILE_VECTOR_ADD_TEST PRIVATE "${CUTILE_GENERATED_INCLUDE_DIR}" -) diff --git a/cpp/tests/cutile/cutile_vector_add.cu b/cpp/tests/cutile/cutile_vector_add.cu deleted file mode 100644 index 07d694bef1..0000000000 --- a/cpp/tests/cutile/cutile_vector_add.cu +++ /dev/null @@ -1,236 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "../test_utils.cuh" - -#include "vector_add_kernel_symbol.h" -#include "vector_add_sm_100_cubin.h" -#include "vector_add_sm_120_cubin.h" -#include "vector_add_sm_80_cubin.h" -#include "vector_add_sm_86_cubin.h" -#include "vector_add_sm_90_cubin.h" -#include "vector_add_tileir_bytecode.h" - -#include - -#include - -#include -#include -#include - -namespace cuvs { -namespace { - -struct EmbeddedCubin { - int cc_major; - int cc_minor; - const unsigned char* data; - size_t size; -}; - -// Prebuilt cubins for known library targets (see export_vector_add_cubin.py). -constexpr EmbeddedCubin kEmbeddedCubins[] = { - {8, 0, vector_add_sm_80_cubin, sizeof(vector_add_sm_80_cubin)}, - {8, 6, vector_add_sm_86_cubin, sizeof(vector_add_sm_86_cubin)}, - {9, 0, vector_add_sm_90_cubin, sizeof(vector_add_sm_90_cubin)}, - {10, 0, vector_add_sm_100_cubin, sizeof(vector_add_sm_100_cubin)}, - {12, 0, vector_add_sm_120_cubin, sizeof(vector_add_sm_120_cubin)}, -}; - -constexpr EmbeddedCubin kTileIrBytecode = { - -1, - -1, - vector_add_tileir_bytecode, - sizeof(vector_add_tileir_bytecode), -}; - -struct CutileModuleImage { - const uint8_t* data; - size_t size; -}; - -std::optional resolve_vector_add_module(int cc_major, int cc_minor) -{ - for (const auto& entry : kEmbeddedCubins) { - if (entry.cc_major == cc_major && entry.cc_minor == cc_minor) { - return CutileModuleImage{reinterpret_cast(entry.data), entry.size}; - } - } - - int driver_version = 0; - if (cudaDriverGetVersion(&driver_version) != cudaSuccess) { return std::nullopt; } - if (!cuvs::detail::jit_lto::tileir_fallback_available(driver_version)) { - return std::nullopt; - } - return CutileModuleImage{ - reinterpret_cast(kTileIrBytecode.data), kTileIrBytecode.size}; -} - -struct LoadedKernel { - cudaLibrary_t library = nullptr; - cudaKernel_t kernel = nullptr; - bool used_tileir_jit{false}; - const char* skip_reason{nullptr}; - - LoadedKernel() = default; - - LoadedKernel(LoadedKernel&& other) noexcept { *this = std::move(other); } - - LoadedKernel& operator=(LoadedKernel&& other) noexcept - { - if (this != &other) { - unload(); - library = other.library; - kernel = other.kernel; - used_tileir_jit = other.used_tileir_jit; - skip_reason = other.skip_reason; - other.library = nullptr; - other.kernel = nullptr; - } - return *this; - } - - LoadedKernel(const LoadedKernel&) = delete; - LoadedKernel& operator=(const LoadedKernel&) = delete; - - ~LoadedKernel() { unload(); } - - explicit operator bool() const { return kernel != nullptr; } - - private: - void unload() - { - if (library != nullptr) { - RAFT_CUDA_TRY(cudaLibraryUnload(library)); - library = nullptr; - kernel = nullptr; - } - } -}; - -LoadedKernel load_vector_add_kernel(int cc_major, int cc_minor) -{ - LoadedKernel result{}; - result.used_tileir_jit = !cuvs::detail::jit_lto::is_embedded_cubin_arch(cc_major, cc_minor); - - auto image = resolve_vector_add_module(cc_major, cc_minor); - if (!image) { - if (result.used_tileir_jit) { - result.skip_reason = - "TileIR driver JIT unavailable for this GPU. Requires CUDA 13.1+ driver (>= 590.44)."; - } else { - ADD_FAILURE() << "No embedded cuTile module for compute capability " << cc_major << "." - << cc_minor; - } - return result; - } - - const cudaError_t load_status = - cudaLibraryLoadData(&result.library, image->data, nullptr, nullptr, 0, nullptr, nullptr, 0); - if (load_status != cudaSuccess) { - if (result.used_tileir_jit) { - result.skip_reason = - "TileIR driver JIT unavailable for this GPU (requires CUDA 13.1+ driver >= 590.44)."; - SCOPED_TRACE(cudaGetErrorString(load_status)); - } else { - ADD_FAILURE() << "cudaLibraryLoadData failed: " << cudaGetErrorString(load_status); - } - return result; - } - - const cudaError_t kernel_status = - cudaLibraryGetKernel(&result.kernel, result.library, CUTILE_VECTOR_ADD_KERNEL_SYMBOL); - if (kernel_status != cudaSuccess) { - if (result.library != nullptr) { - RAFT_CUDA_TRY(cudaLibraryUnload(result.library)); - result.library = nullptr; - } - result.kernel = nullptr; - if (result.used_tileir_jit) { - result.skip_reason = - "TileIR driver JIT unavailable for this GPU (requires CUDA 13.1+ driver >= 590.44)."; - SCOPED_TRACE(cudaGetErrorString(kernel_status)); - } else { - ADD_FAILURE() << "cudaLibraryGetKernel failed: " << cudaGetErrorString(kernel_status); - } - } - return result; -} - -void run_vector_add(cudaKernel_t kernel) -{ - constexpr int kN = 1024; - constexpr int kTile = 256; - constexpr int kGridDim = (kN + kTile - 1) / kTile; - - float *d_a = nullptr, *d_b = nullptr, *d_c = nullptr; - RAFT_CUDA_TRY(cudaMalloc(&d_a, kN * sizeof(float))); - RAFT_CUDA_TRY(cudaMalloc(&d_b, kN * sizeof(float))); - RAFT_CUDA_TRY(cudaMalloc(&d_c, kN * sizeof(float))); - - std::vector h_a(kN), h_b(kN); - for (int i = 0; i < kN; ++i) { - h_a[i] = static_cast(i); - h_b[i] = static_cast(i * 2); - } - RAFT_CUDA_TRY(cudaMemcpy(d_a, h_a.data(), kN * sizeof(float), cudaMemcpyHostToDevice)); - RAFT_CUDA_TRY(cudaMemcpy(d_b, h_b.data(), kN * sizeof(float), cudaMemcpyHostToDevice)); - RAFT_CUDA_TRY(cudaMemset(d_c, 0, kN * sizeof(float))); - - int64_t shape = kN; - int64_t stride = 1; - void* kernel_args[] = { - &d_a, &shape, &stride, &d_b, &shape, &stride, &d_c, &shape, &stride, - }; - - dim3 grid(kGridDim); - dim3 block(1); - ASSERT_EQ(cudaSuccess, cudaLaunchKernel(kernel, grid, block, kernel_args, 0, 0)) - << "cudaLaunchKernel failed: " << cudaGetErrorString(cudaGetLastError()); - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - - std::vector h_c(kN); - RAFT_CUDA_TRY(cudaMemcpy(h_c.data(), d_c, kN * sizeof(float), cudaMemcpyDeviceToHost)); - - for (int i = 0; i < kN; ++i) { - ASSERT_FLOAT_EQ(h_a[i] + h_b[i], h_c[i]) << "@" << i; - } - - RAFT_CUDA_TRY(cudaFree(d_a)); - RAFT_CUDA_TRY(cudaFree(d_b)); - RAFT_CUDA_TRY(cudaFree(d_c)); -} - -class CutileVectorAddTest : public ::testing::Test { - protected: - void SetUp() override - { - int device = 0; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - RAFT_CUDA_TRY( - cudaDeviceGetAttribute(&cc_major_, cudaDevAttrComputeCapabilityMajor, device)); - RAFT_CUDA_TRY( - cudaDeviceGetAttribute(&cc_minor_, cudaDevAttrComputeCapabilityMinor, device)); - } - - int cc_major_{}; - int cc_minor_{}; -}; - -} // namespace - -TEST_F(CutileVectorAddTest, EmbeddedCubinVectorAdd) -{ - LoadedKernel loaded = load_vector_add_kernel(cc_major_, cc_minor_); - if (loaded.skip_reason) { GTEST_SKIP() << loaded.skip_reason; } - if (!loaded) { return; } - - SCOPED_TRACE(loaded.used_tileir_jit ? "loaded via TileIR driver JIT" - : "loaded via prebuilt cubin"); - run_vector_add(loaded.kernel); -} - -} // namespace cuvs diff --git a/cpp/tests/cutile/export_vector_add_cubin.py b/cpp/tests/cutile/export_vector_add_cubin.py deleted file mode 100644 index fa099189cd..0000000000 --- a/cpp/tests/cutile/export_vector_add_cubin.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. -# SPDX-License-Identifier: Apache-2.0 -"""Export the cuTile vector-add kernel to cubin or TileIR bytecode.""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path -from typing import Literal - -import cuda.tile as ct -from cuda.tile.compilation import ( - ArrayConstraint, - CallingConvention, - ConstantConstraint, - KernelSignature, - export_kernel, -) - -from vector_add_kernel import TILE_SIZE, vector_add - -# cuTile / tileiras gpu_code values used at build time. These correspond to the -# cuvs library CUDA 13 real targets as follows (tileiras has no sm_*a/sm_*f names): -# sm_80 -> 80-real -# sm_86 -> 86-real -# sm_90 -> 90a-real -# sm_100 -> 100f-real -# sm_120 -> 120a-real -SUPPORTED_GPU_CODES = ("sm_80", "sm_86", "sm_90", "sm_100", "sm_120") - -# Minimum TileIR bytecode version supported by cuTile; also the most portable choice. -DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" - - -def _kernel_signature() -> KernelSignature: - array = ArrayConstraint( - ct.float32, - 1, - index_dtype=ct.int64, - stride_lower_bound_incl=0, - alias_groups=(), - may_alias_internally=False, - stride_constant=(1,), - ) - return KernelSignature( - parameters=[array, array, array, ConstantConstraint(TILE_SIZE)], - calling_convention=CallingConvention.cutile_python_v1(), - ).with_mangled_symbol("vector_add") - - -def export_kernel_binary( - output_file: Path, - *, - output_format: Literal["cubin", "tileir_bytecode"], - gpu_code: str, - bytecode_version: str | None = None, - symbol_header: Path | None = None, -) -> str: - if output_format == "cubin" and gpu_code not in SUPPORTED_GPU_CODES: - raise ValueError( - f"Unsupported gpu_code {gpu_code!r}; expected one of {SUPPORTED_GPU_CODES}" - ) - - signature = _kernel_signature() - export_kwargs: dict = { - "kernel": vector_add, - "signatures": [signature], - "output_file": str(output_file), - "gpu_code": gpu_code, - "output_format": output_format, - } - if output_format == "tileir_bytecode": - export_kwargs["bytecode_version"] = bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION - - export_kernel(**export_kwargs) - - if symbol_header is not None: - symbol_header.write_text( - "\n".join( - [ - "// Generated by export_vector_add_cubin.py; do not edit.", - "#pragma once", - f'#define CUTILE_VECTOR_ADD_KERNEL_SYMBOL "{signature.symbol}"', - "", - ] - ) - ) - - return signature.symbol - - -def main() -> int: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("output_file", type=Path, help="Output cubin or .tilebc path") - parser.add_argument( - "--format", - choices=("cubin", "tileir_bytecode"), - default="cubin", - help="Export format (default: cubin)", - ) - parser.add_argument( - "--gpu-code", - required=True, - choices=SUPPORTED_GPU_CODES, - help="tileiras / export_kernel compile target (e.g. sm_120)", - ) - parser.add_argument( - "--bytecode-version", - default=DEFAULT_TILEIR_BYTECODE_VERSION, - help="TileIR bytecode version when --format=tileir_bytecode (default: 13.1)", - ) - parser.add_argument( - "--symbol-header", - type=Path, - default=None, - help="Optional header that defines CUTILE_VECTOR_ADD_KERNEL_SYMBOL", - ) - args = parser.parse_args() - - symbol = export_kernel_binary( - args.output_file, - output_format=args.format, - gpu_code=args.gpu_code, - bytecode_version=args.bytecode_version, - symbol_header=args.symbol_header, - ) - print(symbol) - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/cpp/tests/cutile/generate_cutile_cubins.cmake b/cpp/tests/cutile/generate_cutile_cubins.cmake deleted file mode 100644 index 766d3167c6..0000000000 --- a/cpp/tests/cutile/generate_cutile_cubins.cmake +++ /dev/null @@ -1,117 +0,0 @@ -# ============================================================================= -# cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. -# SPDX-License-Identifier: Apache-2.0 -# cmake-format: on -# ============================================================================= - -include_guard(GLOBAL) - -# Build-time cuTile cubin targets. Maps to cuvs CUDA 13 -real library arches (75-real omitted). -set(CUTILE_VECTOR_ADD_GPU_CODES sm_80 sm_86 sm_90 sm_100 sm_120) - -function(generate_cutile_vector_add_cubins output_include_dir_var) - find_package(Python3 REQUIRED COMPONENTS Interpreter) - find_package(CUDAToolkit REQUIRED) - - find_program( - CUTILE_BIN2C - NAMES bin2c - PATHS ${CUDAToolkit_BIN_DIR} - REQUIRED - ) - - execute_process( - COMMAND "${Python3_EXECUTABLE}" -c "import cuda.tile" - RESULT_VARIABLE _cutile_import_result - OUTPUT_QUIET - ERROR_QUIET - ) - if(NOT _cutile_import_result EQUAL 0) - message( - FATAL_ERROR - "cuda.tile (cuTile Python) is required to build CUTILE_VECTOR_ADD_TEST. " - "Install it in the active Python environment, e.g. pip install cuda-tile[tileiras]." - ) - endif() - - set(_cutile_source_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}") - set(_cutile_binary_dir "${CMAKE_CURRENT_BINARY_DIR}/cutile_generated") - file(MAKE_DIRECTORY "${_cutile_binary_dir}") - - set(_symbol_header "${_cutile_binary_dir}/vector_add_kernel_symbol.h") - set(_first_gpu_code TRUE) - - foreach(_gpu_code IN LISTS CUTILE_VECTOR_ADD_GPU_CODES) - set(_cubin_file "${_cutile_binary_dir}/vector_add_${_gpu_code}.cubin") - set(_cubin_header "${_cutile_binary_dir}/vector_add_${_gpu_code}_cubin.h") - - if(_first_gpu_code) - set(_symbol_arg --symbol-header "${_symbol_header}") - set(_cubin_outputs "${_cubin_file}" "${_symbol_header}") - set(_first_gpu_code FALSE) - else() - set(_symbol_arg) - set(_cubin_outputs "${_cubin_file}") - endif() - - add_custom_command( - OUTPUT ${_cubin_outputs} - COMMAND - "${Python3_EXECUTABLE}" "${_cutile_source_dir}/export_vector_add_cubin.py" - "${_cubin_file}" --gpu-code "${_gpu_code}" ${_symbol_arg} - DEPENDS "${_cutile_source_dir}/export_vector_add_cubin.py" - "${_cutile_source_dir}/vector_add_kernel.py" - COMMENT "Exporting cuTile vector_add cubin for ${_gpu_code}" - VERBATIM - ) - - add_custom_command( - OUTPUT "${_cubin_header}" - COMMAND "${CUTILE_BIN2C}" --const --name "vector_add_${_gpu_code}_cubin" --static - "${_cubin_file}" > "${_cubin_header}" - DEPENDS "${_cubin_file}" - COMMENT "Embedding vector_add ${_gpu_code} cubin via bin2c" - VERBATIM - ) - - list(APPEND _generated_headers "${_cubin_header}") - endforeach() - - # Portable TileIR bytecode for driver JIT on architectures without a prebuilt cubin. - # Requires a CUDA 13.1+ driver (>= 590.44); see Tile IR bytecode docs. - set(_tileir_file "${_cutile_binary_dir}/vector_add.tilebc") - set(_tileir_header "${_cutile_binary_dir}/vector_add_tileir_bytecode.h") - - add_custom_command( - OUTPUT "${_tileir_file}" - COMMAND - "${Python3_EXECUTABLE}" "${_cutile_source_dir}/export_vector_add_cubin.py" - "${_tileir_file}" --format tileir_bytecode --gpu-code sm_80 --bytecode-version 13.1 - DEPENDS "${_cutile_source_dir}/export_vector_add_cubin.py" - "${_cutile_source_dir}/vector_add_kernel.py" - COMMENT "Exporting cuTile vector_add TileIR bytecode (v13.1)" - VERBATIM - ) - - add_custom_command( - OUTPUT "${_tileir_header}" - COMMAND "${CUTILE_BIN2C}" --const --name vector_add_tileir_bytecode --static "${_tileir_file}" - > "${_tileir_header}" - DEPENDS "${_tileir_file}" - COMMENT "Embedding vector_add TileIR bytecode via bin2c" - VERBATIM - ) - - list(APPEND _generated_headers "${_tileir_header}") - - add_custom_target( - cutile_vector_add_cubins - DEPENDS "${_symbol_header}" ${_generated_headers} - ) - - set(${output_include_dir_var} - "${_cutile_binary_dir}" - PARENT_SCOPE - ) -endfunction() diff --git a/cpp/tests/cutile/vector_add_kernel.py b/cpp/tests/cutile/vector_add_kernel.py deleted file mode 100644 index 46b7a607c6..0000000000 --- a/cpp/tests/cutile/vector_add_kernel.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. -# SPDX-License-Identifier: Apache-2.0 -"""cuTile Python vector-add kernel used by the embedded-cubin example test.""" - -from __future__ import annotations - -import cuda.tile as ct - -TILE_SIZE = 256 - - -@ct.kernel -def vector_add(a, b, c, TILE_SIZE: ct.Constant): - bid = ct.bid(0) - ta = ct.load(a, bid, TILE_SIZE) - tb = ct.load(b, bid, TILE_SIZE) - ct.store(c, bid, ta + tb) From 86c9311b87fae027be808b1a4d04f174e0bfbbb2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 24 Jun 2026 21:16:48 +0000 Subject: [PATCH 05/10] style check --- cpp/CMakeLists.txt | 47 +++++----- .../modules/generate_cutile_kernels.cmake | 86 ++++++++++--------- .../modules/register_cutile_fragment.cpp.in | 8 +- .../cuvs/detail/jit_lto/FragmentEntry.hpp | 5 +- .../cuvs/detail/jit_lto/tileir_compat.hpp | 4 +- .../detail/jit_lto/TileAlgorithmPlanner.cpp | 4 +- cpp/src/distance/detail/fused_distance_nn.cuh | 4 +- .../cutile/export_fused_1nn.py | 16 +++- .../cutile/fused_1nn_kernel.py | 30 ++++--- cpp/tests/neighbors/distance_nn_helper.cuh | 4 +- python/libcuvs/pyproject.toml | 1 + 11 files changed, 118 insertions(+), 91 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a1f3f3973c..70e2509a88 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -960,32 +960,38 @@ if(NOT BUILD_CPU_ONLY) include(cmake/modules/generate_cutile_kernels.cmake) set(fused_1nn_cutile_dir - "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/fused_distance_nn/cutile") + "${CMAKE_CURRENT_SOURCE_DIR}/src/distance/detail/fused_distance_nn/cutile" + ) set(cutile_fused_1nn_generated_dir - "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/fused_1nn/cutile") + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/distance/fused_1nn/cutile" + ) generate_cutile_kernels( cutile_fused_1nn_files - KERNEL_DIR "${fused_1nn_cutile_dir}" - KERNEL_BASENAME "fused_1nn" - KERNEL_PYTHON "fused_1nn_kernel.py" - EXPORT_SCRIPT "export_fused_1nn.py" - OUTPUT_DIRECTORY "${cutile_fused_1nn_generated_dir}" - MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" + KERNEL_DIR + "${fused_1nn_cutile_dir}" + KERNEL_BASENAME + "fused_1nn" + KERNEL_PYTHON + "fused_1nn_kernel.py" + EXPORT_SCRIPT + "export_fused_1nn.py" + OUTPUT_DIRECTORY + "${cutile_fused_1nn_generated_dir}" + MATRIX_JSON_FILE + "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" FRAGMENT_TAG_FORMAT_CUBIN - "cuvs::distance::detail::fragment_tag_fused_1nn_cubin" + "cuvs::distance::detail::fragment_tag_fused_1nn_cubin" FRAGMENT_TAG_FORMAT_TILEIR - "cuvs::distance::detail::fragment_tag_fused_1nn_tileir" + "cuvs::distance::detail::fragment_tag_fused_1nn_tileir" FRAGMENT_TAG_HEADER_FILES - "" - "" - "" + "" + "" + "" ) if(NOT DEFINED CUVS_CUTILE_ENABLED) set(CUVS_CUTILE_ENABLED 0) endif() - target_compile_definitions( - cuvs_cpp_headers INTERFACE CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} - ) + target_compile_definitions(cuvs_cpp_headers INTERFACE CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED}) generate_inst_matrix( cagra_build_inst_files MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/cagra_build_matrix.json" @@ -1288,9 +1294,9 @@ if(NOT BUILD_CPU_ONLY) ) target_compile_definitions( - cuvs_objs PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> - $<$:NVTX_ENABLED> - CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} + cuvs_objs + PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> CUVS_CUTILE_ENABLED=${CUVS_CUTILE_ENABLED} ) target_link_libraries( @@ -1308,8 +1314,7 @@ if(NOT BUILD_CPU_ONLY) PUBLIC "$" "$" INTERFACE "$" - PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" - "${CMAKE_CURRENT_BINARY_DIR}/src" + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_BINARY_DIR}/src" "${cutile_fused_1nn_generated_dir}" ) diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake index f0219dc842..ac8d369cdc 100644 --- a/cpp/cmake/modules/generate_cutile_kernels.cmake +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -10,7 +10,10 @@ include_guard(GLOBAL) include(${CMAKE_CURRENT_LIST_DIR}/compute_matrix_product.cmake) function(generate_cutile_kernels_stub) - set(CUVS_CUTILE_ENABLED 0 PARENT_SCOPE) + set(CUVS_CUTILE_ENABLED + 0 + PARENT_SCOPE + ) endfunction() function(_cutile_fragment_tag_header_files output_var) @@ -51,15 +54,13 @@ function(_cutile_kernels_setup) find_program( CUTILE_BIN2C NAMES bin2c - PATHS ${CUDAToolkit_BIN_DIR} - REQUIRED + PATHS ${CUDAToolkit_BIN_DIR} REQUIRED ) execute_process( COMMAND "${Python3_EXECUTABLE}" -c "import cuda.tile" RESULT_VARIABLE _cutile_import_result - OUTPUT_QUIET - ERROR_QUIET + OUTPUT_QUIET ERROR_QUIET ) if(NOT _cutile_import_result EQUAL 0) message( @@ -77,8 +78,14 @@ function(_cutile_kernels_setup) file(MAKE_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}") - set(Python3_EXECUTABLE "${Python3_EXECUTABLE}" PARENT_SCOPE) - set(CUTILE_BIN2C "${CUTILE_BIN2C}" PARENT_SCOPE) + set(Python3_EXECUTABLE + "${Python3_EXECUTABLE}" + PARENT_SCOPE + ) + set(CUTILE_BIN2C + "${CUTILE_BIN2C}" + PARENT_SCOPE + ) set(_CUTILE_SETUP_OK TRUE PARENT_SCOPE @@ -87,15 +94,8 @@ endfunction() function(process_cutile_matrix_entry source_list_var) set(options) - set(one_value - KERNEL_DIR - KERNEL_BASENAME - KERNEL_PYTHON - EXPORT_SCRIPT - OUTPUT_DIRECTORY - FRAGMENT_TAG_FORMAT_CUBIN - FRAGMENT_TAG_FORMAT_TILEIR - MATRIX_JSON_ENTRY + set(one_value KERNEL_DIR KERNEL_BASENAME KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY + FRAGMENT_TAG_FORMAT_CUBIN FRAGMENT_TAG_FORMAT_TILEIR MATRIX_JSON_ENTRY ) set(multi_value FRAGMENT_TAG_HEADER_FILES) cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) @@ -116,9 +116,7 @@ function(process_cutile_matrix_entry source_list_var) message(FATAL_ERROR "Unknown cuTile register kind '${register}'") endif() - _cutile_fragment_tag_header_files( - fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} - ) + _cutile_fragment_tag_header_files(fragment_tag_header_files ${_CUTILE_FRAGMENT_TAG_HEADER_FILES}) string(CONFIGURE "${artifact_basename}" _artifact_basename @ONLY) set(_artifact_stem "${_CUTILE_KERNEL_BASENAME}_${_artifact_basename}") @@ -145,8 +143,8 @@ function(process_cutile_matrix_entry source_list_var) add_custom_command( OUTPUT "${_embedded_header}" - COMMAND "${CUTILE_BIN2C}" --const --name ${bin2c_symbol} --static "${_artifact_file}" - > "${_embedded_header}" + COMMAND "${CUTILE_BIN2C}" --const --name ${bin2c_symbol} --static "${_artifact_file}" > + "${_embedded_header}" DEPENDS "${_artifact_file}" VERBATIM ) @@ -163,15 +161,8 @@ endfunction() function(generate_cutile_kernels source_list_var) set(options) - set(one_value - KERNEL_DIR - KERNEL_BASENAME - KERNEL_PYTHON - EXPORT_SCRIPT - OUTPUT_DIRECTORY - MATRIX_JSON_FILE - FRAGMENT_TAG_FORMAT_CUBIN - FRAGMENT_TAG_FORMAT_TILEIR + set(one_value KERNEL_DIR KERNEL_BASENAME KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY + MATRIX_JSON_FILE FRAGMENT_TAG_FORMAT_CUBIN FRAGMENT_TAG_FORMAT_TILEIR ) set(multi_value FRAGMENT_TAG_HEADER_FILES) cmake_parse_arguments(_CUTILE "${options}" "${one_value}" "${multi_value}" ${ARGN}) @@ -184,8 +175,7 @@ function(generate_cutile_kernels source_list_var) endif() _cutile_kernels_setup( - MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" - OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" + MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}" OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" ) if(NOT _CUTILE_SETUP_OK) generate_cutile_kernels_stub() @@ -206,19 +196,31 @@ function(generate_cutile_kernels source_list_var) string(JSON matrix_json_entry GET "${matrix_product}" "${i}") process_cutile_matrix_entry( "${source_list_var}" - KERNEL_DIR "${_CUTILE_KERNEL_DIR}" - KERNEL_BASENAME "${_CUTILE_KERNEL_BASENAME}" - KERNEL_PYTHON "${_CUTILE_KERNEL_PYTHON}" - EXPORT_SCRIPT "${_CUTILE_EXPORT_SCRIPT}" - OUTPUT_DIRECTORY "${_CUTILE_OUTPUT_DIRECTORY}" - FRAGMENT_TAG_FORMAT_CUBIN "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" - FRAGMENT_TAG_FORMAT_TILEIR "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" - FRAGMENT_TAG_HEADER_FILES ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} - MATRIX_JSON_ENTRY "${matrix_json_entry}" + KERNEL_DIR + "${_CUTILE_KERNEL_DIR}" + KERNEL_BASENAME + "${_CUTILE_KERNEL_BASENAME}" + KERNEL_PYTHON + "${_CUTILE_KERNEL_PYTHON}" + EXPORT_SCRIPT + "${_CUTILE_EXPORT_SCRIPT}" + OUTPUT_DIRECTORY + "${_CUTILE_OUTPUT_DIRECTORY}" + FRAGMENT_TAG_FORMAT_CUBIN + "${_CUTILE_FRAGMENT_TAG_FORMAT_CUBIN}" + FRAGMENT_TAG_FORMAT_TILEIR + "${_CUTILE_FRAGMENT_TAG_FORMAT_TILEIR}" + FRAGMENT_TAG_HEADER_FILES + ${_CUTILE_FRAGMENT_TAG_HEADER_FILES} + MATRIX_JSON_ENTRY + "${matrix_json_entry}" ) endforeach() - set(CUVS_CUTILE_ENABLED 1 PARENT_SCOPE) + set(CUVS_CUTILE_ENABLED + 1 + PARENT_SCOPE + ) set(${source_list_var} "${${source_list_var}}" PARENT_SCOPE diff --git a/cpp/cmake/modules/register_cutile_fragment.cpp.in b/cpp/cmake/modules/register_cutile_fragment.cpp.in index 0fc074bdbb..3ffd5c0d0c 100644 --- a/cpp/cmake/modules/register_cutile_fragment.cpp.in +++ b/cpp/cmake/modules/register_cutile_fragment.cpp.in @@ -8,10 +8,10 @@ @fragment_tag_header_files@ -namespace { - -using fragment_tag = @fragment_tag@; -using fragment_entry = @fragment_entry_type@; + namespace +{ + using fragment_tag = @fragment_tag@; + using fragment_entry = @fragment_entry_type@; } // namespace diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index df69ec1d7b..6c399d860a 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -115,7 +115,10 @@ struct StaticTileIrBytecodeFragmentEntry final : TileIrBytecodeFragmentEntry { return StaticTileIrBytecodeFragmentEntry::data; } - size_t get_length() const override { return StaticTileIrBytecodeFragmentEntry::length; } + size_t get_length() const override + { + return StaticTileIrBytecodeFragmentEntry::length; + } const char* get_key() const override { diff --git a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp index d63759fb36..f15407fd4c 100644 --- a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp +++ b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp @@ -88,8 +88,8 @@ inline bool query_current_device_arch(int& cc_major, int& cc_minor) inline bool cutile_launch_available_on_current_device() { - int cc_major = 0; - int cc_minor = 0; + int cc_major = 0; + int cc_minor = 0; int driver_version = 0; if (!query_current_device_arch(cc_major, cc_minor)) { return false; } if (!query_driver_version(driver_version)) { return false; } diff --git a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp index edb6269213..e0ce77e789 100644 --- a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp @@ -23,9 +23,7 @@ std::shared_ptr TileAlgorithmPlanner::build() { int cc_major = 0; int cc_minor = 0; - if (!cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { - return nullptr; - } + if (!cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { return nullptr; } int driver_version = 0; if (cudaDriverGetVersion(&driver_version) != cudaSuccess) { return nullptr; } diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index 8b47092b58..b1b18e58f6 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -19,8 +19,8 @@ #include "fused_distance_nn/helper_structs.cuh" #include "fused_distance_nn/simt_kernel.cuh" #include "pairwise_distance_base.cuh" // PairwiseDistances -#include #include +#include #include // raft::KeyValuePair #include // raft::identity_op #include // Policy diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py index 6a20be24ef..10a4fa9ec1 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py @@ -100,7 +100,9 @@ def export_binary( "output_format": output_format, } if output_format == "tileir_bytecode": - export_kwargs["bytecode_version"] = bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION + export_kwargs["bytecode_version"] = ( + bytecode_version or DEFAULT_TILEIR_BYTECODE_VERSION + ) export_kernel(**export_kwargs) @@ -110,14 +112,20 @@ def export_binary( def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("output_file", type=Path) - parser.add_argument("--format", choices=("cubin", "tileir_bytecode"), default="cubin") - parser.add_argument("--data-type", choices=tuple(KERNELS.keys()), required=True) + parser.add_argument( + "--format", choices=("cubin", "tileir_bytecode"), default="cubin" + ) + parser.add_argument( + "--data-type", choices=tuple(KERNELS.keys()), required=True + ) parser.add_argument( "--gpu-code", default=DEFAULT_TILEIR_EXPORT_GPU_CODE, help="Target SM for cubin export, or compile hint for TileIR bytecode export", ) - parser.add_argument("--bytecode-version", default=DEFAULT_TILEIR_BYTECODE_VERSION) + parser.add_argument( + "--bytecode-version", default=DEFAULT_TILEIR_BYTECODE_VERSION + ) args = parser.parse_args() print( diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py index 232b9506af..65fe165b70 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -14,17 +14,23 @@ def _make_kernel(data_type: str): - if data_type == "half": - dtype = ct.float16 - acc_dtype = ct.float32 - elif data_type == "float": - dtype = ct.float32 - acc_dtype = ct.float32 - else: + if data_type not in ("half", "float"): raise ValueError(f"Unsupported data_type {data_type!r}") + acc_dtype = ct.float32 @ct.kernel - def fused_1nn_kernel(A, B, OutIdx, OutDist, M, N, K, tm: ConstInt, tn: ConstInt, tk: ConstInt): + def fused_1nn_kernel( + A, + B, + OutIdx, + OutDist, + M, + N, + K, + tm: ConstInt, + tn: ConstInt, + tk: ConstInt, + ): bidm = ct.bid(0) best_dist = ct.full((tm,), -3.4e38, acc_dtype) @@ -38,8 +44,12 @@ def fused_1nn_kernel(A, B, OutIdx, OutDist, M, N, K, tm: ConstInt, tn: ConstInt, accumulator = ct.full((tm, tn), 0, dtype=acc_dtype) for k in range(num_tiles_k): - a = ct.load(A, index=(bidm, k), shape=(tm, tk), padding_mode=zero_pad) - b_T = ct.load(B, index=(n, k), shape=(tn, tk), padding_mode=zero_pad) + a = ct.load( + A, index=(bidm, k), shape=(tm, tk), padding_mode=zero_pad + ) + b_T = ct.load( + B, index=(n, k), shape=(tn, tk), padding_mode=zero_pad + ) accumulator = ct.mma(a, ct.transpose(b_T), accumulator) curr_max = ct.max(accumulator, axis=1) diff --git a/cpp/tests/neighbors/distance_nn_helper.cuh b/cpp/tests/neighbors/distance_nn_helper.cuh index 422879918f..ea440387b4 100644 --- a/cpp/tests/neighbors/distance_nn_helper.cuh +++ b/cpp/tests/neighbors/distance_nn_helper.cuh @@ -91,8 +91,8 @@ RAFT_KERNEL ref_nn_kernel( if (metric == DistanceType::InnerProduct) { AccT score = inner_product_score(&A[m * K], &B[n * K], K); if (score > best_score) { - best_score = score; - best_index = n; + best_score = score; + best_index = n; } continue; } diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index 5025daa66d..b4e848304f 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -19,6 +19,7 @@ authors = [ license = "Apache-2.0" requires-python = ">=3.11" dependencies = [ + "cuda-tile[tileiras]", "cuda-toolkit[cublas,curand,cusolver,cusparse,nvrtc]==13.*", "libraft==26.8.*,>=0.0.0a0", "librmm==26.8.*,>=0.0.0a0", From e352629e5398e33dc780888f3b96c8c610f4dafd Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 25 Jun 2026 21:16:05 +0000 Subject: [PATCH 06/10] start integrating other metrics --- cpp/CMakeLists.txt | 4 +- .../modules/generate_cutile_kernels.cmake | 32 ++++- .../modules/register_cutile_fragment.cpp.in | 9 ++ .../cuvs/detail/jit_lto/AlgorithmPlanner.hpp | 3 + .../cuvs/detail/jit_lto/FragmentEntry.hpp | 39 ++++++ .../fused_distance_nn/fused_1nn_fragments.hpp | 70 +++++++++- .../detail/jit_lto/TileAlgorithmPlanner.cpp | 46 +++++++ cpp/src/distance/detail/fused_distance_nn.cuh | 2 +- .../cutile/export_fused_1nn.py | 97 +++++++++----- .../cutile/fused_1nn_cutile_matrix.json | 31 ++++- .../cutile/fused_1nn_kernel.py | 85 ++++++++++--- .../cutile/fused_1nn_planner.hpp | 56 +++++--- .../cutile/fused_1nn_tile.cu | 120 +++++++++++++----- .../cutile/fused_1nn_tile.hpp | 32 +++-- 14 files changed, 508 insertions(+), 118 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 70e2509a88..87716cd296 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -980,9 +980,9 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" FRAGMENT_TAG_FORMAT_CUBIN - "cuvs::distance::detail::fragment_tag_fused_1nn_cubin" + "cuvs::distance::detail::fragment_tag_fused_1nn_cubin, cuvs::detail::jit_lto::@arch_tag@>" FRAGMENT_TAG_FORMAT_TILEIR - "cuvs::distance::detail::fragment_tag_fused_1nn_tileir" + "cuvs::distance::detail::fragment_tag_fused_1nn_tileir>" FRAGMENT_TAG_HEADER_FILES "" "" diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake index ac8d369cdc..abdca118ad 100644 --- a/cpp/cmake/modules/generate_cutile_kernels.cmake +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -92,6 +92,30 @@ function(_cutile_kernels_setup) ) endfunction() +function(_cutile_generate_matrix_tiles_header header_path matrix_json_file) + file(READ "${matrix_json_file}" _matrix_json) + string(JSON _tile0 GET "${_matrix_json}" 0 "_tile" 0) + string(JSON _tile_m GET "${_tile0}" "tile_m") + string(JSON _tile_n GET "${_tile0}" "tile_n") + string(JSON _tile_k GET "${_tile0}" "tile_k") + file( + WRITE "${header_path}" + "/* + * Generated from ${matrix_json_file} by generate_cutile_kernels.cmake — do not edit. + */ +#pragma once + +#include + +namespace cuvs::distance::detail { + +using fused_1nn_matrix_tile = cutile_tile_config<${_tile_m}, ${_tile_n}, ${_tile_k}>; + +} // namespace cuvs::distance::detail +" + ) +endfunction() + function(process_cutile_matrix_entry source_list_var) set(options) set(one_value KERNEL_DIR KERNEL_BASENAME KERNEL_PYTHON EXPORT_SCRIPT OUTPUT_DIRECTORY @@ -125,7 +149,10 @@ function(process_cutile_matrix_entry source_list_var) set(_fragment_cpp "${_CUTILE_OUTPUT_DIRECTORY}/${_artifact_stem}_${register}.cpp") set(embedded_header_file "${_artifact_stem}_${register}.h") - set(_python_args --format "${output_format}" --data-type "${data_type}" --gpu-code "${gpu_code}") + set(_python_args + --format "${output_format}" --data-type "${data_type}" --metric "${metric}" --tile-m + "${tile_m}" --tile-n "${tile_n}" --tile-k "${tile_k}" --gpu-code "${gpu_code}" + ) if(DEFINED bytecode_version AND NOT "${bytecode_version}" STREQUAL "") list(APPEND _python_args --bytecode-version "${bytecode_version}") endif() @@ -188,6 +215,9 @@ function(generate_cutile_kernels source_list_var) compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_CUTILE_MATRIX_JSON_FILE}") + set(_matrix_tiles_header "${_CUTILE_OUTPUT_DIRECTORY}/fused_1nn_cutile_tiles.hpp") + _cutile_generate_matrix_tiles_header("${_matrix_tiles_header}" "${_CUTILE_MATRIX_JSON_FILE}") + string(JSON len LENGTH "${matrix_product}") math(EXPR last "${len} - 1") diff --git a/cpp/cmake/modules/register_cutile_fragment.cpp.in b/cpp/cmake/modules/register_cutile_fragment.cpp.in index 3ffd5c0d0c..de0472a779 100644 --- a/cpp/cmake/modules/register_cutile_fragment.cpp.in +++ b/cpp/cmake/modules/register_cutile_fragment.cpp.in @@ -20,3 +20,12 @@ const uint8_t* const fragment_entry::data = @bin2c_symbol@; template <> const size_t fragment_entry::length = sizeof(@bin2c_symbol@); + +template <> +const int fragment_entry::tile_m = @tile_m@; + +template <> +const int fragment_entry::tile_n = @tile_n@; + +template <> +const int fragment_entry::tile_k = @tile_k@; diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp index d727c73b9d..7ff8487d20 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp @@ -97,6 +97,9 @@ struct TileAlgorithmPlanner : AlgorithmPlanner { tileir_fragment_ = std::make_unique>(); } + /** Tile geometry from the cubin or TileIR fragment that would load on this device. */ + CutileTileConfig tile_config() const; + protected: std::vector> cubin_fragments_; std::unique_ptr tileir_fragment_; diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index 6c399d860a..0961595f8d 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -63,6 +63,13 @@ struct UDFFatbinFragment final : FatbinFragmentEntry { std::vector bytes_; }; +/** cuTile GEMM-style block geometry embedded in generated Static*FragmentEntry specializations. */ +struct CutileTileConfig { + int tile_m; + int tile_n; + int tile_k; +}; + /** Embedded CUDA binary module (cubin), loaded directly via cudaLibraryLoadData. */ struct CubinFragmentEntry { virtual ~CubinFragmentEntry() = default; @@ -76,6 +83,12 @@ struct CubinFragmentEntry { virtual int get_cc_major() const = 0; virtual int get_cc_minor() const = 0; + + virtual int get_tile_m() const { return 0; } + + virtual int get_tile_n() const { return 0; } + + virtual int get_tile_k() const { return 0; } }; template @@ -93,6 +106,16 @@ struct StaticCubinFragmentEntry final : CubinFragmentEntry { int get_cc_minor() const override { return FragmentTag::cc_minor; } + int get_tile_m() const override { return tile_m; } + + int get_tile_n() const override { return tile_n; } + + int get_tile_k() const override { return tile_k; } + + static const int tile_m; + static const int tile_n; + static const int tile_k; + static const uint8_t* const data; static const size_t length; }; @@ -106,6 +129,12 @@ struct TileIrBytecodeFragmentEntry { virtual size_t get_length() const = 0; virtual const char* get_key() const = 0; + + virtual int get_tile_m() const { return 0; } + + virtual int get_tile_n() const { return 0; } + + virtual int get_tile_k() const { return 0; } }; template @@ -125,6 +154,16 @@ struct StaticTileIrBytecodeFragmentEntry final : TileIrBytecodeFragmentEntry { return typeid(StaticTileIrBytecodeFragmentEntry).name(); } + int get_tile_m() const override { return tile_m; } + + int get_tile_n() const override { return tile_n; } + + int get_tile_k() const override { return tile_k; } + + static const int tile_m; + static const int tile_n; + static const int tile_k; + static const uint8_t* const data; static const size_t length; }; diff --git a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp index 517118bbe2..658c6e882b 100644 --- a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp @@ -6,16 +6,82 @@ #pragma once #include +#include namespace cuvs::distance::detail { -template +struct metric_tag_ip {}; +struct metric_tag_l2 {}; +struct metric_tag_cos {}; + +template +struct cutile_tile_config { + static constexpr int tile_m = TileM; + static constexpr int tile_n = TileN; + static constexpr int tile_k = TileK; +}; + +template +struct fused_1nn_metric_tag; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_ip; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_l2; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_l2; +}; + +template <> +struct fused_1nn_metric_tag { + using type = metric_tag_cos; +}; + +/** Whether sqrt is applied when packing distance into KVP output. */ +template +constexpr bool fused_1nn_apply_sqrt_at_pack(bool is_sqrt) +{ + if constexpr (Metric == cuvs::distance::DistanceType::L2Expanded || + Metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return is_sqrt; + } else { + return false; + } +} + +template +using fused_1nn_metric_tag_t = typename fused_1nn_metric_tag::type; + +template +struct fused_1nn_data_tag; + +template <> +struct fused_1nn_data_tag { + using type = cuvs::neighbors::detail::tag_f; +}; + +template <> +struct fused_1nn_data_tag { + using type = cuvs::neighbors::detail::tag_h; +}; + +template +using fused_1nn_data_tag_t = typename fused_1nn_data_tag::type; + +template struct fragment_tag_fused_1nn_cubin { static constexpr int cc_major = ArchTag::cc_major; static constexpr int cc_minor = ArchTag::cc_minor; }; -template +template struct fragment_tag_fused_1nn_tileir {}; } // namespace cuvs::distance::detail diff --git a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp index e0ce77e789..1487abb239 100644 --- a/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp +++ b/cpp/src/detail/jit_lto/TileAlgorithmPlanner.cpp @@ -9,6 +9,31 @@ #include #include +#include +#include + +namespace { + +template +CutileTileConfig tile_config_from_fragment(const FragmentT* fragment, const std::string& entrypoint) +{ + if (fragment == nullptr) { + RAFT_FAIL("cuTile planner '%s' has no registered fragments", entrypoint.c_str()); + } + const int tile_m = fragment->get_tile_m(); + const int tile_n = fragment->get_tile_n(); + const int tile_k = fragment->get_tile_k(); + if (tile_m <= 0 || tile_n <= 0 || tile_k <= 0) { + RAFT_FAIL( + "cuTile planner '%s' is missing tile geometry in its static fragment (check " + "register_cutile_fragment.cpp generation)", + entrypoint.c_str()); + } + return CutileTileConfig{tile_m, tile_n, tile_k}; +} + +} // namespace + std::string TileAlgorithmPlanner::get_planner_key() const { std::string key = this->entrypoint; @@ -19,6 +44,27 @@ std::string TileAlgorithmPlanner::get_planner_key() const return key; } +CutileTileConfig TileAlgorithmPlanner::tile_config() const +{ + int cc_major = 0; + int cc_minor = 0; + if (cuvs::detail::jit_lto::get_device_compute_capability(cc_major, cc_minor)) { + for (const auto& fragment : cubin_fragments_) { + if (fragment->get_cc_major() == cc_major && fragment->get_cc_minor() == cc_minor) { + return tile_config_from_fragment(fragment.get(), entrypoint); + } + } + } + + if (tileir_fragment_) { return tile_config_from_fragment(tileir_fragment_.get(), entrypoint); } + + if (!cubin_fragments_.empty()) { + return tile_config_from_fragment(cubin_fragments_.front().get(), entrypoint); + } + + RAFT_FAIL("cuTile planner '%s' has no registered fragments", entrypoint.c_str()); +} + std::shared_ptr TileAlgorithmPlanner::build() { int cc_major = 0; diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index b1b18e58f6..63ef5396ac 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -64,7 +64,7 @@ void fusedDistanceNNImpl(OutT* min, #if CUVS_CUTILE_ENABLED if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device() && - try_fused_1nn_tile(min, x, y, m, n, k, metric, stream)) { + try_fused_1nn_tile(min, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { return; } #endif diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py index 10a4fa9ec1..fcefe7a027 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py @@ -19,7 +19,7 @@ export_kernel, ) -from fused_1nn_kernel import KERNELS, KERNEL_SYMBOLS, TILE_CONSTANTS +from fused_1nn_kernel import METRICS, kernel_symbol, make_kernel, metric_abbrev DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" # cuTile requires a gpu_code even for TileIR bytecode export: it selects the compilation @@ -35,50 +35,75 @@ def _dtype_for(data_type: str): raise ValueError(f"Unsupported data_type {data_type!r}") -def _kernel_signature(data_type: str) -> KernelSignature: - elem = _dtype_for(data_type) - array = ArrayConstraint( - elem, - 2, - index_dtype=ct.int64, - stride_lower_bound_incl=0, - alias_groups=(), - may_alias_internally=False, - ) - idx_array = ArrayConstraint( - ct.int64, - 1, +def _data_abbrev(data_type: str) -> str: + return {"half": "h", "float": "f"}[data_type] + + +def _relaxed_matrix_constraint(elem_dtype): + """Array constraints matching the relaxed TMA-friendly layout from gemm_nn_cutile.""" + return ArrayConstraint( + elem_dtype, + ndim=2, index_dtype=ct.int64, - stride_lower_bound_incl=0, + stride_lower_bound_incl=(0, None), alias_groups=(), may_alias_internally=False, - stride_constant=(1,), + stride_constant=(None, 1), + stride_divisible_by=(8, 1), + shape_divisible_by=(1, 1), + base_addr_divisible_by=16, ) - dist_array = ArrayConstraint( - ct.float32, - 1, + + +def _relaxed_vector_constraint(elem_dtype, *, tma_friendly: bool = False): + base_div = 16 if tma_friendly else 1 + return ArrayConstraint( + elem_dtype, + ndim=1, index_dtype=ct.int64, - stride_lower_bound_incl=0, + stride_lower_bound_incl=(None,), alias_groups=(), may_alias_internally=False, stride_constant=(1,), + stride_divisible_by=(1,), + shape_divisible_by=(1,), + base_addr_divisible_by=base_div, ) - tm, tn, tk = TILE_CONSTANTS + + +def _kernel_signature( + data_type: str, + metric: str, + tile_m: int, + tile_n: int, + tile_k: int, +) -> KernelSignature: + elem = _dtype_for(data_type) + matrix = _relaxed_matrix_constraint(elem) + norm_array = _relaxed_vector_constraint(elem, tma_friendly=True) + idx_array = _relaxed_vector_constraint(ct.int64) + dist_array = _relaxed_vector_constraint(ct.float32) + + abbrev = _data_abbrev(data_type) + symbol = kernel_symbol(abbrev, metric_abbrev(metric)) + return KernelSignature( parameters=[ - array, - array, + matrix, + matrix, + norm_array, + norm_array, idx_array, dist_array, ScalarConstraint(ct.int64), ScalarConstraint(ct.int64), ScalarConstraint(ct.int64), - ConstantConstraint(tm), - ConstantConstraint(tn), - ConstantConstraint(tk), + ConstantConstraint(tile_m), + ConstantConstraint(tile_n), + ConstantConstraint(tile_k), ], calling_convention=CallingConvention.cutile_python_v1(), - ).with_symbol(KERNEL_SYMBOLS[data_type]) + ).with_symbol(symbol) def export_binary( @@ -86,11 +111,15 @@ def export_binary( *, output_format: Literal["cubin", "tileir_bytecode"], data_type: str, + metric: str, + tile_m: int, + tile_n: int, + tile_k: int, gpu_code: str, bytecode_version: str | None = None, ) -> str: - kernel = KERNELS[data_type] - signature = _kernel_signature(data_type) + kernel = make_kernel(data_type, metric, tile_m, tile_n, tile_k) + signature = _kernel_signature(data_type, metric, tile_m, tile_n, tile_k) export_kwargs = { "kernel": kernel, @@ -116,8 +145,12 @@ def main() -> int: "--format", choices=("cubin", "tileir_bytecode"), default="cubin" ) parser.add_argument( - "--data-type", choices=tuple(KERNELS.keys()), required=True + "--data-type", choices=("half", "float"), required=True ) + parser.add_argument("--metric", choices=METRICS, required=True) + parser.add_argument("--tile-m", type=int, required=True) + parser.add_argument("--tile-n", type=int, required=True) + parser.add_argument("--tile-k", type=int, required=True) parser.add_argument( "--gpu-code", default=DEFAULT_TILEIR_EXPORT_GPU_CODE, @@ -133,6 +166,10 @@ def main() -> int: args.output_file, output_format=args.format, data_type=args.data_type, + metric=args.metric, + tile_m=args.tile_m, + tile_n=args.tile_n, + tile_k=args.tile_k, gpu_code=args.gpu_code, bytecode_version=args.bytecode_version, ) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json index 52955863c5..3aa9dffd8a 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json @@ -10,11 +10,32 @@ "data_abbrev": "f" } ], + "_metric": [ + { + "metric": "inner_product", + "metric_abbrev": "ip" + }, + { + "metric": "l2_expanded", + "metric_abbrev": "l2" + }, + { + "metric": "cosine_expanded", + "metric_abbrev": "cos" + } + ], + "_tile": [ + { + "tile_m": 128, + "tile_n": 128, + "tile_k": 64 + } + ], "_export": [ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_80", "cc_major": 8, @@ -24,7 +45,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_86", "cc_major": 8, @@ -34,7 +55,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_90", "cc_major": 9, @@ -44,7 +65,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_120", "cc_major": 12, @@ -54,7 +75,7 @@ { "output_format": "tileir_bytecode", "artifact_ext": "tilebc", - "artifact_basename": "@data_type@", + "artifact_basename": "@data_type@_@metric_abbrev@", "register": "tileir", "gpu_code": "sm_80", "bytecode_version": "13.1" diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py index 65fe165b70..162e6ceb6b 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 -"""cuTile fused GEMM + inner-product 1-NN (argmax dot product) for cuVS.""" +"""cuTile fused GEMM + 1-NN kernels (InnerProduct, L2Expanded, CosineExpanded).""" from __future__ import annotations @@ -8,20 +8,38 @@ ConstInt = ct.Constant[int] -TILE_M = 128 -TILE_N = 256 -TILE_K = 64 +# Default tile geometry; overridden per export via make_kernel(..., tile_m, tile_n, tile_k). +DEFAULT_TILE_M = 128 +DEFAULT_TILE_N = 128 +DEFAULT_TILE_K = 64 +METRICS = ("inner_product", "l2_expanded", "cosine_expanded") -def _make_kernel(data_type: str): + +def make_kernel( + data_type: str, + metric: str, + tile_m: int = DEFAULT_TILE_M, + tile_n: int = DEFAULT_TILE_N, + tile_k: int = DEFAULT_TILE_K, +): + """Build a cuTile kernel with metric and tile sizes baked in at compile time.""" if data_type not in ("half", "float"): raise ValueError(f"Unsupported data_type {data_type!r}") + if metric not in METRICS: + raise ValueError(f"Unsupported metric {metric!r}") + acc_dtype = ct.float32 + is_ip = metric == "inner_product" + is_l2 = metric == "l2_expanded" + is_cos = metric == "cosine_expanded" @ct.kernel def fused_1nn_kernel( A, B, + A_norm, + B_norm, OutIdx, OutDist, M, @@ -33,7 +51,10 @@ def fused_1nn_kernel( ): bidm = ct.bid(0) - best_dist = ct.full((tm,), -3.4e38, acc_dtype) + if is_ip: + best_dist = ct.full((tm,), -3.4e38, acc_dtype) + else: + best_dist = ct.full((tm,), 3.4e38, acc_dtype) best_idx = ct.zeros((tm,), ct.int64) num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) @@ -52,11 +73,37 @@ def fused_1nn_kernel( ) accumulator = ct.mma(a, ct.transpose(b_T), accumulator) - curr_max = ct.max(accumulator, axis=1) - curr_idx = ct.argmax(accumulator, axis=1) + if is_ip: + score = accumulator + elif is_l2 or is_cos: + a_norm = ct.load( + A_norm, index=(bidm,), shape=(tm,), padding_mode=zero_pad + ) + b_norm = ct.load( + B_norm, index=(n,), shape=(tn,), padding_mode=zero_pad + ) + if is_l2: + # L2 expanded: ||x||^2 + ||y||^2 - 2 * dot(x, y); norms are squared. + score = ( + a_norm[:, None] + b_norm[None, :] - (2.0 * accumulator) + ) + elif is_cos: + # Cosine expanded distance: 1 - dot / (||x|| * ||y||); norms are L2 (not squared). + # No sqrt during the reduction — only arithmetic on stored distance if needed. + denom = a_norm[:, None] * b_norm[None, :] + score = 1.0 - (accumulator / denom) + + if is_ip: + curr_best = ct.max(score, axis=1) + curr_idx = ct.argmax(score, axis=1) + update = curr_best > best_dist + best_dist = ct.where(update, curr_best, best_dist) + else: + curr_best = ct.min(score, axis=1) + curr_idx = ct.argmin(score, axis=1) + update = curr_best < best_dist + best_dist = ct.where(update, curr_best, best_dist) - update = curr_max > best_dist - best_dist = ct.where(update, curr_max, best_dist) best_idx = ct.where(update, n * tn + curr_idx, best_idx) ct.store(OutIdx, index=(bidm,), tile=best_idx) @@ -65,14 +112,14 @@ def fused_1nn_kernel( return fused_1nn_kernel -KERNELS = { - "half": _make_kernel("half"), - "float": _make_kernel("float"), -} +def kernel_symbol(data_abbrev: str, metric_abbrev: str) -> str: + """Must stay in sync with fused_1nn_kernel_entrypoint() in fused_1nn_planner.hpp.""" + return f"fused_1nn_{data_abbrev}_{metric_abbrev}" -KERNEL_SYMBOLS = { - "half": "fused_1nn_half", - "float": "fused_1nn_float", -} -TILE_CONSTANTS = (TILE_M, TILE_N, TILE_K) +def metric_abbrev(metric: str) -> str: + return { + "inner_product": "ip", + "l2_expanded": "l2", + "cosine_expanded": "cos", + }[metric] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp index dd2a539528..ae0ae118bd 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp @@ -12,32 +12,49 @@ #include #include +#include "fused_1nn_cutile_tiles.hpp" + namespace cuvs::distance::detail { -/** Must match KERNEL_SYMBOLS in fused_1nn_kernel.py (export uses with_symbol). */ -template +/** Must match kernel_symbol() in fused_1nn_kernel.py (export uses with_symbol). */ +template inline const char* fused_1nn_kernel_entrypoint() { - if constexpr (std::is_same_v) { - return "fused_1nn_half"; - } else if constexpr (std::is_same_v) { - return "fused_1nn_float"; - } else { - static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data type"); - return ""; + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return "fused_1nn_f_ip"; + } else if constexpr (std::is_same_v) { + return "fused_1nn_f_l2"; + } else if constexpr (std::is_same_v) { + return "fused_1nn_f_cos"; + } + } else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + return "fused_1nn_h_ip"; + } else if constexpr (std::is_same_v) { + return "fused_1nn_h_l2"; + } else if constexpr (std::is_same_v) { + return "fused_1nn_h_cos"; + } } + static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data/metric combination"); + return ""; } -template +template struct Fused1nnTilePlanner : TileAlgorithmPlanner { + using DataTag = fused_1nn_data_tag_t; + using MetricTag = fused_1nn_metric_tag_t; + inline static LauncherJitCache launcher_jit_cache{}; Fused1nnTilePlanner() - : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), launcher_jit_cache) + : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), launcher_jit_cache) { } - /** Registers embedded cubin modules (one per SM); see register_cubin.cpp object files. */ + /** Registers embedded cubin modules (one per SM); see register_cutile_fragment.cpp object files. + */ void add_entrypoint() { using cuvs::detail::jit_lto::cutile_arch_12_0; @@ -45,15 +62,20 @@ struct Fused1nnTilePlanner : TileAlgorithmPlanner { using cuvs::detail::jit_lto::cutile_arch_8_6; using cuvs::detail::jit_lto::cutile_arch_9_0; - this->add_static_fragment>(); - this->add_static_fragment>(); - this->add_static_fragment>(); - this->add_static_fragment>(); + this->add_static_fragment< + fragment_tag_fused_1nn_cubin>(); + this->add_static_fragment< + fragment_tag_fused_1nn_cubin>(); + this->add_static_fragment< + fragment_tag_fused_1nn_cubin>(); + this->add_static_fragment< + fragment_tag_fused_1nn_cubin>(); } void add_tileir_fallback() { - this->add_static_tileir_fragment>(); + this->add_static_tileir_fragment< + fragment_tag_fused_1nn_tileir>(); } }; diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu index 0ad4ee62a5..e343afca30 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -16,28 +16,42 @@ namespace detail { namespace { -constexpr int64_t TILE_M = 128; - template -__global__ void pack_fused_1nn_kvp(OutT* out, const int64_t* idx, const float* dist, IdxT len) +__global__ void pack_fused_1nn_kvp( + OutT* out, const int64_t* idx, const float* dist, IdxT len, bool apply_sqrt) { IdxT i = blockIdx.x * blockDim.x + threadIdx.x; if (i < len) { - out[i].key = static_cast(idx[i]); - out[i].value = static_cast(dist[i]); + out[i].key = static_cast(idx[i]); + float value = dist[i]; + if (apply_sqrt) { value = sqrtf(value); } + out[i].value = static_cast(value); } } -template -bool launch_fused_1nn_tile( - const DataT* x, const DataT* y, OutT* out, IdxT m, IdxT n, IdxT k, cudaStream_t stream) +template +bool launch_fused_1nn_tile(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + bool is_sqrt, + cudaStream_t stream) { - Fused1nnTilePlanner planner; + if constexpr (!std::is_same_v && !std::is_same_v) { return false; } + + Fused1nnTilePlanner planner; planner.add_entrypoint(); planner.add_tileir_fallback(); - auto launcher = planner.try_get_launcher(); + const CutileTileConfig tile_cfg = planner.tile_config(); + auto launcher = planner.try_get_launcher(); if (!launcher) { return false; } + const bool apply_sqrt = fused_1nn_apply_sqrt_at_pack(is_sqrt); + int64_t* d_idx = nullptr; float* d_dist = nullptr; RAFT_CUDA_TRY(cudaMallocAsync(&d_idx, m * sizeof(int64_t), stream)); @@ -47,6 +61,10 @@ bool launch_fused_1nn_tile( int64_t stride_x[2] = {k, 1}; int64_t shape_y[2] = {n, k}; int64_t stride_y[2] = {k, 1}; + int64_t shape_xn = m; + int64_t stride_xn = 1; + int64_t shape_yn = n; + int64_t stride_yn = 1; int64_t shape_idx = m; int64_t stride_idx = 1; int64_t shape_dist = m; @@ -56,15 +74,17 @@ bool launch_fused_1nn_tile( void* x_ptr = const_cast(x); void* y_ptr = const_cast(y); + void* xn_ptr = const_cast(xn); + void* yn_ptr = const_cast(yn); void* idx_ptr = d_idx; void* dist_ptr = d_dist; - dim3 grid((m + TILE_M - 1) / TILE_M, 1, 1); + const int64_t tile_m = tile_cfg.tile_m; + dim3 grid((m + tile_m - 1) / tile_m, 1, 1); dim3 block(1, 1, 1); - // cutile_python_v1 (see fused_1nn_float PTX): each 2D array is (ptr, shape0, shape1, - // stride0, stride1); each 1D array is (ptr, shape, stride); ConstantConstraint tile sizes - // are embedded in the module. + // cutile_python_v1: 2D array (ptr, shape0, shape1, stride0, stride1); + // 1D array (ptr, shape, stride); tile sizes are embedded constants. using fused_1nn_cutile_kernel_t = void(void*, int64_t, int64_t, @@ -81,6 +101,12 @@ bool launch_fused_1nn_tile( void*, int64_t, int64_t, + void*, + int64_t, + int64_t, + void*, + int64_t, + int64_t, int64_t, int64_t, int64_t); @@ -98,6 +124,12 @@ bool launch_fused_1nn_tile( shape_y[1], stride_y[0], stride_y[1], + xn_ptr, + shape_xn, + stride_xn, + yn_ptr, + shape_yn, + stride_yn, idx_ptr, shape_idx, stride_idx, @@ -108,39 +140,62 @@ bool launch_fused_1nn_tile( N, K); - pack_fused_1nn_kvp<<<(m + 255) / 256, 256, 0, stream>>>(out, d_idx, d_dist, m); + pack_fused_1nn_kvp + <<<(m + 255) / 256, 256, 0, stream>>>(out, d_idx, d_dist, m, apply_sqrt); RAFT_CUDA_TRY(cudaGetLastError()); RAFT_CUDA_TRY(cudaFreeAsync(d_idx, stream)); RAFT_CUDA_TRY(cudaFreeAsync(d_dist, stream)); return true; } +template +bool try_fused_1nn_tile_dispatch(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool is_sqrt, + cudaStream_t stream) +{ + switch (metric) { + case cuvs::distance::DistanceType::InnerProduct: + return launch_fused_1nn_tile( + x, y, xn, yn, min, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::L2Expanded: + return launch_fused_1nn_tile( + x, y, xn, yn, min, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::L2SqrtExpanded: + return launch_fused_1nn_tile( + x, y, xn, yn, min, m, n, k, is_sqrt, stream); + case cuvs::distance::DistanceType::CosineExpanded: + return launch_fused_1nn_tile( + x, y, xn, yn, min, m, n, k, is_sqrt, stream); + default: return false; + } +} + } // namespace -template , int>> +template + requires Fused1nnKvpOutput bool try_fused_1nn_tile(OutT* min, const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric, + bool is_sqrt, cudaStream_t stream) { - if (metric != cuvs::distance::DistanceType::InnerProduct) { return false; } - - if constexpr (std::is_same_v) { - return launch_fused_1nn_tile( - x, y, min, m, n, k, stream); - } else if constexpr (std::is_same_v) { - return launch_fused_1nn_tile( - x, y, min, m, n, k, stream); - } else { - return false; - } + return try_fused_1nn_tile_dispatch( + min, x, y, xn, yn, m, n, k, metric, is_sqrt, stream); } using kvp_i_f = raft::KeyValuePair; @@ -150,15 +205,18 @@ using kvp_i64_h = raft::KeyValuePair; #define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, OutT, IdxT) \ template CUVS_EXPORT bool try_fused_1nn_tile(OutT*, \ + const DataT*, \ + const DataT*, \ const DataT*, \ const DataT*, \ IdxT, \ IdxT, \ IdxT, \ cuvs::distance::DistanceType, \ + bool, \ cudaStream_t) -// int and int32_t are the same on LP64; one instantiation covers both. +// int and int64_t are the same on LP64; one instantiation covers both. CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i_f, int); CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i64_f, int64_t); CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i_f, int); diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp index d72a020ba7..807ecdb233 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -5,6 +5,7 @@ #pragma once +#include #include #include @@ -22,25 +23,36 @@ inline constexpr bool is_fused_1nn_kvp_output_v = (std::is_same_v> || std::is_same_v>); -template , int> = 0> +template +concept Fused1nnKvpOutput = is_fused_1nn_kvp_output_v; + +template + requires Fused1nnKvpOutput bool try_fused_1nn_tile(OutT* min, const DataT* x, const DataT* y, + const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric, + bool is_sqrt, cudaStream_t stream); -template , int> = 0> -bool try_fused_1nn_tile( - OutT*, const DataT*, const DataT*, IdxT, IdxT, IdxT, cuvs::distance::DistanceType, cudaStream_t) +template + requires(!Fused1nnKvpOutput) +bool try_fused_1nn_tile(OutT*, + const DataT*, + const DataT*, + const DataT*, + const DataT*, + IdxT, + IdxT, + IdxT, + cuvs::distance::DistanceType, + bool, + cudaStream_t) { return false; } From d6560fca9e678c9eabcbcf519506c000f15c993d Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 25 Jun 2026 21:37:40 +0000 Subject: [PATCH 07/10] if constexpr exit --- .../cutile/fused_1nn_planner.hpp | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp index ae0ae118bd..c70ab3f87b 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp @@ -20,25 +20,28 @@ namespace cuvs::distance::detail { template inline const char* fused_1nn_kernel_entrypoint() { - if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - return "fused_1nn_f_ip"; - } else if constexpr (std::is_same_v) { - return "fused_1nn_f_l2"; - } else if constexpr (std::is_same_v) { - return "fused_1nn_f_cos"; - } - } else if constexpr (std::is_same_v) { - if constexpr (std::is_same_v) { - return "fused_1nn_h_ip"; - } else if constexpr (std::is_same_v) { - return "fused_1nn_h_l2"; - } else if constexpr (std::is_same_v) { - return "fused_1nn_h_cos"; - } + if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_f_ip"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_f_l2"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_f_cos"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_h_ip"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_h_l2"; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return "fused_1nn_h_cos"; + } else { + static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data/metric combination"); + return ""; } - static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data/metric combination"); - return ""; } template From 674285321fd5c37d9d79aea86da4b9a9a670c1ff Mon Sep 17 00:00:00 2001 From: divyegala Date: Thu, 25 Jun 2026 22:50:48 +0000 Subject: [PATCH 08/10] passing KMeans tests --- .../cuvs/detail/jit_lto/tileir_compat.hpp | 11 ++- cpp/src/cluster/detail/kmeans_balanced.cuh | 85 ++++++++++++++----- cpp/src/cluster/detail/kmeans_common.cuh | 47 +++++++--- .../detail/minClusterDistanceCompute.cu | 44 ++++++---- cpp/src/distance/detail/fused_distance_nn.cuh | 27 +++--- .../cutile/fused_1nn_kernel.py | 10 +++ .../cutile/fused_1nn_tile.cu | 2 + .../cutile/fused_1nn_tile.hpp | 33 ++++++- 8 files changed, 187 insertions(+), 72 deletions(-) diff --git a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp index f15407fd4c..f114233179 100644 --- a/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp +++ b/cpp/include/cuvs/detail/jit_lto/tileir_compat.hpp @@ -61,12 +61,16 @@ inline bool tileir_fallback_available(int driver_version) * is CUDA 13+, and either a matching embedded cubin exists (no driver JIT required) or the driver * can JIT the embedded TileIR bytecode fallback. */ +#if CUVS_CUTILE_ENABLED inline bool cutile_launch_available_for_arch(int cc_major, int cc_minor, int driver_version) { - if (!cutile_integration_enabled()) { return false; } + if (!runtime_cuda13_or_newer()) { return false; } if (has_embedded_cubin_for_arch(cc_major, cc_minor)) { return true; } return tileir_fallback_available(driver_version); } +#else +inline constexpr bool cutile_launch_available_for_arch(int, int, int) { return false; } +#endif inline bool query_driver_version(int& driver_version) { @@ -86,6 +90,7 @@ inline bool query_current_device_arch(int& cc_major, int& cc_minor) return true; } +#if CUVS_CUTILE_ENABLED inline bool cutile_launch_available_on_current_device() { int cc_major = 0; @@ -95,5 +100,9 @@ inline bool cutile_launch_available_on_current_device() if (!query_driver_version(driver_version)) { return false; } return cutile_launch_available_for_arch(cc_major, cc_minor, driver_version); } +#else +/** Compile-time false when cuTile is not built; use in if constexpr to skip cuTile-only paths. */ +inline constexpr bool cutile_launch_available_on_current_device() { return false; } +#endif } // namespace cuvs::detail::jit_lto diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 7fac255810..86e254f473 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -121,32 +121,63 @@ inline std::enable_if_t> predict_core( break; } case cuvs::distance::DistanceType::InnerProduct: { - // TODO: pass buffer - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + if (use_cutile_fused_nn(handle, n_rows, n_clusters, dim)) { + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); + rmm::device_uvector workspace(0, stream, mr); + + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(centers, n_clusters, dim); + auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); + + auto minClusterAndDistance = + raft::make_device_mdarray, IdxT>( + handle, mr, raft::make_extents(n_rows)); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + minClusterAndDistance.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + + raft::linalg::map(handle, + raft::make_const_mdspan(minClusterAndDistance.view()), + raft::make_device_vector_view(labels, n_rows), + raft::compose_op, raft::key_op>()); + } else { + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - MathT alpha = -1.0; - MathT beta = 0.0; + MathT alpha = -1.0; + MathT beta = 0.0; - raft::linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); + raft::linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); + auto distances_const_view = + raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); + } break; } default: { @@ -195,6 +226,14 @@ auto calc_minibatch_size(const raft::resources& handle, mem_per_row += sizeof(MathT) * n_clusters; } } break; + case distance::DistanceType::InnerProduct: { + if (use_cutile_fused_nn(handle, n_rows, n_clusters, dim)) { + mem_per_row += sizeof(int); + mem_per_row += sizeof(raft::KeyValuePair); + } else { + mem_per_row += sizeof(MathT) * n_clusters; + } + } break; // Other metrics require storing a distance matrix. default: { mem_per_row += sizeof(MathT) * n_clusters; diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index ba98dadca6..0606d77dec 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -7,6 +7,7 @@ #include "../../distance/distance.cuh" #include #include +#include #include #include @@ -57,31 +58,51 @@ namespace cuvs::cluster::kmeans::detail { +template +inline constexpr bool is_cutile_fused_data_type_v = + std::is_same_v || std::is_same_v; + /** - * @brief Returns true if the fused distance NN implementation should be used. + * @brief Returns true if the fused distance NN implementation should be used (CUTLASS and/or + * cuTile). + * + * Float/half: use fused whenever cuTile can launch (any architecture and problem size). If cuTile + * is unavailable, fall back to legacy CUTLASS fused on Ampere and Hopper only. Double and other + * types never use cuTile; they keep the historical CUTLASS/unfused heuristics on pre-Blackwell + * GPUs. * - * On Ampere (SM <= 8.x) always use fused. - * On Hopper (SM 9.x) use fused when m or n >= 4096. - * On Blackwell (SM >= 10.x) use unfused. + * Callers route through fusedDistanceNNMinReduce when this returns true; cuTile dispatch inside + * that API is gated separately by dtype (see fusedDistanceNNImpl). */ template bool use_fused(const raft::resources& handle, IdxT m, IdxT n, IdxT k) { + (void)k; cudaDeviceProp prop; prop = raft::resource::get_device_properties(handle); - if (prop.major <= 8) { - // Use fused for Ampere or before - return true; - } else if (prop.major == 9 && (m >= 4096 || n >= 4096)) { - // On Hopper if m, n are bigger than 4096, use fused - return true; - } else if (prop.major >= 10) { - // On Blackwell onwards, use unfused - return false; + + if constexpr (is_cutile_fused_data_type_v) { + if constexpr (cuvs::detail::jit_lto::library_built_with_cutile()) { + if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { return true; } + } + return prop.major <= 9; } + + if (prop.major >= 10) { return false; } + if (prop.major <= 8) { return true; } + if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return true; } return false; } +/** True when assignment should use the cuTile fused 1-NN kernel (float/half only). */ +template +bool use_cutile_fused_nn(const raft::resources& /*handle*/, IdxT /*m*/, IdxT /*n*/, IdxT /*k*/) +{ + if constexpr (!is_cutile_fused_data_type_v) { return false; } + if constexpr (!cuvs::detail::jit_lto::library_built_with_cutile()) { return false; } + return cuvs::detail::jit_lto::cutile_launch_available_on_current_device(); +} + template struct SamplingOp { DataT* rnd; diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index b15119599e..65678faa08 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -27,33 +27,41 @@ void minClusterAndDistanceCompute( int batch_centroids, rmm::device_uvector& workspace) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; - - if (is_fused) { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = centroids.extent(0); + bool is_l2_cos_fused = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + const bool is_ip_cutile = + metric == cuvs::distance::DistanceType::InnerProduct && + use_cutile_fused_nn(handle, n_samples, n_clusters, n_features); + + if (is_l2_cos_fused || is_ip_cutile) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (metric == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm( - handle, centroids, centroidsNorm, raft::sqrt_op{}); - } else { - raft::linalg::norm( - handle, centroids, centroidsNorm); + if (is_l2_cos_fused) { + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } } raft::KeyValuePair initial_value(0, std::numeric_limits::max()); raft::matrix::fill(handle, minClusterAndDistance, initial_value); - bool should_use_fused = + const bool should_use_fused = use_fused(handle, n_samples, n_clusters, n_features); + auto centroidsNormConst = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + if (should_use_fused) { workspace.resize((sizeof(int)) * n_samples, stream); @@ -62,7 +70,7 @@ void minClusterAndDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm.data_handle(), + centroidsNormConst.data_handle(), n_samples, n_clusters, n_features, @@ -83,7 +91,7 @@ void minClusterAndDistanceCompute( X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), - centroidsNorm.data_handle(), + centroidsNormConst.data_handle(), n_samples, n_clusters, n_features, diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index 63ef5396ac..476ab9c2be 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -5,21 +5,14 @@ #pragma once -#ifndef CUVS_CUTILE_ENABLED -#define CUVS_CUTILE_ENABLED 0 -#endif - #include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op -#include "fused_distance_nn/cutlass_base.cuh" -#if CUVS_CUTILE_ENABLED #include "fused_distance_nn/cutile/fused_1nn_tile.hpp" -#endif +#include "fused_distance_nn/cutlass_base.cuh" #include "fused_distance_nn/fused_cosine_nn.cuh" #include "fused_distance_nn/fused_l2_nn.cuh" #include "fused_distance_nn/helper_structs.cuh" #include "fused_distance_nn/simt_kernel.cuh" #include "pairwise_distance_base.cuh" // PairwiseDistances -#include #include #include // raft::KeyValuePair #include // raft::identity_op @@ -62,12 +55,16 @@ void fusedDistanceNNImpl(OutT* min, // The kernel policy is determined by fusedDistanceNN. typedef Policy P; -#if CUVS_CUTILE_ENABLED - if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device() && - try_fused_1nn_tile(min, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { - return; + // Callers (e.g. use_fused) enable this API for CUTLASS fused as well as cuTile; only try cuTile + // for float/half KVP output so double and other types never instantiate cuTile symbols here. + if constexpr (is_fused_1nn_cutile_data_v) { + if constexpr (cuvs::detail::jit_lto::library_built_with_cutile() && + is_fused_1nn_kvp_output_v) { + if (try_fused_1nn_tile(min, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { + return; + } + } } -#endif dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); @@ -88,10 +85,12 @@ void fusedDistanceNNImpl(OutT* min, break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: - // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. fusedL2NNImpl( min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); break; + case cuvs::distance::DistanceType::InnerProduct: + // cuTile is the only fused InnerProduct implementation; callers must gate on availability. + break; default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; } } diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py index 162e6ceb6b..7d78525869 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -93,6 +93,16 @@ def fused_1nn_kernel( denom = a_norm[:, None] * b_norm[None, :] score = 1.0 - (accumulator / denom) + # Only the final N-tile can include zero-padded centroid columns. + if n == num_tiles_n - 1: + col = ct.arange(tn, dtype=ct.int64) + global_col = n * tn + col + valid = global_col < N + if is_ip: + score = ct.where(valid[None, :], score, -3.4e38) + else: + score = ct.where(valid[None, :], score, 3.4e38) + if is_ip: curr_best = ct.max(score, axis=1) curr_idx = ct.argmax(score, axis=1) diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu index e343afca30..d292f0522b 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -110,6 +110,7 @@ bool launch_fused_1nn_tile(const DataT* x, int64_t, int64_t, int64_t); + std::cout << "Launching cuTile kernel" << std::endl; launcher->template dispatch(stream, grid, block, @@ -194,6 +195,7 @@ bool try_fused_1nn_tile(OutT* min, bool is_sqrt, cudaStream_t stream) { + if (!cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { return false; } return try_fused_1nn_tile_dispatch( min, x, y, xn, yn, m, n, k, metric, is_sqrt, stream); } diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp index 807ecdb233..563d2583d8 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -11,21 +11,30 @@ #include #include +#include #include +#ifndef CUVS_CUTILE_ENABLED +#define CUVS_CUTILE_ENABLED 0 +#endif + namespace cuvs { namespace distance { namespace detail { +template +inline constexpr bool is_fused_1nn_cutile_data_v = + std::is_same_v || std::is_same_v; + template inline constexpr bool is_fused_1nn_kvp_output_v = - (std::is_same_v || std::is_same_v) && - (std::is_same_v> || - std::is_same_v>); + is_fused_1nn_cutile_data_v && (std::is_same_v> || + std::is_same_v>); template concept Fused1nnKvpOutput = is_fused_1nn_kvp_output_v; +#if CUVS_CUTILE_ENABLED template requires Fused1nnKvpOutput bool try_fused_1nn_tile(OutT* min, @@ -39,6 +48,24 @@ bool try_fused_1nn_tile(OutT* min, cuvs::distance::DistanceType metric, bool is_sqrt, cudaStream_t stream); +#else +template + requires Fused1nnKvpOutput +bool try_fused_1nn_tile(OutT*, + const DataT*, + const DataT*, + const DataT*, + const DataT*, + IdxT, + IdxT, + IdxT, + cuvs::distance::DistanceType, + bool, + cudaStream_t) +{ + return false; +} +#endif template requires(!Fused1nnKvpOutput) From db6b38524cc260dffc2b8d73ec61ebd887f68e43 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 30 Jun 2026 04:12:33 +0000 Subject: [PATCH 09/10] undo kvp, add constrains and alignments to tile export --- cpp/CMakeLists.txt | 4 +- .../modules/generate_cutile_kernels.cmake | 18 +- .../fused_distance_nn/fused_1nn_fragments.hpp | 26 +- cpp/src/cluster/detail/kmeans.cuh | 67 ++--- cpp/src/cluster/detail/kmeans_balanced.cuh | 139 +++++++---- cpp/src/cluster/detail/kmeans_common.cuh | 207 ++++++++-------- cpp/src/cluster/detail/kmeans_mg.cuh | 48 ++-- cpp/src/cluster/detail/kmeans_mg_batched.cuh | 16 +- .../detail/minClusterDistanceCompute.cu | 234 ++++++++++-------- cpp/src/cluster/kmeans.cuh | 25 +- cpp/src/distance/detail/fused_distance_nn.cuh | 81 +++--- .../cutile/export_fused_1nn.py | 77 ++++-- .../cutile/fused_1nn_cutile_matrix.json | 26 +- .../cutile/fused_1nn_kernel.py | 61 +++-- .../cutile/fused_1nn_planner.hpp | 54 ++-- .../cutile/fused_1nn_tile.cu | 135 +++++----- .../cutile/fused_1nn_tile.hpp | 40 +-- .../fused_distance_nn/fused_cosine_nn.cuh | 37 ++- .../detail/fused_distance_nn/fused_l2_nn.cuh | 37 +-- .../fused_distance_nn/helper_structs.cuh | 105 ++++++-- .../predicated_tile_iterator_reduced_vec.h | 4 +- cpp/src/distance/fused_distance_nn-inl.cuh | 157 ++++-------- cpp/tests/CMakeLists.txt | 2 +- cpp/tests/neighbors/distance_nn.cu | 73 +++--- cpp/tests/neighbors/distance_nn_helper.cuh | 28 +++ 25 files changed, 950 insertions(+), 751 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 87716cd296..84979d05c1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -980,9 +980,9 @@ if(NOT BUILD_CPU_ONLY) MATRIX_JSON_FILE "${fused_1nn_cutile_dir}/fused_1nn_cutile_matrix.json" FRAGMENT_TAG_FORMAT_CUBIN - "cuvs::distance::detail::fragment_tag_fused_1nn_cubin, cuvs::detail::jit_lto::@arch_tag@>" + "cuvs::distance::detail::fragment_tag_fused_1nn_cubin, cuvs::detail::jit_lto::@arch_tag@>" FRAGMENT_TAG_FORMAT_TILEIR - "cuvs::distance::detail::fragment_tag_fused_1nn_tileir>" + "cuvs::distance::detail::fragment_tag_fused_1nn_tileir>" FRAGMENT_TAG_HEADER_FILES "" "" diff --git a/cpp/cmake/modules/generate_cutile_kernels.cmake b/cpp/cmake/modules/generate_cutile_kernels.cmake index abdca118ad..9cd8a207c8 100644 --- a/cpp/cmake/modules/generate_cutile_kernels.cmake +++ b/cpp/cmake/modules/generate_cutile_kernels.cmake @@ -150,8 +150,22 @@ function(process_cutile_matrix_entry source_list_var) set(embedded_header_file "${_artifact_stem}_${register}.h") set(_python_args - --format "${output_format}" --data-type "${data_type}" --metric "${metric}" --tile-m - "${tile_m}" --tile-n "${tile_n}" --tile-k "${tile_k}" --gpu-code "${gpu_code}" + --format + "${output_format}" + --data-type + "${data_type}" + --metric + "${metric}" + --index-type + "${index_type}" + --tile-m + "${tile_m}" + --tile-n + "${tile_n}" + --tile-k + "${tile_k}" + --gpu-code + "${gpu_code}" ) if(DEFINED bytecode_version AND NOT "${bytecode_version}" STREQUAL "") list(APPEND _python_args --bytecode-version "${bytecode_version}") diff --git a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp index 658c6e882b..c6afe16b5c 100644 --- a/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/fused_distance_nn/fused_1nn_fragments.hpp @@ -5,6 +5,8 @@ #pragma once +#include + #include #include @@ -75,13 +77,33 @@ struct fused_1nn_data_tag { template using fused_1nn_data_tag_t = typename fused_1nn_data_tag::type; -template +template +struct fused_1nn_index_tag; + +template <> +struct fused_1nn_index_tag { + using type = cuvs::neighbors::detail::tag_index_i32; +}; + +template <> +struct fused_1nn_index_tag { + using type = cuvs::neighbors::detail::tag_index_i64; +}; + +template +using fused_1nn_index_tag_t = typename fused_1nn_index_tag::type; + +template struct fragment_tag_fused_1nn_cubin { static constexpr int cc_major = ArchTag::cc_major; static constexpr int cc_minor = ArchTag::cc_minor; }; -template +template struct fragment_tag_fused_1nn_tileir {}; } // namespace cuvs::distance::detail diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 635e8813bd..757108312f 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -682,8 +682,8 @@ void kmeans_fit( DataT* cur_centroids_ptr = cur_centroids_buf.data(); DataT* new_centroids_ptr = new_centroids_buf.data(); - auto minClusterAndDistance = raft::make_device_vector, IndexT>( - handle, streaming_batch_size); + auto nearest_idx = raft::make_device_vector(handle, streaming_batch_size); + auto nearest_dist = raft::make_device_vector(handle, streaming_batch_size); auto L2NormBatch = raft::make_device_vector(handle, streaming_batch_size); auto batch_weights_buf = raft::make_device_vector(handle, streaming_batch_size); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); @@ -853,8 +853,10 @@ void kmeans_fit( auto batch_weights_view = cur_batch_weights(static_cast(data_batch.offset()), wt_data, cur_batch_size); - auto minCAD_view = raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle(), cur_batch_size); + auto nearest_idx_view = + raft::make_device_vector_view(nearest_idx.data_handle(), cur_batch_size); + auto nearest_dist_view = + raft::make_device_vector_view(nearest_dist.data_handle(), cur_batch_size); if constexpr (!data_on_device) { if (need_compute_norms) { @@ -883,7 +885,8 @@ void kmeans_fit( metric, iter_params.batch_samples, iter_params.batch_centroids, - minCAD_view, + nearest_idx_view, + nearest_dist_view, l2_const_view, L2NormBuf_OR_DistBuf, ws, @@ -1071,8 +1074,7 @@ void kmeans_predict(raft::resources const& handle, raft::make_const_mdspan(weight.view())); } - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // L2 norm of X: ||x||^2 @@ -1082,50 +1084,35 @@ void kmeans_predict(raft::resources const& handle, raft::linalg::norm(handle, X, L2NormX.view()); } - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' auto l2normx_view = raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - pams.metric, - pams.batch_samples, - pams.batch_centroids, - workspace); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute(handle, + X, + centroids, + labels, + nearest_dist.view(), + l2normx_view, + L2NormBuf_OR_DistBuf, + pams.metric, + pams.batch_samples, + pams.batch_centroids, + workspace); - // calculate cluster cost phi_x(C) rmm::device_scalar clusterCostD(stream); - raft::linalg::map( - handle, - minClusterAndDistance.view(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_const_mdspan(weight.view())); + raft::linalg::map(handle, + nearest_dist.view(), + raft::mul_op{}, + raft::make_const_mdspan(nearest_dist.view()), + raft::make_const_mdspan(weight.view())); cuvs::cluster::kmeans::detail::computeClusterCost( handle, - minClusterAndDistance.view(), + nearest_dist.view(), workspace, raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, + raft::identity_op{}, raft::add_op{}); - raft::linalg::map( - handle, labels, raft::key_op{}, raft::make_const_mdspan(minClusterAndDistance.view())); - inertia[0] = clusterCostD.value(stream); } diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 86e254f473..007c462247 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -98,58 +98,90 @@ inline std::enable_if_t> predict_core( raft::make_device_matrix_view(centers, n_clusters, dim); auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, raft::make_extents(n_rows)); - - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( - handle, - X_view, - centroids_view, - minClusterAndDistance.view(), - X_norm_view, - L2NormBuf_OR_DistBuf, - params.metric, - 0, // batch_samples (unused for fused reduction) - 0, // batch_centroids (unused for fused reduction) - workspace); - - // Copy keys to output labels - raft::linalg::map(handle, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_device_vector_view(labels, n_rows), - raft::compose_op, raft::key_op>()); - break; - } - case cuvs::distance::DistanceType::InnerProduct: { - if (use_cutile_fused_nn(handle, n_rows, n_clusters, dim)) { - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); - rmm::device_uvector workspace(0, stream, mr); - - auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); - auto centroids_view = - raft::make_device_matrix_view(centers, n_clusters, dim); - auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); - - auto minClusterAndDistance = - raft::make_device_mdarray, IdxT>( - handle, mr, raft::make_extents(n_rows)); + auto nearest_dist = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + if constexpr (std::is_same_v) { + auto labels_view = raft::make_device_vector_view(labels, n_rows); cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, X_view, centroids_view, - minClusterAndDistance.view(), + labels_view, + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, // batch_samples (unused for fused reduction) + 0, // batch_centroids (unused for fused reduction) + workspace); + } else { + auto nearest_idx = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + nearest_idx.view(), + nearest_dist.view(), X_norm_view, L2NormBuf_OR_DistBuf, params.metric, 0, 0, workspace); + raft::copy( + handle, raft::make_device_vector_view(labels, n_rows), nearest_idx.view()); + } + break; + } + case cuvs::distance::DistanceType::InnerProduct: { + if (uses_fused_distance_nn( + use_fused(handle, n_rows, n_clusters, dim, params.metric))) { + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); + rmm::device_uvector workspace(0, stream, mr); - raft::linalg::map(handle, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_device_vector_view(labels, n_rows), - raft::compose_op, raft::key_op>()); + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(centers, n_clusters, dim); + auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); + + auto nearest_dist = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + + if constexpr (std::is_same_v) { + auto labels_view = raft::make_device_vector_view(labels, n_rows); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + labels_view, + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + } else { + auto nearest_idx = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows)); + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + X_view, + centroids_view, + nearest_idx.view(), + nearest_dist.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, + params.metric, + 0, + 0, + workspace); + raft::copy(handle, + raft::make_device_vector_view(labels, n_rows), + nearest_idx.view()); + } } else { rmm::device_uvector distances(n_rows * n_clusters, stream, mr); @@ -216,22 +248,19 @@ auto calc_minibatch_size(const raft::resources& handle, size_t mem_per_row = 0; switch (metric) { case distance::DistanceType::L2Expanded: - case distance::DistanceType::L2SqrtExpanded: { - if (use_fused(handle, n_rows, n_clusters, dim)) { - // fusedL2NN needs a mutex and a key-value pair for each row. - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } else { - // unfused path needs a full GEMM output (distance matrix row). - mem_per_row += sizeof(MathT) * n_clusters; - } - } break; + case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::InnerProduct: { - if (use_cutile_fused_nn(handle, n_rows, n_clusters, dim)) { - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } else { - mem_per_row += sizeof(MathT) * n_clusters; + switch (use_fused(handle, n_rows, n_clusters, dim, metric)) { + case FusedDistancePath::FusedCutile: break; + case FusedDistancePath::FusedCutlass: + // fusedDistanceNNMinReduce CUTLASS fallback: mutex workspace + scratch KVP per row. + mem_per_row += sizeof(int); + mem_per_row += sizeof(raft::KeyValuePair); + break; + case FusedDistancePath::Unfused: + // unfused / GEMM+argmin path needs a full distance matrix row. + mem_per_row += sizeof(MathT) * n_clusters; + break; } } break; // Other metrics require storing a distance matrix. diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 0606d77dec..f0d9bde801 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -62,20 +62,42 @@ template inline constexpr bool is_cutile_fused_data_type_v = std::is_same_v || std::is_same_v; +/** Which fused-distance implementation minCluster* will use (or Unfused). */ +enum class FusedDistancePath : std::uint8_t { + /** unfusedDistanceNNMinReduce or batched pairwise distance. */ + Unfused = 0, + /** fusedDistanceNNMinReduce via cuTile; no CUTLASS mutex / KVP scratch. */ + FusedCutile, + /** fusedDistanceNNMinReduce via legacy CUTLASS; needs mutex workspace + KVP scratch. */ + FusedCutlass, +}; + +inline constexpr bool uses_fused_distance_nn(FusedDistancePath path) +{ + return path != FusedDistancePath::Unfused; +} + +inline constexpr bool needs_cutlass_kvp_scratch(FusedDistancePath path) +{ + return path == FusedDistancePath::FusedCutlass; +} + +inline constexpr bool needs_fused_mutex_workspace(FusedDistancePath path) +{ + return path == FusedDistancePath::FusedCutlass; +} + /** - * @brief Returns true if the fused distance NN implementation should be used (CUTLASS and/or - * cuTile). + * @brief Selects the fused-distance assignment path for KMeans. * - * Float/half: use fused whenever cuTile can launch (any architecture and problem size). If cuTile - * is unavailable, fall back to legacy CUTLASS fused on Ampere and Hopper only. Double and other - * types never use cuTile; they keep the historical CUTLASS/unfused heuristics on pre-Blackwell + * Float/half: cuTile when the build and device support it. Otherwise L2/L2Sqrt/Cosine may use + * legacy CUTLASS fused on Ampere/Hopper (large enough problems). InnerProduct without cuTile uses + * Unfused. Double never uses cuTile; keeps historical CUTLASS/unfused heuristics on pre-Blackwell * GPUs. - * - * Callers route through fusedDistanceNNMinReduce when this returns true; cuTile dispatch inside - * that API is gated separately by dtype (see fusedDistanceNNImpl). */ template -bool use_fused(const raft::resources& handle, IdxT m, IdxT n, IdxT k) +FusedDistancePath use_fused( + const raft::resources& handle, IdxT m, IdxT n, IdxT k, cuvs::distance::DistanceType metric) { (void)k; cudaDeviceProp prop; @@ -83,24 +105,20 @@ bool use_fused(const raft::resources& handle, IdxT m, IdxT n, IdxT k) if constexpr (is_cutile_fused_data_type_v) { if constexpr (cuvs::detail::jit_lto::library_built_with_cutile()) { - if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { return true; } + if (cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { + return FusedDistancePath::FusedCutile; + } } - return prop.major <= 9; + if (metric == cuvs::distance::DistanceType::InnerProduct) { return FusedDistancePath::Unfused; } + if (prop.major <= 8) { return FusedDistancePath::FusedCutlass; } + if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return FusedDistancePath::FusedCutlass; } + return FusedDistancePath::Unfused; } - if (prop.major >= 10) { return false; } - if (prop.major <= 8) { return true; } - if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return true; } - return false; -} - -/** True when assignment should use the cuTile fused 1-NN kernel (float/half only). */ -template -bool use_cutile_fused_nn(const raft::resources& /*handle*/, IdxT /*m*/, IdxT /*n*/, IdxT /*k*/) -{ - if constexpr (!is_cutile_fused_data_type_v) { return false; } - if constexpr (!cuvs::detail::jit_lto::library_built_with_cutile()) { return false; } - return cuvs::detail::jit_lto::cutile_launch_available_on_current_device(); + if (prop.major >= 10) { return FusedDistancePath::Unfused; } + if (prop.major <= 8) { return FusedDistancePath::FusedCutlass; } + if (prop.major == 9 && (m >= 4096 || n >= 4096)) { return FusedDistancePath::FusedCutlass; } + return FusedDistancePath::Unfused; } template @@ -391,33 +409,32 @@ void shuffleAndGather(raft::resources const& handle, stream); } -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' +// Calculates nearest centroid index and distance for every sample in input 'X'. template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace); - -#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ - extern template void minClusterAndDistanceCompute( \ - raft::resources const& handle, \ - raft::device_matrix_view X, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view, IndexT> minClusterAndDistance, \ - raft::device_vector_view L2NormX, \ - rmm::device_uvector& L2NormBuf_OR_DistBuf, \ - cuvs::distance::DistanceType metric, \ - int batch_samples, \ - int batch_centroids, \ +void minClusterAndDistanceCompute(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace); + +#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ + extern template void minClusterAndDistanceCompute( \ + raft::resources const& handle, \ + raft::device_matrix_view X, \ + raft::device_matrix_view centroids, \ + raft::device_vector_view nearest_idx, \ + raft::device_vector_view nearest_dist, \ + raft::device_vector_view L2NormX, \ + rmm::device_uvector& L2NormBuf_OR_DistBuf, \ + cuvs::distance::DistanceType metric, \ + int batch_samples, \ + int batch_centroids, \ rmm::device_uvector& workspace); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) @@ -476,22 +493,16 @@ void countSamplesInCluster(raft::resources const& handle, // stores (key, value) pair corresponding to each sample where // - key is the index of nearest cluster // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store distance matrix, destructor releases the resource + auto nearest_idx = raft::make_device_vector(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, X, (raft::device_matrix_view)centroids, - minClusterAndDistance.view(), + nearest_idx.view(), + nearest_dist.view(), L2NormX, L2NormBuf_OR_DistBuf, params.metric, @@ -499,12 +510,8 @@ void countSamplesInCluster(raft::resources const& handle, params.batch_centroids, workspace); - cuda::transform_iterator itr(minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); - - // count # of samples in each cluster countLabels(handle, - itr, + nearest_idx.data_handle(), sampleCountInCluster.data_handle(), (IndexT)n_samples, (IndexT)n_clusters, @@ -689,7 +696,8 @@ __device__ void check_convergence(raft::device_scalar_view clusteri * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute * @param[in] batch_centroids_param Batch-centroids param forwarded to * minClusterAndDistanceCompute - * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[inout] nearest_idx Nearest cluster index per sample [batch_size] + * @param[inout] nearest_dist Nearest distance per sample [batch_size] * @param[in] L2NormBatch Precomputed data norms [batch_size] * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch * @param[inout] workspace Resizable scratch @@ -698,29 +706,30 @@ __device__ void check_convergence(raft::device_scalar_view clusteri * @param[inout] clustering_cost Running cost scalar (device) (added into) */ template -void process_batch( - raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view batch_weights, - raft::device_matrix_view centroids, - cuvs::distance::DistanceType metric, - int batch_samples_param, - int batch_centroids_param, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormBatch, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace, - raft::device_matrix_view centroid_sums, - raft::device_vector_view weight_per_cluster, - raft::device_scalar_view clustering_cost, - rmm::device_uvector& batch_workspace) +void process_batch(raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_scalar_view clustering_cost, + rmm::device_uvector& batch_workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); minClusterAndDistanceCompute(handle, batch_data, centroids, - minClusterAndDistance, + nearest_idx, + nearest_dist, L2NormBatch, L2NormBuf_OR_DistBuf, metric, @@ -728,36 +737,30 @@ void process_batch( batch_centroids_param, workspace); - KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> - labels_itr(minClusterAndDistance.data_handle(), conversion_op); - compute_centroid_adjustments(handle, batch_data, batch_weights, - labels_itr, + nearest_idx.data_handle(), static_cast(centroid_sums.extent(0)), centroid_sums, weight_per_cluster, batch_workspace, /*reset_sums=*/false); - raft::linalg::map( - handle, - minClusterAndDistance, - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }, - raft::make_const_mdspan(minClusterAndDistance), - batch_weights); + auto weighted_dist = raft::make_device_vector(handle, nearest_dist.extent(0)); + raft::linalg::map(handle, + weighted_dist.view(), + raft::mul_op{}, + raft::make_const_mdspan(nearest_dist), + raft::make_const_mdspan(batch_weights)); auto batch_cost = raft::make_device_scalar(handle, DataT{0}); - computeClusterCost( - handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + computeClusterCost(handle, + weighted_dist.view(), + workspace, + batch_cost.view(), + raft::identity_op{}, + raft::add_op{}); raft::linalg::add(clustering_cost.data_handle(), clustering_cost.data_handle(), batch_cost.data_handle(), diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index ce3ca5a1fe..955d51c2a9 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -539,11 +539,8 @@ void fit(const raft::resources& handle, THROW("unknown initialization method to select initial centers"); } - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); + auto nearest_idx = raft::make_device_vector(handle, n_samples); + auto nearest_dist = raft::make_device_vector(handle, n_samples); // temporary buffer to store L2 norm of centroids or distance matrix, // destructor releases the resource @@ -577,15 +574,11 @@ void fit(const raft::resources& handle, auto const_centroids = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' cuvs::cluster::kmeans::min_cluster_and_distance(handle, X, const_centroids, - minClusterAndDistance.view(), + nearest_idx.view(), + nearest_dist.view(), L2NormX.view(), L2NormBuf_OR_DistBuf, params.metric, @@ -595,9 +588,7 @@ void fit(const raft::resources& handle, workspace.resize(n_samples, stream); - cuda::transform_iterator keys_itr( - minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); + const IndexT* keys_itr = nearest_idx.data_handle(); raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), X.extent(1), keys_itr, @@ -696,35 +687,24 @@ void fit(const raft::resources& handle, raft::make_device_vector_view(centroids.data_handle(), newCentroids.size()), raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); - bool done = false; - rmm::device_scalar> clusterCostD(stream); + bool done = false; + auto clusterCostD = raft::make_device_scalar(handle, DataT{0}); // calculate cluster cost phi_x(C) cuvs::cluster::kmeans::cluster_cost( handle, - minClusterAndDistance.view(), + nearest_dist.view(), workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); + clusterCostD.view(), + cuda::proclaim_return_type( + [] __device__(const DataT& a, const DataT& b) { return a + b; })); // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); + comm.allreduce( + clusterCostD.data_handle(), clusterCostD.data_handle(), 1, raft::comms::op_t::SUM, stream); DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); + raft::copy(handle, raft::make_host_scalar_view(&curClusteringCost), clusterCostD.view()); ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, "An error occurred in the distributed operation. This can result " diff --git a/cpp/src/cluster/detail/kmeans_mg_batched.cuh b/cpp/src/cluster/detail/kmeans_mg_batched.cuh index 98fed41636..ccc89991a6 100644 --- a/cpp/src/cluster/detail/kmeans_mg_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_mg_batched.cuh @@ -157,9 +157,9 @@ void mnmg_fit(const raft::resources& handle, auto sqrd_norm_error_dev = raft::make_device_scalar(dev_res, DataT{0}); IndexT alloc_batch_size = has_data ? streaming_batch_size : IndexT{1}; auto batch_weights = raft::make_device_vector(dev_res, alloc_batch_size); - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(dev_res, alloc_batch_size); - auto L2NormBatch = raft::make_device_vector(dev_res, alloc_batch_size); + auto nearest_idx = raft::make_device_vector(dev_res, alloc_batch_size); + auto nearest_dist = raft::make_device_vector(dev_res, alloc_batch_size); + auto L2NormBatch = raft::make_device_vector(dev_res, alloc_batch_size); rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); rmm::device_uvector workspace(0, stream); rmm::device_uvector batch_workspace(0, stream); @@ -353,9 +353,10 @@ void mnmg_fit(const raft::resources& handle, auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); - auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle(), current_batch_size); + auto nearest_idx_view = raft::make_device_vector_view( + nearest_idx.data_handle(), current_batch_size); + auto nearest_dist_view = raft::make_device_vector_view( + nearest_dist.data_handle(), current_batch_size); cuvs::cluster::kmeans::detail::process_batch( dev_res, @@ -365,7 +366,8 @@ void mnmg_fit(const raft::resources& handle, metric, params.batch_samples, params.batch_centroids, - minClusterAndDistance_view, + nearest_idx_view, + nearest_dist_view, L2NormBatch_const, L2NormBuf_OR_DistBuf, workspace, diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 65678faa08..d01f48fc0a 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -11,39 +11,66 @@ namespace cuvs::cluster::kmeans::detail { -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroids[key]'. +namespace { + +template +__global__ void unpack_kvp_to_soa(IndexT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IndexT n) +{ + IndexT i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + if (nearest_idx != nullptr) { nearest_idx[i] = kvp[i].key; } + if (nearest_dist != nullptr) { nearest_dist[i] = kvp[i].value; } + } +} + +template +void unpack_kvp(raft::resources const& handle, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view, IndexT> kvp) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto n = static_cast(kvp.extent(0)); + int blks = static_cast((n + 255) / 256); + unpack_kvp_to_soa<<>>( + nearest_idx.data_handle(), nearest_dist.data_handle(), kvp.data_handle(), n); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace + template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void minClusterAndDistanceCompute(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - bool is_l2_cos_fused = metric == cuvs::distance::DistanceType::L2Expanded || + const bool is_l2_cos = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded || metric == cuvs::distance::DistanceType::CosineExpanded; - const bool is_ip_cutile = - metric == cuvs::distance::DistanceType::InnerProduct && - use_cutile_fused_nn(handle, n_samples, n_clusters, n_features); + const FusedDistancePath fused_path = + use_fused(handle, n_samples, n_clusters, n_features, metric); - if (is_l2_cos_fused || is_ip_cutile) { + if (uses_fused_distance_nn(fused_path)) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (is_l2_cos_fused) { + if (is_l2_cos) { if (metric == cuvs::distance::DistanceType::CosineExpanded) { raft::linalg::norm( handle, centroids, centroidsNorm, raft::sqrt_op{}); @@ -53,20 +80,61 @@ void minClusterAndDistanceCompute( } } - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance, initial_value); + auto centroidsNormConst = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + raft::KeyValuePair* cutlass_kvp_scratch = nullptr; + rmm::device_uvector> temp_kvp(0, stream); + if (needs_cutlass_kvp_scratch(fused_path)) { + temp_kvp.resize(n_samples, stream); + cutlass_kvp_scratch = temp_kvp.data(); + workspace.resize(sizeof(int) * n_samples, stream); + } + + cuvs::distance::fusedDistanceNNMinReduce( + nearest_idx.data_handle(), + nearest_dist.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNormConst.data_handle(), + n_samples, + n_clusters, + n_features, + needs_fused_mutex_workspace(fused_path) ? (void*)workspace.data() : nullptr, + metric != cuvs::distance::DistanceType::L2Expanded, + true, + true, + metric, + 0.0f, + cutlass_kvp_scratch, + stream); + } else if (is_l2_cos) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - const bool should_use_fused = - use_fused(handle, n_samples, n_clusters, n_features); + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } auto centroidsNormConst = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - if (should_use_fused) { - workspace.resize((sizeof(int)) * n_samples, stream); + workspace.resize(sizeof(DataT) * n_samples * n_clusters, stream); + auto temp_kvp = + raft::make_device_vector, IndexT>(handle, n_samples); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, temp_kvp.view(), initial_value); - cuvs::distance::fusedDistanceNNMinReduce, IndexT>( - minClusterAndDistance.data_handle(), + cuvs::distance:: + unfusedDistanceNNMinReduce, IndexT>( + handle, + temp_kvp.data_handle(), X.data_handle(), centroids.data_handle(), L2NormX.data_handle(), @@ -81,84 +149,44 @@ void minClusterAndDistanceCompute( metric, 0.0f, stream); - } else { - workspace.resize(sizeof(DataT) * n_samples * n_clusters, stream); - - cuvs::distance:: - unfusedDistanceNNMinReduce, IndexT>( - handle, - minClusterAndDistance.data_handle(), - X.data_handle(), - centroids.data_handle(), - L2NormX.data_handle(), - centroidsNormConst.data_handle(), - n_samples, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } + unpack_kvp(handle, nearest_idx, nearest_dist, raft::make_const_mdspan(temp_kvp.view())); } else { auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + auto temp_kvp = + raft::make_device_vector, IndexT>(handle, n_samples); raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance, initial_value); + raft::matrix::fill(handle, temp_kvp.view(), initial_value); - // tile over the input dataset for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - // datasetView [ns x n_features] - view representing the current batch of - // input dataset auto datasetView = raft::make_device_matrix_view( X.data_handle() + (dIdx * n_features), ns, n_features); - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); + auto temp_kvp_view = raft::make_device_vector_view, IndexT>( + temp_kvp.data_handle() + dIdx, ns); - // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - // centroidsView [nc x n_features] - view representing the current batch - // of centroids auto centroidsView = raft::make_device_matrix_view( centroids.data_handle() + (cIdx * n_features), nc, n_features); - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch auto pairwiseDistanceView = raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - // calculate pairwise distance between current tile of cluster centroids - // and input dataset pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, metric); - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid raft::linalg::coalescedReduction( - minClusterAndDistanceView.data_handle(), + temp_kvp_view.data_handle(), pairwiseDistanceView.data_handle(), pairwiseDistanceView.extent(1), pairwiseDistanceView.extent(0), @@ -175,20 +203,23 @@ void minClusterAndDistanceCompute( raft::identity_op{}); } } + + unpack_kvp(handle, nearest_idx, nearest_dist, raft::make_const_mdspan(temp_kvp.view())); } } -#define INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ - template void minClusterAndDistanceCompute( \ - raft::resources const& handle, \ - raft::device_matrix_view X, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view, IndexT> minClusterAndDistance, \ - raft::device_vector_view L2NormX, \ - rmm::device_uvector& L2NormBuf_OR_DistBuf, \ - cuvs::distance::DistanceType metric, \ - int batch_samples, \ - int batch_centroids, \ +#define INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ + template void minClusterAndDistanceCompute( \ + raft::resources const& handle, \ + raft::device_matrix_view X, \ + raft::device_matrix_view centroids, \ + raft::device_vector_view nearest_idx, \ + raft::device_vector_view nearest_dist, \ + raft::device_vector_view L2NormX, \ + rmm::device_uvector& L2NormBuf_OR_DistBuf, \ + cuvs::distance::DistanceType metric, \ + int batch_samples, \ + int batch_centroids, \ rmm::device_uvector& workspace); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) @@ -215,13 +246,17 @@ void minClusterDistanceCompute(raft::resources const& handle, auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded || - metric == cuvs::distance::DistanceType::CosineExpanded; + const bool is_l2_cos = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); - if (is_fused) { + const FusedDistancePath fused_path = + is_l2_cos ? use_fused(handle, n_samples, n_clusters, n_features, metric) + : FusedDistancePath::Unfused; + + if (uses_fused_distance_nn(fused_path)) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); auto centroidsNorm = raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); @@ -241,9 +276,16 @@ void minClusterDistanceCompute(raft::resources const& handle, centroidsNorm); } - workspace.resize(sizeof(int) * n_samples, stream); + raft::KeyValuePair* cutlass_kvp_scratch = nullptr; + rmm::device_uvector> temp_kvp(0, stream); + if (needs_cutlass_kvp_scratch(fused_path)) { + temp_kvp.resize(n_samples, stream); + cutlass_kvp_scratch = temp_kvp.data(); + workspace.resize(sizeof(int) * n_samples, stream); + } - cuvs::distance::fusedDistanceNNMinReduce( + cuvs::distance::fusedDistanceNNMinReduce( + nullptr, minClusterDistance.data_handle(), X.data_handle(), centroids.data_handle(), @@ -252,12 +294,13 @@ void minClusterDistanceCompute(raft::resources const& handle, n_samples, n_clusters, n_features, - (void*)workspace.data(), + needs_fused_mutex_workspace(fused_path) ? (void*)workspace.data() : nullptr, metric != cuvs::distance::DistanceType::L2Expanded, - false, + true, true, metric, 0.0f, + cutlass_kvp_scratch, stream); } else { auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); @@ -268,8 +311,6 @@ void minClusterDistanceCompute(raft::resources const& handle, auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); @@ -279,7 +320,6 @@ void minClusterDistanceCompute(raft::resources const& handle, auto minClusterDistanceView = raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index 06da1fc1de..003604769b 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -486,22 +486,23 @@ void cluster_cost( * */ template -void min_cluster_and_distance( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void min_cluster_and_distance(raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view nearest_idx, + raft::device_vector_view nearest_dist, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace) { cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute(handle, X, centroids, - minClusterAndDistance, + nearest_idx, + nearest_dist, L2NormX, L2NormBuf_OR_DistBuf, metric, diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh index 476ab9c2be..a2ec5422dd 100644 --- a/cpp/src/distance/detail/fused_distance_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -14,6 +14,7 @@ #include "fused_distance_nn/simt_kernel.cuh" #include "pairwise_distance_base.cuh" // PairwiseDistances #include +#include #include // raft::KeyValuePair #include // raft::identity_op #include // Policy @@ -28,13 +29,9 @@ namespace distance { namespace detail { -template -void fusedDistanceNNImpl(OutT* min, +template +void fusedDistanceNNImpl(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -50,49 +47,77 @@ void fusedDistanceNNImpl(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { - // The kernel policy is determined by fusedDistanceNN. typedef Policy P; + typedef raft::KeyValuePair KVP; + constexpr auto maxVal = std::numeric_limits::max(); - // Callers (e.g. use_fused) enable this API for CUTLASS fused as well as cuTile; only try cuTile - // for float/half KVP output so double and other types never instantiate cuTile symbols here. if constexpr (is_fused_1nn_cutile_data_v) { - if constexpr (cuvs::detail::jit_lto::library_built_with_cutile() && - is_fused_1nn_kvp_output_v) { - if (try_fused_1nn_tile(min, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { + if constexpr (cuvs::detail::jit_lto::library_built_with_cutile()) { + if (try_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, metric, sqrt, stream)) { return; } } } - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef raft::KeyValuePair KVPair; + RAFT_EXPECTS(cutlass_kvp_scratch != nullptr, "CUTLASS fused 1-NN requires a scratch KVP buffer"); - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); + initFused1nnOutput(nearest_idx, nearest_dist, m, std::numeric_limits::max(), stream); } + MinAndDistanceReduceOpImpl cutlass_redOp; + cutlass_redOp.out_kvp = cutlass_kvp_scratch; + initialize( + cutlass_kvp_scratch, m, maxVal, cutlass_redOp, stream); + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + switch (metric) { case cuvs::distance::DistanceType::CosineExpanded: - fusedCosineNN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + fusedCosineNN(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + cutlass_redOp, + pairRedOp, + sqrt, + cutlass_kvp_scratch, + stream); break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Expanded: - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); - break; - case cuvs::distance::DistanceType::InnerProduct: - // cuTile is the only fused InnerProduct implementation; callers must gate on availability. + fusedL2NNImpl(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + cutlass_redOp, + pairRedOp, + sqrt, + false, + cutlass_kvp_scratch, + stream); break; + case cuvs::distance::DistanceType::InnerProduct: break; default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; } + + unpackFused1nnKvpToSoa(nearest_idx, nearest_dist, cutlass_kvp_scratch, m, stream); } } // namespace detail diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py index fcefe7a027..1211456c9e 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/export_fused_1nn.py @@ -19,7 +19,14 @@ export_kernel, ) -from fused_1nn_kernel import METRICS, kernel_symbol, make_kernel, metric_abbrev +from fused_1nn_kernel import ( + INDEX_TYPES, + METRICS, + index_abbrev, + kernel_symbol, + make_kernel, + metric_abbrev, +) DEFAULT_TILEIR_BYTECODE_VERSION = "13.1" # cuTile requires a gpu_code even for TileIR bytecode export: it selects the compilation @@ -39,53 +46,82 @@ def _data_abbrev(data_type: str) -> str: return {"half": "h", "float": "f"}[data_type] -def _relaxed_matrix_constraint(elem_dtype): - """Array constraints matching the relaxed TMA-friendly layout from gemm_nn_cutile.""" +def _elem_stride_divisible_for_tma(elem_dtype) -> tuple[int, int]: + """Row stride (dim 0) divisible enough for 16-byte TMA access; last dim stride 1.""" + bytes_per_elem = 2 if elem_dtype == ct.float16 else 4 + return (16 // bytes_per_elem, 1) + + +def _cuvs_matrix_constraint(elem_dtype): + """Row-major device matrices for cuVS KMeans benchmarks. + + Assumes raft/cupy-style contiguous layout: stride[-1]==1, stride[0]==D, + 16-byte base alignment, and row pitch 16-byte aligned (float32 D%4==0, + float16 D%8==0). Applies to both points and centroids matrices. + + shape_divisible_by is (1, 1); tail tiles are masked in the kernel. + Odd D or general layouts need a separate relaxed export profile. + """ return ArrayConstraint( elem_dtype, ndim=2, - index_dtype=ct.int64, + index_dtype=ct.int32, stride_lower_bound_incl=(0, None), alias_groups=(), may_alias_internally=False, stride_constant=(None, 1), - stride_divisible_by=(8, 1), + stride_divisible_by=_elem_stride_divisible_for_tma(elem_dtype), shape_divisible_by=(1, 1), base_addr_divisible_by=16, ) -def _relaxed_vector_constraint(elem_dtype, *, tma_friendly: bool = False): - base_div = 16 if tma_friendly else 1 +def _cuvs_vector_constraint(elem_dtype): + """1-D device vectors: contiguous, 16-byte base. Length need not be divisible by 16.""" return ArrayConstraint( elem_dtype, ndim=1, - index_dtype=ct.int64, + index_dtype=ct.int32, stride_lower_bound_incl=(None,), alias_groups=(), may_alias_internally=False, stride_constant=(1,), stride_divisible_by=(1,), shape_divisible_by=(1,), - base_addr_divisible_by=base_div, + base_addr_divisible_by=16, ) +def _relaxed_matrix_constraint(elem_dtype): + """Deprecated alias; use _cuvs_matrix_constraint.""" + return _cuvs_matrix_constraint(elem_dtype) + + +def _relaxed_vector_constraint(elem_dtype, *, tma_friendly: bool = False): + """Deprecated alias; use _cuvs_vector_constraint.""" + del tma_friendly + return _cuvs_vector_constraint(elem_dtype) + + def _kernel_signature( data_type: str, metric: str, + index_type: str, tile_m: int, tile_n: int, tile_k: int, ) -> KernelSignature: elem = _dtype_for(data_type) - matrix = _relaxed_matrix_constraint(elem) - norm_array = _relaxed_vector_constraint(elem, tma_friendly=True) - idx_array = _relaxed_vector_constraint(ct.int64) - dist_array = _relaxed_vector_constraint(ct.float32) + matrix = _cuvs_matrix_constraint(elem) + norm_array = _cuvs_vector_constraint(elem) + idx_elem = ct.int32 if index_type == "int32" else ct.int64 + idx_array = _cuvs_vector_constraint(idx_elem) + dist_array = _cuvs_vector_constraint(elem) abbrev = _data_abbrev(data_type) - symbol = kernel_symbol(abbrev, metric_abbrev(metric)) + symbol = kernel_symbol( + abbrev, metric_abbrev(metric), index_abbrev(index_type) + ) return KernelSignature( parameters=[ @@ -98,6 +134,8 @@ def _kernel_signature( ScalarConstraint(ct.int64), ScalarConstraint(ct.int64), ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), + ScalarConstraint(ct.int64), ConstantConstraint(tile_m), ConstantConstraint(tile_n), ConstantConstraint(tile_k), @@ -112,14 +150,19 @@ def export_binary( output_format: Literal["cubin", "tileir_bytecode"], data_type: str, metric: str, + index_type: str, tile_m: int, tile_n: int, tile_k: int, gpu_code: str, bytecode_version: str | None = None, ) -> str: - kernel = make_kernel(data_type, metric, tile_m, tile_n, tile_k) - signature = _kernel_signature(data_type, metric, tile_m, tile_n, tile_k) + kernel = make_kernel( + data_type, metric, tile_m, tile_n, tile_k, index_type=index_type + ) + signature = _kernel_signature( + data_type, metric, index_type, tile_m, tile_n, tile_k + ) export_kwargs = { "kernel": kernel, @@ -148,6 +191,7 @@ def main() -> int: "--data-type", choices=("half", "float"), required=True ) parser.add_argument("--metric", choices=METRICS, required=True) + parser.add_argument("--index-type", choices=INDEX_TYPES, required=True) parser.add_argument("--tile-m", type=int, required=True) parser.add_argument("--tile-n", type=int, required=True) parser.add_argument("--tile-k", type=int, required=True) @@ -167,6 +211,7 @@ def main() -> int: output_format=args.format, data_type=args.data_type, metric=args.metric, + index_type=args.index_type, tile_m=args.tile_m, tile_n=args.tile_n, tile_k=args.tile_k, diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json index 3aa9dffd8a..7d9b723b39 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_cutile_matrix.json @@ -24,18 +24,28 @@ "metric_abbrev": "cos" } ], + "_index": [ + { + "index_type": "int32", + "index_abbrev": "i32" + }, + { + "index_type": "int64", + "index_abbrev": "i64" + } + ], "_tile": [ { - "tile_m": 128, - "tile_n": 128, - "tile_k": 64 + "tile_m": 256, + "tile_n": 64, + "tile_k": 32 } ], "_export": [ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_80", "cc_major": 8, @@ -45,7 +55,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_86", "cc_major": 8, @@ -55,7 +65,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_90", "cc_major": 9, @@ -65,7 +75,7 @@ { "output_format": "cubin", "artifact_ext": "cubin", - "artifact_basename": "@data_type@_@metric_abbrev@_@gpu_code@", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@_@gpu_code@", "register": "cubin", "gpu_code": "sm_120", "cc_major": 12, @@ -75,7 +85,7 @@ { "output_format": "tileir_bytecode", "artifact_ext": "tilebc", - "artifact_basename": "@data_type@_@metric_abbrev@", + "artifact_basename": "@data_type@_@metric_abbrev@_@index_abbrev@", "register": "tileir", "gpu_code": "sm_80", "bytecode_version": "13.1" diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py index 7d78525869..b2ff25555b 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_kernel.py @@ -9,11 +9,20 @@ ConstInt = ct.Constant[int] # Default tile geometry; overridden per export via make_kernel(..., tile_m, tile_n, tile_k). -DEFAULT_TILE_M = 128 -DEFAULT_TILE_N = 128 -DEFAULT_TILE_K = 64 +DEFAULT_TILE_M = 256 +DEFAULT_TILE_N = 64 +DEFAULT_TILE_K = 32 METRICS = ("inner_product", "l2_expanded", "cosine_expanded") +INDEX_TYPES = ("int32", "int64") + + +def _idx_dtype(index_type: str): + if index_type == "int32": + return ct.int32 + if index_type == "int64": + return ct.int64 + raise ValueError(f"Unsupported index_type {index_type!r}") def make_kernel( @@ -22,14 +31,20 @@ def make_kernel( tile_m: int = DEFAULT_TILE_M, tile_n: int = DEFAULT_TILE_N, tile_k: int = DEFAULT_TILE_K, + *, + index_type: str = "int32", ): - """Build a cuTile kernel with metric and tile sizes baked in at compile time.""" + """Build a cuTile kernel with metric, index width, and tile sizes baked in at compile time.""" if data_type not in ("half", "float"): raise ValueError(f"Unsupported data_type {data_type!r}") if metric not in METRICS: raise ValueError(f"Unsupported metric {metric!r}") + if index_type not in INDEX_TYPES: + raise ValueError(f"Unsupported index_type {index_type!r}") acc_dtype = ct.float32 + idx_dtype = _idx_dtype(index_type) + out_dist_dtype = ct.float16 if data_type == "half" else ct.float32 is_ip = metric == "inner_product" is_l2 = metric == "l2_expanded" is_cos = metric == "cosine_expanded" @@ -45,6 +60,8 @@ def fused_1nn_kernel( M, N, K, + apply_sqrt, + store_idx, tm: ConstInt, tn: ConstInt, tk: ConstInt, @@ -55,7 +72,7 @@ def fused_1nn_kernel( best_dist = ct.full((tm,), -3.4e38, acc_dtype) else: best_dist = ct.full((tm,), 3.4e38, acc_dtype) - best_idx = ct.zeros((tm,), ct.int64) + best_idx = ct.zeros((tm,), idx_dtype) num_tiles_k = ct.num_tiles(A, axis=1, shape=(tm, tk)) num_tiles_n = ct.num_tiles(B, axis=0, shape=(tn, tk)) @@ -65,12 +82,15 @@ def fused_1nn_kernel( accumulator = ct.full((tm, tn), 0, dtype=acc_dtype) for k in range(num_tiles_k): + dtype = ct.tfloat32 if A.dtype == ct.float32 else A.dtype + a = ct.load( A, index=(bidm, k), shape=(tm, tk), padding_mode=zero_pad - ) + ).astype(dtype) b_T = ct.load( B, index=(n, k), shape=(tn, tk), padding_mode=zero_pad - ) + ).astype(dtype) + accumulator = ct.mma(a, ct.transpose(b_T), accumulator) if is_ip: @@ -89,14 +109,13 @@ def fused_1nn_kernel( ) elif is_cos: # Cosine expanded distance: 1 - dot / (||x|| * ||y||); norms are L2 (not squared). - # No sqrt during the reduction — only arithmetic on stored distance if needed. denom = a_norm[:, None] * b_norm[None, :] score = 1.0 - (accumulator / denom) # Only the final N-tile can include zero-padded centroid columns. if n == num_tiles_n - 1: - col = ct.arange(tn, dtype=ct.int64) - global_col = n * tn + col + col = ct.arange(tn, dtype=idx_dtype) + global_col = (n * tn + col).astype(idx_dtype) valid = global_col < N if is_ip: score = ct.where(valid[None, :], score, -3.4e38) @@ -114,17 +133,25 @@ def fused_1nn_kernel( update = curr_best < best_dist best_dist = ct.where(update, curr_best, best_dist) - best_idx = ct.where(update, n * tn + curr_idx, best_idx) + best_idx = ct.where( + update, (n * tn + curr_idx).astype(idx_dtype), best_idx + ) - ct.store(OutIdx, index=(bidm,), tile=best_idx) - ct.store(OutDist, index=(bidm,), tile=best_dist) + out_dist = best_dist + if is_l2: + out_dist = ct.where(apply_sqrt != 0, ct.sqrt(best_dist), best_dist) + if store_idx != 0: + ct.store(OutIdx, index=(bidm,), tile=best_idx) + ct.store(OutDist, index=(bidm,), tile=out_dist.astype(out_dist_dtype)) return fused_1nn_kernel -def kernel_symbol(data_abbrev: str, metric_abbrev: str) -> str: +def kernel_symbol( + data_abbrev: str, metric_abbrev: str, index_abbrev: str +) -> str: """Must stay in sync with fused_1nn_kernel_entrypoint() in fused_1nn_planner.hpp.""" - return f"fused_1nn_{data_abbrev}_{metric_abbrev}" + return f"fused_1nn_{data_abbrev}_{metric_abbrev}_{index_abbrev}" def metric_abbrev(metric: str) -> str: @@ -133,3 +160,7 @@ def metric_abbrev(metric: str) -> str: "l2_expanded": "l2", "cosine_expanded": "cos", }[metric] + + +def index_abbrev(index_type: str) -> str: + return {"int32": "i32", "int64": "i64"}[index_type] diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp index c70ab3f87b..017fb72d48 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_planner.hpp @@ -17,42 +17,48 @@ namespace cuvs::distance::detail { /** Must match kernel_symbol() in fused_1nn_kernel.py (export uses with_symbol). */ -template +template inline const char* fused_1nn_kernel_entrypoint() { + constexpr bool is_i32 = std::is_same_v; + constexpr bool is_i64 = std::is_same_v; + static_assert(is_i32 || is_i64, "unsupported fused 1-NN cuTile index width"); + if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_f_ip"; + return is_i32 ? "fused_1nn_f_ip_i32" : "fused_1nn_f_ip_i64"; } else if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_f_l2"; + return is_i32 ? "fused_1nn_f_l2_i32" : "fused_1nn_f_l2_i64"; } else if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_f_cos"; + return is_i32 ? "fused_1nn_f_cos_i32" : "fused_1nn_f_cos_i64"; } else if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_h_ip"; + return is_i32 ? "fused_1nn_h_ip_i32" : "fused_1nn_h_ip_i64"; } else if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_h_l2"; + return is_i32 ? "fused_1nn_h_l2_i32" : "fused_1nn_h_l2_i64"; } else if constexpr (std::is_same_v && std::is_same_v) { - return "fused_1nn_h_cos"; + return is_i32 ? "fused_1nn_h_cos_i32" : "fused_1nn_h_cos_i64"; } else { static_assert(sizeof(DataTag) == 0, "unsupported fused 1-NN cuTile data/metric combination"); return ""; } } -template +template struct Fused1nnTilePlanner : TileAlgorithmPlanner { using DataTag = fused_1nn_data_tag_t; using MetricTag = fused_1nn_metric_tag_t; + using IndexTag = fused_1nn_index_tag_t; inline static LauncherJitCache launcher_jit_cache{}; Fused1nnTilePlanner() - : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), launcher_jit_cache) + : TileAlgorithmPlanner(fused_1nn_kernel_entrypoint(), + launcher_jit_cache) { } @@ -65,20 +71,32 @@ struct Fused1nnTilePlanner : TileAlgorithmPlanner { using cuvs::detail::jit_lto::cutile_arch_8_6; using cuvs::detail::jit_lto::cutile_arch_9_0; - this->add_static_fragment< - fragment_tag_fused_1nn_cubin>(); - this->add_static_fragment< - fragment_tag_fused_1nn_cubin>(); - this->add_static_fragment< - fragment_tag_fused_1nn_cubin>(); - this->add_static_fragment< - fragment_tag_fused_1nn_cubin>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); + this->add_static_fragment>(); } void add_tileir_fallback() { this->add_static_tileir_fragment< - fragment_tag_fused_1nn_tileir>(); + fragment_tag_fused_1nn_tileir>(); } }; diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu index d292f0522b..cf01c12ce1 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.cu @@ -8,6 +8,7 @@ #include "fused_1nn_planner.hpp" #include +#include #include namespace cuvs { @@ -16,25 +17,13 @@ namespace detail { namespace { -template -__global__ void pack_fused_1nn_kvp( - OutT* out, const int64_t* idx, const float* dist, IdxT len, bool apply_sqrt) -{ - IdxT i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < len) { - out[i].key = static_cast(idx[i]); - float value = dist[i]; - if (apply_sqrt) { value = sqrtf(value); } - out[i].value = static_cast(value); - } -} - -template -bool launch_fused_1nn_tile(const DataT* x, +template +bool launch_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, + const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, - OutT* out, IdxT m, IdxT n, IdxT k, @@ -43,7 +32,9 @@ bool launch_fused_1nn_tile(const DataT* x, { if constexpr (!std::is_same_v && !std::is_same_v) { return false; } - Fused1nnTilePlanner planner; + if (nearest_dist == nullptr) { return false; } + + Fused1nnTilePlanner planner; planner.add_entrypoint(); planner.add_tileir_fallback(); const CutileTileConfig tile_cfg = planner.tile_config(); @@ -52,11 +43,6 @@ bool launch_fused_1nn_tile(const DataT* x, const bool apply_sqrt = fused_1nn_apply_sqrt_at_pack(is_sqrt); - int64_t* d_idx = nullptr; - float* d_dist = nullptr; - RAFT_CUDA_TRY(cudaMallocAsync(&d_idx, m * sizeof(int64_t), stream)); - RAFT_CUDA_TRY(cudaMallocAsync(&d_dist, m * sizeof(float), stream)); - int64_t shape_x[2] = {m, k}; int64_t stride_x[2] = {k, 1}; int64_t shape_y[2] = {n, k}; @@ -72,19 +58,21 @@ bool launch_fused_1nn_tile(const DataT* x, int64_t M = m, N = n, K = k; - void* x_ptr = const_cast(x); - void* y_ptr = const_cast(y); - void* xn_ptr = const_cast(xn); - void* yn_ptr = const_cast(yn); - void* idx_ptr = d_idx; - void* dist_ptr = d_dist; + void* x_ptr = const_cast(x); + void* y_ptr = const_cast(y); + void* xn_ptr = const_cast(xn); + void* yn_ptr = const_cast(yn); + // OutIdx must be a valid device pointer for the launch ABI; when store_idx is 0 the kernel + // does not write it (dist-only callers pass nearest_dist as a stand-in). + const int64_t store_idx = nearest_idx != nullptr ? 1 : 0; + void* idx_ptr = + nearest_idx != nullptr ? static_cast(nearest_idx) : static_cast(nearest_dist); + void* dist_ptr = nearest_dist; const int64_t tile_m = tile_cfg.tile_m; dim3 grid((m + tile_m - 1) / tile_m, 1, 1); dim3 block(1, 1, 1); - // cutile_python_v1: 2D array (ptr, shape0, shape1, stride0, stride1); - // 1D array (ptr, shape, stride); tile sizes are embedded constants. using fused_1nn_cutile_kernel_t = void(void*, int64_t, int64_t, @@ -109,8 +97,9 @@ bool launch_fused_1nn_tile(const DataT* x, int64_t, int64_t, int64_t, + int64_t, + int64_t, int64_t); - std::cout << "Launching cuTile kernel" << std::endl; launcher->template dispatch(stream, grid, block, @@ -139,18 +128,16 @@ bool launch_fused_1nn_tile(const DataT* x, stride_dist, M, N, - K); - - pack_fused_1nn_kvp - <<<(m + 255) / 256, 256, 0, stream>>>(out, d_idx, d_dist, m, apply_sqrt); + K, + static_cast(apply_sqrt), + store_idx); RAFT_CUDA_TRY(cudaGetLastError()); - RAFT_CUDA_TRY(cudaFreeAsync(d_idx, stream)); - RAFT_CUDA_TRY(cudaFreeAsync(d_dist, stream)); return true; } -template -bool try_fused_1nn_tile_dispatch(OutT* min, +template +bool try_fused_1nn_tile_dispatch(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -164,26 +151,27 @@ bool try_fused_1nn_tile_dispatch(OutT* min, { switch (metric) { case cuvs::distance::DistanceType::InnerProduct: - return launch_fused_1nn_tile( - x, y, xn, yn, min, m, n, k, is_sqrt, stream); + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); case cuvs::distance::DistanceType::L2Expanded: - return launch_fused_1nn_tile( - x, y, xn, yn, min, m, n, k, is_sqrt, stream); + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); case cuvs::distance::DistanceType::L2SqrtExpanded: - return launch_fused_1nn_tile( - x, y, xn, yn, min, m, n, k, is_sqrt, stream); + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); case cuvs::distance::DistanceType::CosineExpanded: - return launch_fused_1nn_tile( - x, y, xn, yn, min, m, n, k, is_sqrt, stream); + return launch_fused_1nn_tile( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, is_sqrt, stream); default: return false; } } } // namespace -template - requires Fused1nnKvpOutput -bool try_fused_1nn_tile(OutT* min, +template + requires is_fused_1nn_cutile_data_v +bool try_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -196,35 +184,28 @@ bool try_fused_1nn_tile(OutT* min, cudaStream_t stream) { if (!cuvs::detail::jit_lto::cutile_launch_available_on_current_device()) { return false; } - return try_fused_1nn_tile_dispatch( - min, x, y, xn, yn, m, n, k, metric, is_sqrt, stream); + return try_fused_1nn_tile_dispatch( + nearest_idx, nearest_dist, x, y, xn, yn, m, n, k, metric, is_sqrt, stream); } -using kvp_i_f = raft::KeyValuePair; -using kvp_i64_f = raft::KeyValuePair; -using kvp_i_h = raft::KeyValuePair; -using kvp_i64_h = raft::KeyValuePair; - -#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, OutT, IdxT) \ - template CUVS_EXPORT bool try_fused_1nn_tile(OutT*, \ - const DataT*, \ - const DataT*, \ - const DataT*, \ - const DataT*, \ - IdxT, \ - IdxT, \ - IdxT, \ - cuvs::distance::DistanceType, \ - bool, \ - cudaStream_t) - -// int and int64_t are the same on LP64; one instantiation covers both. -CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i_f, int); -CUVS_INST_TRY_FUSED_1NN_TILE(float, kvp_i64_f, int64_t); -CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i_f, int); -CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i64_f, int64_t); -CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i_h, int); -CUVS_INST_TRY_FUSED_1NN_TILE(half, kvp_i64_h, int64_t); +#define CUVS_INST_TRY_FUSED_1NN_TILE(DataT, IdxT) \ + template CUVS_EXPORT bool try_fused_1nn_tile(IdxT*, \ + DataT*, \ + const DataT*, \ + const DataT*, \ + const DataT*, \ + const DataT*, \ + IdxT, \ + IdxT, \ + IdxT, \ + cuvs::distance::DistanceType, \ + bool, \ + cudaStream_t) + +CUVS_INST_TRY_FUSED_1NN_TILE(float, int); +CUVS_INST_TRY_FUSED_1NN_TILE(float, int64_t); +CUVS_INST_TRY_FUSED_1NN_TILE(half, int); +CUVS_INST_TRY_FUSED_1NN_TILE(half, int64_t); #undef CUVS_INST_TRY_FUSED_1NN_TILE diff --git a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp index 563d2583d8..4c0964631c 100644 --- a/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp +++ b/cpp/src/distance/detail/fused_distance_nn/cutile/fused_1nn_tile.hpp @@ -5,11 +5,9 @@ #pragma once -#include #include #include -#include #include #include @@ -26,18 +24,11 @@ template inline constexpr bool is_fused_1nn_cutile_data_v = std::is_same_v || std::is_same_v; -template -inline constexpr bool is_fused_1nn_kvp_output_v = - is_fused_1nn_cutile_data_v && (std::is_same_v> || - std::is_same_v>); - -template -concept Fused1nnKvpOutput = is_fused_1nn_kvp_output_v; - #if CUVS_CUTILE_ENABLED -template - requires Fused1nnKvpOutput -bool try_fused_1nn_tile(OutT* min, +template + requires is_fused_1nn_cutile_data_v +bool try_fused_1nn_tile(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -49,9 +40,9 @@ bool try_fused_1nn_tile(OutT* min, bool is_sqrt, cudaStream_t stream); #else -template - requires Fused1nnKvpOutput -bool try_fused_1nn_tile(OutT*, +template +bool try_fused_1nn_tile(IdxT*, + DataT*, const DataT*, const DataT*, const DataT*, @@ -67,23 +58,6 @@ bool try_fused_1nn_tile(OutT*, } #endif -template - requires(!Fused1nnKvpOutput) -bool try_fused_1nn_tile(OutT*, - const DataT*, - const DataT*, - const DataT*, - const DataT*, - IdxT, - IdxT, - IdxT, - cuvs::distance::DistanceType, - bool, - cudaStream_t) -{ - return false; -} - } // namespace detail } // namespace distance } // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh index 12f4f17cac..cc16d8a2e1 100644 --- a/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -1,11 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "../distance_ops/cosine.cuh" // ops::l2_exp_distance_op +#include "../distance_ops/cosine.cuh" // ops::cosine_distance_op #include "../pairwise_distance_base.cuh" // PairwiseDistances #include "cutlass_base.cuh" #include "helper_structs.cuh" @@ -24,13 +24,9 @@ namespace distance { namespace detail { -template -void fusedCosineNN(OutT* min, +template +void fusedCosineNN(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -42,15 +38,20 @@ void fusedCosineNN(OutT* min, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + raft::KeyValuePair* cutlass_out, cudaStream_t stream) { - // The kernel policy is determined by fusedL2NN. typedef Policy P; dim3 blk(P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); typedef raft::KeyValuePair KVPair; + if (cutlass_out == nullptr) { + initFused1nnOutput(nearest_idx, nearest_dist, m, maxVal, stream); + RAFT_CUDA_TRY(cudaGetLastError()); + } + namespace arch = raft::util::arch; using AccT = DataT; ops::cosine_distance_op distance_op{}; @@ -58,7 +59,7 @@ void fusedCosineNN(OutT* min, raft::identity_op fin_op{}; auto kernel = fusedDistanceNNkernel; - // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 void* kernel_ptr = reinterpret_cast(kernel); auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. using cosineOp = cuvs::distance::detail::ops::cosine_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; cosineOp cosine_dist_op; @@ -86,7 +82,7 @@ void fusedCosineNN(OutT* min, cutlassFusedDistanceNN(m, n, shmemSize, kernel); kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + cutlass_out, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } } diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh index f1aad72110..8c532e2932 100644 --- a/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -24,13 +24,9 @@ namespace distance { namespace detail { -template -void fusedL2NNImpl(OutT* min, +template +void fusedL2NNImpl(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -43,19 +39,17 @@ void fusedL2NNImpl(OutT* min, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer, + raft::KeyValuePair* cutlass_out, cudaStream_t stream) { - // The kernel policy is determined by fusedL2NN. typedef Policy P; dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); typedef raft::KeyValuePair KVPair; - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); + if (initOutBuffer && cutlass_out == nullptr) { + initFused1nnOutput(nearest_idx, nearest_dist, m, maxVal, stream); RAFT_CUDA_TRY(cudaGetLastError()); } @@ -66,7 +60,7 @@ void fusedL2NNImpl(OutT* min, raft::identity_op fin_op{}; auto kernel = fusedDistanceNNkernel; - // Get pointer to fp32 SIMT kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 void* kernel_ptr = reinterpret_cast(kernel); auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = cuvs::distance::detail::ops::l2_exp_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); @@ -95,7 +83,7 @@ void fusedL2NNImpl(OutT* min, cutlassFusedDistanceNN(m, n, shmemSize, kernel); kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + cutlass_out, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } } diff --git a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh index 762c720568..3bd78ba5ab 100644 --- a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh +++ b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -32,20 +32,43 @@ struct KVPMinReduceImpl { }; // KVPMinReduce +/** Writes fused 1-NN results to separate idx/dist arrays (dist may be null). */ template struct MinAndDistanceReduceOpImpl { typedef typename raft::KeyValuePair KVP; + LabelT* out_idx{nullptr}; + DataT* out_dist{nullptr}; + /** When set, CUTLASS/SIMT global merge writes here instead of SoA (caller unpacks). */ + KVP* out_kvp{nullptr}; + + DI void merge(LabelT rid, const KVP& other) const + { + if (out_kvp != nullptr) { + if (other.value < out_kvp[rid].value) { out_kvp[rid] = other; } + } else if (out_dist != nullptr) { + if (other.value < out_dist[rid]) { + out_dist[rid] = other.value; + if (out_idx != nullptr) { out_idx[rid] = other.key; } + } + } else if (out_idx != nullptr) { + // Idx-only output: dist must still be tracked for multi-tile merge; caller must provide + // out_dist or use a single-pass backend (cuTile). KMeans always passes both buffers. + out_idx[rid] = other.key; + } + } + DI void operator()(LabelT rid, KVP* out, const KVP& other) const { - if (other.value < out->value) { + if (out != nullptr && other.value < out->value) { out->key = other.key; out->value = other.value; } } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const { - if (other.value < out->value) { + if (out != nullptr && other.value < out->value) { out->key = other.key; out->value = other.value; } @@ -53,35 +76,41 @@ struct MinAndDistanceReduceOpImpl { DI void operator()(LabelT rid, DataT* out, const KVP& other) const { - if (other.value < *out) { *out = other.value; } + if (out != nullptr && other.value < *out) { *out = other.value; } } DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const { - if (other.value < *out) { *out = other.value; } + if (out != nullptr && other.value < *out) { *out = other.value; } } DI void operator()(LabelT rid, DataT* out, const DataT& other) const { - if (other < *out) { *out = other; } + if (out != nullptr && other < *out) { *out = other; } } DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const { - if (other < *out) { *out = other; } + if (out != nullptr && other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const + { + if (out != nullptr) { *out = maxVal; } } - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; - out->key = 0xfffffff0; + out->key = LabelT(0); } - DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(DataT& /*out*/, LabelT /*idx*/) const {} + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } }; @@ -96,6 +125,53 @@ struct MinReduceOpImpl { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } }; +template +RAFT_KERNEL initFused1nnOutputKernel(IdxT* nearest_idx, DataT* nearest_dist, IdxT m, DataT maxVal) +{ + IdxT tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + if (nearest_idx != nullptr) { nearest_idx[tid] = IdxT(0); } + if (nearest_dist != nullptr) { nearest_dist[tid] = maxVal; } + } +} + +template +void initFused1nnOutput( + IdxT* nearest_idx, DataT* nearest_dist, IdxT m, DataT maxVal, cudaStream_t stream) +{ + if (nearest_idx == nullptr && nearest_dist == nullptr) { return; } + auto blks = raft::ceildiv(m, 256); + initFused1nnOutputKernel + <<>>(nearest_idx, nearest_dist, m, maxVal); +} + +template +RAFT_KERNEL unpackFused1nnKvpToSoaKernel(IdxT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IdxT n) +{ + IdxT i = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (i < n) { + if (nearest_idx != nullptr) { nearest_idx[i] = kvp[i].key; } + if (nearest_dist != nullptr) { nearest_dist[i] = kvp[i].value; } + } +} + +template +void unpackFused1nnKvpToSoa(IdxT* nearest_idx, + DataT* nearest_dist, + const raft::KeyValuePair* kvp, + IdxT m, + cudaStream_t stream) +{ + if (nearest_idx == nullptr && nearest_dist == nullptr) { return; } + auto blks = raft::ceildiv(m, 256); + unpackFused1nnKvpToSoaKernel + <<>>(nearest_idx, nearest_dist, kvp, m); + RAFT_CUDA_TRY(cudaGetLastError()); +} + template RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { @@ -106,15 +182,13 @@ RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) template void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) { - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); } // cg::reduce functor for FusedDistanceNN used in its cutlass version // to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template +template struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; @@ -122,7 +196,6 @@ struct kvp_cg_min_reduce_op { using AccTypeT = AccType; using IndexT = Index; - // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } diff --git a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index caa6a36d53..0d9f5333af 100644 --- a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -437,10 +437,8 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); if (row < total_rows) { - volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - if ((block_start_row_first_tile_ + row) < extent_row_) { - user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); + user_params.red_op_.merge(block_start_row_first_tile_ + row, row_local_min); } } diff --git a/cpp/src/distance/fused_distance_nn-inl.cuh b/cpp/src/distance/fused_distance_nn-inl.cuh index 3fa80a9b60..13c4faa472 100644 --- a/cpp/src/distance/fused_distance_nn-inl.cuh +++ b/cpp/src/distance/fused_distance_nn-inl.cuh @@ -28,48 +28,10 @@ namespace distance { * \ingroup fused_l2_nn * @{ */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNN(OutT* min, + +template +void fusedDistanceNN(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -85,12 +47,10 @@ void fusedDistanceNN(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. bool is_skinny = k < 32; size_t bytes = sizeof(DataT) * k; @@ -100,10 +60,10 @@ void fusedDistanceNN(OutT* min, if (is_skinny) { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -119,14 +79,15 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -142,16 +103,17 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { if (is_skinny) { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -167,14 +129,15 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl< DataT, - OutT, IdxT, typename raft::linalg::Policy4x4::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -190,15 +153,16 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } else { if (is_skinny) { detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -214,13 +178,14 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } else { detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, + ReduceOpT>(nearest_idx, + nearest_dist, x, y, xn, @@ -236,44 +201,23 @@ void fusedDistanceNN(OutT* min, isRowMajor, metric, metric_arg, + cutlass_kvp_scratch, stream); } } } /** - * @brief Wrapper around fusedDistanceNN with minimum reduction operators. - * - * fusedDistanceNN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). + * @brief Fused GEMM + 1-NN minimum reduction. * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream + * @param[out] nearest_idx Nearest neighbor index per row, length `m` (required). + * @param[out] nearest_dist Minimum distance per row, length `m` (optional, may be null). + * @param[in] cutlass_kvp_scratch Temp KVP buffer, length `m`; required when CUTLASS/SIMT runs. + * Unused when cuTile handles the launch. */ -template -void fusedDistanceNNMinReduce(OutT* min, +template +void fusedDistanceNNMinReduce(IdxT* nearest_idx, + DataT* nearest_dist, const DataT* x, const DataT* y, const DataT* xn, @@ -287,28 +231,33 @@ void fusedDistanceNNMinReduce(OutT* min, bool isRowMajor, cuvs::distance::DistanceType metric, float metric_arg, + raft::KeyValuePair* cutlass_kvp_scratch, cudaStream_t stream) { MinAndDistanceReduceOp redOp; + redOp.out_idx = nearest_idx; + redOp.out_dist = nearest_dist; KVPMinReduce pairRedOp; - fusedDistanceNN(min, - x, - y, - xn, - yn, - m, - n, - k, - workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); + fusedDistanceNN(nearest_idx, + nearest_dist, + x, + y, + xn, + yn, + m, + n, + k, + workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + isRowMajor, + metric, + metric_arg, + cutlass_kvp_scratch, + stream); } /** @} */ diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 9b96f94bf0..d4e0099035 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -120,7 +120,7 @@ ConfigureTest( ConfigureTest( NAME CLUSTER_TEST PATH cluster/kmeans.cu cluster/kmeans_balanced.cu cluster/kmeans_find_k.cu cluster/linkage.cu - cluster/connect_knn.cu cluster/spectral.cu + cluster/connect_knn.cu cluster/spectral.cu cluster/soa_unpack_trace.cu GPUS 1 PERCENT 100 ) diff --git a/cpp/tests/neighbors/distance_nn.cu b/cpp/tests/neighbors/distance_nn.cu index f5efaa5bec..6b17fc646b 100644 --- a/cpp/tests/neighbors/distance_nn.cu +++ b/cpp/tests/neighbors/distance_nn.cu @@ -42,7 +42,7 @@ __global__ void fill_int8(int8_t* buff, int len, int seed_offset) template class NNTest : public ::testing::TestWithParam> { public: - using OutT = raft::KeyValuePair; + using RefOutT = raft::KeyValuePair; NNTest() : params_{::testing::TestWithParam>::GetParam()}, m{params_.m}, @@ -55,8 +55,10 @@ class NNTest : public ::testing::TestWithParam> { y{raft::make_device_matrix(handle, n, k)}, x_norm{raft::make_device_vector(handle, m)}, y_norm{raft::make_device_vector(handle, n)}, - out{raft::make_device_vector(handle, m)}, - ref_out{raft::make_device_vector(handle, m)} + out_idx{raft::make_device_vector(handle, m)}, + out_dist{raft::make_device_vector(handle, m)}, + out_kvp{raft::make_device_vector(handle, m)}, + ref_out{raft::make_device_vector(handle, m)} { } @@ -92,15 +94,11 @@ class NNTest : public ::testing::TestWithParam> { workspace_size = m * n * sizeof(AccT); } - // Reset buffer - if constexpr (std::is_same_v>) { - // OutT is a RAFT KeyValuePair - raft::matrix::fill( - handle, raft::make_device_matrix_view(out.data_handle(), m, 1), OutT{0, 0}); - } else { - // OutT is a scalar type - raft::matrix::fill(handle, raft::make_device_matrix_view(out.data_handle(), m, 1), OutT{0}); - } + raft::matrix::fill(handle, raft::make_device_matrix_view(out_idx.data_handle(), m, 1), IdxT{0}); + raft::matrix::fill( + handle, raft::make_device_matrix_view(out_dist.data_handle(), m, 1), AccT{0}); + raft::matrix::fill( + handle, raft::make_device_matrix_view(ref_out.data_handle(), m, 1), RefOutT{0, 0}); raft::resource::sync_stream(handle, stream); } @@ -109,34 +107,36 @@ class NNTest : public ::testing::TestWithParam> { raft::device_vector workspace = raft::make_device_vector(handle, workspace_size); - ref_nn( + ref_nn( ref_out.data_handle(), x.data_handle(), y.data_handle(), m, n, k, sqrt, metric, stream); if constexpr (impl == ImplType::fused) { if constexpr (std::is_same_v) { - cuvs::distance::fusedDistanceNNMinReduce(out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - m, - n, - k, - (void*)workspace.data_handle(), - sqrt, - true, - true, - metric, - 0.0, - stream); + cuvs::distance::fusedDistanceNNMinReduce(out_idx.data_handle(), + out_dist.data_handle(), + x.data_handle(), + y.data_handle(), + x_norm.data_handle(), + y_norm.data_handle(), + m, + n, + k, + (void*)workspace.data_handle(), + sqrt, + true, + true, + metric, + 0.0, + out_kvp.data_handle(), + stream); } else { static_assert(sizeof(DataT) == 0, "fusedDistanceNNMinReduce is not implemented for datatype other than float"); } } else if constexpr (impl == ImplType::unfused) { - cuvs::distance::unfusedDistanceNNMinReduce( + cuvs::distance::unfusedDistanceNNMinReduce( handle, - out.data_handle(), + out_kvp.data_handle(), x.data_handle(), y.data_handle(), x_norm.data_handle(), @@ -156,7 +156,12 @@ class NNTest : public ::testing::TestWithParam> { void compare() { - vector_compare(handle, ref_out.data_handle(), out.data_handle(), m, summary); + if constexpr (impl == ImplType::fused) { + vector_compare_soa( + handle, ref_out.data_handle(), out_idx.data_handle(), out_dist.data_handle(), m, summary); + } else { + vector_compare(handle, ref_out.data_handle(), out_kvp.data_handle(), m, summary); + } ASSERT_TRUE(summary.max_diff < params_.tol) << summary; } @@ -174,8 +179,10 @@ class NNTest : public ::testing::TestWithParam> { raft::device_matrix y; raft::device_vector x_norm; raft::device_vector y_norm; - raft::device_vector out; - raft::device_vector ref_out; + raft::device_vector out_idx; + raft::device_vector out_dist; + raft::device_vector out_kvp; + raft::device_vector ref_out; size_t workspace_size; }; diff --git a/cpp/tests/neighbors/distance_nn_helper.cuh b/cpp/tests/neighbors/distance_nn_helper.cuh index ea440387b4..dfa71f71a4 100644 --- a/cpp/tests/neighbors/distance_nn_helper.cuh +++ b/cpp/tests/neighbors/distance_nn_helper.cuh @@ -209,6 +209,34 @@ class ComparisonSummary { } }; +template +void vector_compare_soa(raft::resources const& handle, + const raft::KeyValuePair* ref, + const IdxT* out_idx, + const AccT* out_dist, + IdxT n, + ComparisonSummary& summary) +{ + auto ref_h = raft::make_host_vector, IdxT>(n); + auto idx_h = raft::make_host_vector(n); + auto dist_h = raft::make_host_vector(n); + + raft::copy(ref_h.data_handle(), ref, n, raft::resource::get_cuda_stream(handle)); + raft::copy(idx_h.data_handle(), out_idx, n, raft::resource::get_cuda_stream(handle)); + raft::copy(dist_h.data_handle(), out_dist, n, raft::resource::get_cuda_stream(handle)); + raft::resource::sync_stream(handle, raft::resource::get_cuda_stream(handle)); + + summary.init(); + + for (IdxT i = 0; i < n; i++) { + const double a_val = double(dist_h(i)); + const double b_val = double(ref_h(i).value); + const bool missed = idx_h(i) != ref_h(i).key; + const double diff = std::abs(a_val - b_val); + summary.update(diff, i, a_val, b_val, missed); + } +} + template void vector_compare( raft::resources const& handle, const OutT* a, const OutT* b, IdxT n, ComparisonSummary& summary) From 0a3da06d706a05a676f288288a78e9d3070fcbf7 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 30 Jun 2026 04:38:36 +0000 Subject: [PATCH 10/10] add reproducible benchmark scripts --- benchmark_kmeans.py | 410 ++++++++++++++++++++++++++++++++++++++++ run_benchmark_kmeans.sh | 47 +++++ 2 files changed, 457 insertions(+) create mode 100644 benchmark_kmeans.py create mode 100755 run_benchmark_kmeans.sh diff --git a/benchmark_kmeans.py b/benchmark_kmeans.py new file mode 100644 index 0000000000..cfabe7570a --- /dev/null +++ b/benchmark_kmeans.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 +r"""KMeans fit+predict benchmark: baseline / cuTile / flash-kmeans. + +Single impl (activate the target conda env first): + python benchmark_kmeans.py --impl baseline|cutile|flash --n N --d D --k K \\ + --max-iter 5 --tol 1e-4 --seed 42 \\ + --warmup-fit 1 --iters-fit 3 --warmup-pred 1 --iters-pred 3 + +Compare (subprocess per impl; export env vars, then --compare): + export BENCH_CONDA=/path/to/miniforge3 + export BENCH_ENV_BASE=cuvs_2608_base + export BENCH_ENV_CUTILE=cuvs_2608 + export BENCH_ENV_FLASH=cuvs_2608_base + python benchmark_kmeans.py --compare --n 33554432 --d 32 --k 64 \\ + --max-iter 5 --tol 1e-4 --seed 42 \\ + --warmup-fit 1 --iters-fit 3 --warmup-pred 1 --iters-pred 3 + +Smoke test (small shape, single impl): + conda activate cuvs_2608 + python benchmark_kmeans.py --impl cutile --n 10000 --d 32 --k 8 \\ + --max-iter 2 --tol 1e-4 --seed 42 \\ + --warmup-fit 0 --iters-fit 1 --warmup-pred 0 --iters-pred 1 + +Required for --compare (no defaults): + BENCH_CONDA path to miniforge/conda root + BENCH_ENV_BASE conda env name for baseline libcuvs + BENCH_ENV_CUTILE conda env name for cuTile libcuvs + BENCH_ENV_FLASH conda env name for flash-kmeans +""" + +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +ROOT = Path(__file__).resolve().parent +IMPLS = ("baseline", "cutile", "flash") + + +def _require_env(name: str) -> str: + val = os.environ.get(name) + if not val: + raise SystemExit(f"required environment variable {name} is not set") + return val + + +def _impl_config() -> dict[str, dict]: + conda = Path(_require_env("BENCH_CONDA")) + return { + "baseline": { + "bench_mode": "cuvs_base", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_BASE"), + }, + "cutile": { + "bench_mode": "cuvs_cutile", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_CUTILE"), + }, + "flash": { + "bench_mode": "flash", + "conda": conda, + "conda_env": _require_env("BENCH_ENV_FLASH"), + }, + } + + +@dataclass +class BenchResult: + impl: str + fit_median_ms: float | None = None + predict_median_ms: float | None = None + n_iter: int | None = None + inertia: float | None = None + error: str | None = None + + +def median(xs: list[float]) -> float: + import numpy as np + + return float(np.median(xs)) + + +def run_benchmark( + bench_mode: str, + n: int, + d: int, + k: int, + *, + max_iter: int, + tol: float, + seed: int, + warmup_fit: int, + iters_fit: int, + warmup_pred: int, + iters_pred: int, +) -> BenchResult: + import numpy as np + + rng = np.random.default_rng(seed) + init_centroids_host = rng.standard_normal((k, d), dtype=np.float32) + x_host = rng.standard_normal((n, d), dtype=np.float32) + input_gib = n * d * 4 / (1024**3) + + label = { + "cuvs_base": "baseline", + "cuvs_cutile": "cutile", + "flash": "flash", + }[bench_mode] + print( + f"=== N={n:,} D={d} K={k:,} iters={max_iter} input={input_gib:.2f} GiB ===", + flush=True, + ) + + if bench_mode in ("cuvs_base", "cuvs_cutile"): + from cuda.bindings import runtime as cudart + from pylibraft.common import device_ndarray + + from cuvs.cluster.kmeans import KMeansParams, fit, predict + + def sync(): + cudart.cudaDeviceSynchronize() + + x = device_ndarray(x_host) + params = KMeansParams( + n_clusters=k, + max_iter=max_iter, + tol=tol, + metric="sqeuclidean", + hierarchical=False, + init_method="Array", + n_init=1, + ) + + for _ in range(warmup_fit): + fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + + fit_times: list[float] = [] + n_iter = 0 + inertia = 0.0 + for _ in range(iters_fit): + t0 = time.perf_counter() + _, inertia, n_iter = fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + fit_times.append((time.perf_counter() - t0) * 1e3) + + centroids, _, _ = fit( + params, x, centroids=device_ndarray(init_centroids_host.copy()) + ) + sync() + + for _ in range(warmup_pred): + predict(params, x, centroids) + sync() + + pred_times: list[float] = [] + for _ in range(iters_pred): + t0 = time.perf_counter() + predict(params, x, centroids) + sync() + pred_times.append((time.perf_counter() - t0) * 1e3) + + print(f"impl={label} init=Array", flush=True) + print(f"fit_median_ms={median(fit_times):.2f}", flush=True) + print(f"predict_median_ms={median(pred_times):.2f}", flush=True) + print(f"n_iter={n_iter} inertia={inertia:.6g}", flush=True) + return BenchResult( + impl=label, + fit_median_ms=median(fit_times), + predict_median_ms=median(pred_times), + n_iter=n_iter, + inertia=inertia, + ) + + if bench_mode == "flash": + import torch + from flash_kmeans.assign_euclid_triton import euclid_assign_triton + from flash_kmeans.kmeans_triton_impl import batch_kmeans_Euclid + + def sync(): + torch.cuda.synchronize() + + x = torch.from_numpy(x_host).cuda() + init_c = ( + torch.from_numpy(init_centroids_host.copy()).cuda().unsqueeze(0) + ) + + def run_fit(init): + x_b = x.unsqueeze(0) + _, centroids_b, _ = batch_kmeans_Euclid( + x_b, + k, + max_iters=max_iter, + tol=tol, + init_centroids=init, + verbose=False, + ) + return centroids_b + + for _ in range(warmup_fit): + run_fit(init_c.clone()) + sync() + + fit_times = [] + for _ in range(iters_fit): + t0 = time.perf_counter() + run_fit(init_c.clone()) + sync() + fit_times.append((time.perf_counter() - t0) * 1e3) + + centroids_b = run_fit(init_c.clone()) + sync() + + x_b = x.unsqueeze(0) + x_sq = (x_b**2).sum(dim=-1) + + for _ in range(warmup_pred): + euclid_assign_triton(x_b, centroids_b, x_sq) + sync() + + pred_times = [] + for _ in range(iters_pred): + t0 = time.perf_counter() + euclid_assign_triton(x_b, centroids_b, x_sq) + sync() + pred_times.append((time.perf_counter() - t0) * 1e3) + + print("impl=flash-kmeans init=Array", flush=True) + print(f"fit_median_ms={median(fit_times):.2f}", flush=True) + print(f"predict_median_ms={median(pred_times):.2f}", flush=True) + return BenchResult( + impl="flash", + fit_median_ms=median(fit_times), + predict_median_ms=median(pred_times), + ) + + raise ValueError(f"unknown bench_mode={bench_mode!r}") + + +def _parse_output(text: str, impl: str) -> BenchResult: + fit_m = re.search(r"^fit_median_ms=([0-9.]+)", text, re.M) + pred_m = re.search(r"^predict_median_ms=([0-9.]+)", text, re.M) + if not fit_m or not pred_m: + return BenchResult(impl=impl, error=text.strip() or "no output") + n_iter_m = re.search(r"^n_iter=([0-9]+)", text, re.M) + inertia_m = re.search(r"^inertia=([0-9.eE+-]+)", text, re.M) + return BenchResult( + impl=impl, + fit_median_ms=float(fit_m.group(1)), + predict_median_ms=float(pred_m.group(1)), + n_iter=int(n_iter_m.group(1)) if n_iter_m else None, + inertia=float(inertia_m.group(1)) if inertia_m else None, + ) + + +def _run_subprocess( + impl: str, + n: int, + d: int, + k: int, + args: argparse.Namespace, +) -> BenchResult: + cfg = _impl_config()[impl] + conda = cfg["conda"] + env_exports = " ".join( + f'export {key}="{val}"' + for key, val in ( + ( + "CUDA_VISIBLE_DEVICES", + os.environ.get("CUDA_VISIBLE_DEVICES", ""), + ), + ("MAX_ITER", args.max_iter), + ("TOL", args.tol), + ("SEED", args.seed), + ("WARMUP_FIT", args.warmup_fit), + ("ITERS_FIT", args.iters_fit), + ("WARMUP_PRED", args.warmup_pred), + ("ITERS_PRED", args.iters_pred), + ) + if val != "" + ) + cmd = f""" +set -eo pipefail +source "{conda}/etc/profile.d/conda.sh" +conda activate "{cfg["conda_env"]}" +{env_exports} +python3 "{ROOT / "benchmark_kmeans.py"}" --impl {impl} --n {n} --d {d} --k {k} \\ + --max-iter {args.max_iter} --tol {args.tol} --seed {args.seed} \\ + --warmup-fit {args.warmup_fit} --iters-fit {args.iters_fit} \\ + --warmup-pred {args.warmup_pred} --iters-pred {args.iters_pred} +""" + proc = subprocess.run(["bash", "-lc", cmd], capture_output=True, text=True) + out = proc.stdout + proc.stderr + if proc.returncode != 0: + return BenchResult( + impl=impl, error=out.strip() or f"exit {proc.returncode}" + ) + return _parse_output(out, impl) + + +def _speedup(base: float, other: float) -> str: + if other <= 0: + return "n/a" + return f"{base / other:.2f}x" + + +def print_compare_table( + results: list[BenchResult], n: int, d: int, k: int +) -> None: + print(f"\n######## compare N={n} D={d} K={k} ########") + print(f"{'impl':<10} {'fit_ms':>10} {'pred_ms':>10} {'notes'}") + print("-" * 50) + by_impl = {r.impl: r for r in results} + for impl in IMPLS: + r = by_impl.get(impl) + if r is None: + print(f"{impl:<10} {'—':>10} {'—':>10} missing") + continue + if r.error: + print( + f"{impl:<10} {'FAIL':>10} {'FAIL':>10} {r.error.splitlines()[-1][:40]}" + ) + continue + print( + f"{impl:<10} {r.fit_median_ms:10.2f} {r.predict_median_ms:10.2f}" + ) + + flash = by_impl.get("flash") + cutile = by_impl.get("cutile") + if flash and cutile and flash.fit_median_ms and cutile.fit_median_ms: + if not flash.error and not cutile.error: + print( + f"\nflash vs cutile fit: {_speedup(cutile.fit_median_ms, flash.fit_median_ms)}" + f" predict: {_speedup(cutile.predict_median_ms, flash.predict_median_ms)}" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--compare", action="store_true", help="run baseline, cutile, flash" + ) + parser.add_argument("--impl", choices=IMPLS, help="single impl") + parser.add_argument("--n", type=int, required=True) + parser.add_argument("--d", type=int, required=True) + parser.add_argument("--k", type=int, required=True) + parser.add_argument("--max-iter", type=int, required=True) + parser.add_argument("--tol", type=float, required=True) + parser.add_argument("--seed", type=int, required=True) + parser.add_argument("--warmup-fit", type=int, required=True) + parser.add_argument("--iters-fit", type=int, required=True) + parser.add_argument("--warmup-pred", type=int, required=True) + parser.add_argument("--iters-pred", type=int, required=True) + args = parser.parse_args() + + if args.compare: + if args.impl: + parser.error("--compare and --impl are mutually exclusive") + _impl_config() # validate required env before launching subprocesses + results = [ + _run_subprocess(impl, args.n, args.d, args.k, args) + for impl in IMPLS + ] + print_compare_table(results, args.n, args.d, args.k) + return 0 if all(r.error is None for r in results) else 1 + + if not args.impl: + parser.error("set --impl for single-run mode, or use --compare") + + bench_mode = { + "baseline": "cuvs_base", + "cutile": "cuvs_cutile", + "flash": "flash", + }[args.impl] + + try: + run_benchmark( + bench_mode, + args.n, + args.d, + args.k, + max_iter=args.max_iter, + tol=args.tol, + seed=args.seed, + warmup_fit=args.warmup_fit, + iters_fit=args.iters_fit, + warmup_pred=args.warmup_pred, + iters_pred=args.iters_pred, + ) + except Exception as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/run_benchmark_kmeans.sh b/run_benchmark_kmeans.sh new file mode 100755 index 0000000000..37291b5518 --- /dev/null +++ b/run_benchmark_kmeans.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Compare baseline / cuTile / flash-kmeans for one shape. +# +# Usage: +# export BENCH_CONDA=/path/to/miniforge3 +# export BENCH_ENV_BASE=... +# export BENCH_ENV_CUTILE=... +# export BENCH_ENV_FLASH=... +# export MAX_ITER=5 TOL=1e-4 SEED=42 +# export WARMUP_FIT=1 ITERS_FIT=3 WARMUP_PRED=1 ITERS_PRED=3 +# ./run_benchmark_kmeans.sh N D K +# + +set -u + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ $# -ne 3 ]]; then + echo "usage: $0 N D K" >&2 + echo "See script header for required env vars and examples." >&2 + exit 2 +fi + +: "${BENCH_CONDA:?set BENCH_CONDA to conda/miniforge root}" +: "${BENCH_ENV_BASE:?set BENCH_ENV_BASE}" +: "${BENCH_ENV_CUTILE:?set BENCH_ENV_CUTILE}" +: "${BENCH_ENV_FLASH:?set BENCH_ENV_FLASH}" +: "${MAX_ITER:?set MAX_ITER}" +: "${SEED:?set SEED}" +: "${WARMUP_FIT:?set WARMUP_FIT}" +: "${ITERS_FIT:?set ITERS_FIT}" +: "${WARMUP_PRED:?set WARMUP_PRED}" +: "${ITERS_PRED:?set ITERS_PRED}" +: "${TOL:?set TOL}" + +N=$1 +D=$2 +K=$3 + +exec python3 "$SCRIPT_DIR/benchmark_kmeans.py" --compare \ + --n "$N" --d "$D" --k "$K" \ + --max-iter "$MAX_ITER" --tol "$TOL" --seed "$SEED" \ + --warmup-fit "$WARMUP_FIT" --iters-fit "$ITERS_FIT" \ + --warmup-pred "$WARMUP_PRED" --iters-pred "$ITERS_PRED"