diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 252ea7281d981..66141c58e7a01 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -78,8 +78,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1436672" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" + echo "Binary size threshold in bytes: 1440768" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1440768" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index a9197b3732dd8..2369af53621b2 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -141,3 +141,31 @@ jobs: cd /onnxruntime_src/onnxruntime/test/python/transformers python test_cuda_plugin_ep.py " + + # --- Run the CUDA plugin EP C++ GoogleTest binary --- + # onnxruntime_provider_test is built into the artifact and links the plugin tests + # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). These tests register the plugin .so via + # GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), which returns the + # platform-specific filename without a directory component. Run from /build/Release/Release + # so that filename resolves to the plugin .so built there. + # The filter covers every CUDA plugin EP suite linked into this binary: + # CudaPlugin* -> CudaPluginUserStreamGraphTest, CudaPluginArenaTest, + # CudaPluginPartitioningTest, CudaPluginProfilingTest + # CudaResourcePartitioning* -> CudaResourcePartitioningTest + - name: Run CUDA Plugin EP C++ Tests + run: | + docker run --rm --gpus all \ + -v ${{ github.workspace }}:/onnxruntime_src \ + -v ${{ runner.temp }}/Release:/build/Release \ + -e NVIDIA_VISIBLE_DEVICES=all \ + ${{ steps.build_docker_image_step.outputs.full-image-name }} \ + bash -c " + set -ex + export PATH=/opt/python/cp312-cp312/bin:\$PATH + # Make libcudart.so.13 (and the plugin's CUDA deps) findable; see note above. + export LD_LIBRARY_PATH=/build/Release/Release:/usr/local/cuda-13.0/lib64:\${LD_LIBRARY_PATH:-} + + cd /build/Release/Release + ls -la onnxruntime_provider_test libonnxruntime_providers_cuda_plugin.so + ./onnxruntime_provider_test --gtest_filter='CudaPlugin*:CudaResourcePartitioning*' + " diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index c6346d9c4e932..9776c22fedabb 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -148,7 +148,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU CUDA CI Pipeline Test Job @@ -260,4 +260,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_cuda_plugin.yml b/.github/workflows/windows_cuda_plugin.yml index 538c1d783cd68..c2d84d7be482a 100644 --- a/.github/workflows/windows_cuda_plugin.yml +++ b/.github/workflows/windows_cuda_plugin.yml @@ -118,7 +118,7 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows CUDA Plugin EP Test @@ -214,4 +214,4 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_gpu_doc_gen.yml b/.github/workflows/windows_gpu_doc_gen.yml index 5e50a970875fc..b41de5542b0c6 100644 --- a/.github/workflows/windows_gpu_doc_gen.yml +++ b/.github/workflows/windows_gpu_doc_gen.yml @@ -44,7 +44,7 @@ jobs: setVcvars: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 52581c7d0a5f5..d87a4919a86fd 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -26,7 +26,7 @@ jobs: timeout-minutes: 240 env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e OnnxRuntimeBuildDirectory: ${{ github.workspace }} DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 35c620ca6f650..5560de89040de 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -29,7 +29,7 @@ jobs: QnnLibKind: [shared_lib, static_lib] env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index d5710795942d1..3ad0076de6d52 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -154,7 +154,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU TensorRT CI Pipeline Test Job @@ -265,4 +265,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/cmake/deps.txt b/cmake/deps.txt index e303ccd9f8a98..d6a5b71221dc8 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -50,7 +50,7 @@ protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/downlo psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v3.0.2.zip;a064e663b4d7a337ac291d1bef7337ef4e60a1ae -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/403d652dca4c1046e8145950b1c0997a9f748b57.zip;30b2a07fe4bae8574f89176e56274cacdd6d135b +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/4628dc060ce4e82345dc166bbac875609db4ff69.zip;e58d4b47c16a982111c897e669ae4f1821a393d7 re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 1a1e4921a41e6..ed3b0aa8192a7 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -371,9 +371,7 @@ if (CPUINFO_SUPPORTED) PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && # https://github.com/pytorch/cpuinfo/pull/324 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch && - # https://github.com/pytorch/cpuinfo/pull/348 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/win_arm_fp16_detection_fallback.patch + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b8ab7142b6b35..d55a4d49fb455 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/flashattn_qkv.cpp + ${MLAS_SRC_DIR}/flashattn_gqa.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f692f1f5e0a57..2aa31276cc395 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -208,7 +208,7 @@ target_compile_definitions(onnxruntime_providers_cuda PRIVATE FILE_NAME=\"onnxruntime_providers_cuda.dll\") endif() - # Work around a CUDA 13.x cudafe++ (EDG front-end) regression that mis-parses CCCL's + # Work around a CUDA 13.3 cudafe++ (EDG front-end) regression that mis-parses CCCL's # global-qualified partial specializations, e.g. in : # template # struct ::cuda::proclaims_copyable_arguments<...> : ::cuda::std::true_type {}; @@ -218,7 +218,7 @@ # corrected copies of the affected headers into the build tree and place that directory # ahead of the toolkit cccl include path. This is a no-op on toolkits whose headers do not # contain the offending pattern (e.g. once NVIDIA fixes it), so it is safe to keep enabled. - function(ort_cuda13_patch_cccl_header src dst) + function(ort_cuda133_patch_cccl_header src dst) if (NOT EXISTS "${src}") return() endif() @@ -412,19 +412,21 @@ if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) foreach(inc_dir ${CUDAToolkit_INCLUDE_DIRS}) if (EXISTS "${inc_dir}/cccl") - # Generate cudafe++-parseable copies of the CCCL headers that contain global-qualified - # partial specializations (see ort_cuda13_patch_cccl_header above) and put the fixed - # directory ahead of the toolkit cccl include so the corrected headers win. - set(_ort_cccl_fix_dir "${CMAKE_CURRENT_BINARY_DIR}/cccl_cuda13_fix") - ort_cuda13_patch_cccl_header( - "${inc_dir}/cccl/cub/device/device_transform.cuh" - "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh") - ort_cuda13_patch_cccl_header( - "${inc_dir}/cccl/cub/device/dispatch/tuning/tuning_transform.cuh" - "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") - if (EXISTS "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh" OR - EXISTS "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") - target_include_directories(${target} BEFORE PRIVATE "${_ort_cccl_fix_dir}") + if (UNIX AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.3 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 13.4) + # Generate cudafe++-parseable copies of the CCCL headers that contain global-qualified + # partial specializations (see ort_cuda133_patch_cccl_header above) and put the fixed + # directory ahead of the toolkit cccl include so the corrected headers win. + set(_ort_cccl_fix_dir "${CMAKE_CURRENT_BINARY_DIR}/cccl_cuda13_fix") + ort_cuda133_patch_cccl_header( + "${inc_dir}/cccl/cub/device/device_transform.cuh" + "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh") + ort_cuda133_patch_cccl_header( + "${inc_dir}/cccl/cub/device/dispatch/tuning/tuning_transform.cuh" + "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") + if (EXISTS "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh" OR + EXISTS "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") + target_include_directories(${target} BEFORE PRIVATE "${_ort_cccl_fix_dir}") + endif() endif() # Add the cccl subdirectory to the include path so can be found diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 551e877d5f6d8..86e5579eb6761 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -88,10 +88,11 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. -# shape_op.cc inherits from onnxruntime::OpKernel (framework) -# which cannot convert to ep::adapter::OpKernel in the plugin build. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") +# shape_op.cc is INCLUDED in the plugin build. It provides an adapter-based +# Shape kernel under #ifdef BUILD_CUDA_EP_AS_PLUGIN (the CPU onnxruntime::Shape +# class, which derives from the framework OpKernel, is only used in the +# non-plugin build). Registering Shape on the EP keeps it off the CPU EP and +# avoids Memcpy nodes that would otherwise break CUDA Graph capture. # Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 23eccb22476df..bb171b056b400 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -603,6 +603,7 @@ set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ep_context_data_api.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_experimental_api.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc @@ -1609,8 +1610,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() - - if(onnxruntime_USE_QNN) + # Build ep_weight_sharing_ctx_gen for all supported EPs (QNN, TensorRT, OpenVINO, VitisAI) + if(onnxruntime_USE_QNN OR onnxruntime_USE_TENSORRT OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_VITISAI) #qnn ctx generator set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) set(ep_weight_sharing_ctx_gen_src_patterns @@ -2174,6 +2175,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND # file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.cc" + "${TEST_SRC_DIR}/autoep/library/ep_context_data_utils.h" "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) diff --git a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch index 005cd458fdd2b..47a1054e25107 100644 --- a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch +++ b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch @@ -1,10 +1,19 @@ diff --git a/src/linux/processors.c b/src/linux/processors.c -index 47bee76..d0c5569 100644 +index fd040a3..2ca8ec4 100644 --- a/src/linux/processors.c +++ b/src/linux/processors.c -@@ -2,0 +3 @@ +@@ -3,6 +3,7 @@ + #include + #include + #include +#include -@@ -291,0 +293,22 @@ + + #if !defined(__ANDROID__) + /* +@@ -289,6 +290,28 @@ static bool max_processor_number_parser(uint32_t processor_list_start, uint32_t + return true; + } + +static uint32_t cpuinfo_linux_get_max_processor_from_sysconf( + uint32_t max_processors_count, + const char* processor_list_name) { @@ -27,13 +36,31 @@ index 47bee76..d0c5569 100644 + return max_processor; +} + -@@ -301 +324 @@ + uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) { + uint32_t max_possible_processor = 0; + if (!cpuinfo_linux_parse_cpulist( +@@ -298,7 +321,7 @@ uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, POSSIBLE_CPULIST_FILENAME); -@@ -323 +346 @@ + } + if (max_possible_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -320,7 +343,7 @@ uint32_t cpuinfo_linux_get_max_present_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, PRESENT_CPULIST_FILENAME); -@@ -357,0 +381,31 @@ + } + if (max_present_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -355,6 +378,37 @@ static bool detect_processor_parser(uint32_t processor_list_start, uint32_t proc + return true; + } + +static bool cpuinfo_linux_detect_processors_from_sysconf( + uint32_t max_processors_count, + uint32_t* processor0_flags, @@ -65,7 +92,13 @@ index 47bee76..d0c5569 100644 + return true; +} + -@@ -373 +427,6 @@ + bool cpuinfo_linux_detect_possible_processors( + uint32_t max_processors_count, + uint32_t* processor0_flags, +@@ -370,7 +424,12 @@ bool cpuinfo_linux_detect_possible_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -73,7 +106,13 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + possible_flag, + POSSIBLE_CPULIST_FILENAME); -@@ -392 +451,6 @@ + } + } + +@@ -389,7 +448,12 @@ bool cpuinfo_linux_detect_present_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -81,3 +120,6 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + present_flag, + PRESENT_CPULIST_FILENAME); + } + } + diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index 67bd18e61cc28..9140a233e2ccd 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,13 +6,12 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 403d652dca4c1046e8145950b1c0997a9f748b57 - SHA512 f7cd6dc44bd1120af610cae1337ed4c0f557ba78d2de9c73fed350fa3dfe9512643a1619ae55f5a540c6316a87d641856cca27297bb8766e48f39b7b7a59da1f - HEAD_REF master + REF 4628dc060ce4e82345dc166bbac875609db4ff69 + SHA512 db7a93279f2f6daaf825fbd8552935d8ed671d276b65ad614e11f722b6a6848e663850d65180d33b554d67ef1a36aae842feb368699f90be8f21172a1af1924e + HEAD_REF main PATCHES patch_cpuinfo_h_for_arm64ec.patch patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 - win_arm_fp16_detection_fallback.patch # https://github.com/pytorch/cpuinfo/pull/348 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index e5a211c9fd11a..ffce29682ca42 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` -- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel) + +The non-quantized flash attention tiled kernel is implemented in MLAS: + +- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel) +- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`) The operator schema itself is defined in: @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. -The quantized path has two execution strategies: +Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). ## Supported Cache Modes @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. -## Flash Attention Path +## Quantized Flash Attention Path -The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. +The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. ### Algorithm @@ -204,6 +211,93 @@ The partials buffer is allocated alongside the per-thread scratch in a single al - Per-thread scratch: `scores[Bc]` (one float per KV block element) - Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) +## Non-Quantized Flash Attention Path + +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. + +### Differences from the Quantized Path + +- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step. +- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`. +- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`. +- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA` before the tiled loop runs. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +#### Causal early-termination + +During prefill, every KV block whose start index is at or beyond the largest global query +position in the current q_block is fully causally masked and contributes nothing. The kernel +computes a per-q_block bound +`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once +`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle +QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the +main reason the FP32 flash path is faster than naive even at short sequence lengths +(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all +KV positions, so the bound equals `total_seqlen` and nothing is skipped. + +### Activation Conditions + +The non-quantized flash path is selected when ALL of the following hold: + +- The kernel specialization is `float` (FP16 uses the naive path) +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `total_sequence_length > 1` +- No softcap +- No smooth softmax +- No head sink +- No output QK capture +- `present_key` and `present_value` are provided + +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes, Threading, and Flash Decoding + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + +#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`) + +The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for +prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size` +query rows and tiling delivers real cache-locality and SGEMM packing benefits. + +For single-token decode the query tile has `M = 1`, so every K/V element is streamed +exactly once with no reuse across query rows. Tiling provides **no** cache-locality +benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM +B-packing/setup cost on every call — which previously made the flash decode path *slower* +than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths. + +Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`), +dispatched whenever `sequence_length == 1` and flash decoding is not active. It +parallelizes over `(batch, head)` and, per head, computes the attention directly with two +matrix-vector products and a two-pass softmax: + +- **QK GEMV** — `scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`. +- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` / + `ComputeSumExpF32Kernel` helpers. +- **SV GEMV** — `out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`. + +Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS +translation unit, so their inner loops use independent accumulator lanes / map-style +updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the +single new token is the most recent position and attends to every cached token); only +optional local-window masking and additive attention bias are applied. The kernel streams +K and V exactly once each, so it is memory-bandwidth bound. + +The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned +across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products +instead of `M = 1` SGEMM calls, removing the same packing overhead. + + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. @@ -428,7 +522,57 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle | 4096 (N=32) | +2131 | +87 | 24.5x | **Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. +### Non-Quantized (FP32) Flash Attention vs Naive benchmark results +Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache, +`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with: + +```bash +python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \ + --fp32 --prompt_only --warmup 10 --repeats 30 +``` + +#### Latency — Prefill (S = T, prompt phase) + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x | +| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x | +| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x | +| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x | + +The FP32 flash path is faster than naive across all measured prefill lengths. With the causal +early-termination described above, roughly half of the QK/SV work (the causally masked +upper triangle of the square prefill attention matrix) is skipped entirely, which more than +offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale). +The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is +algorithmic rather than purely from threading. + +#### Latency — Decode (S = 1, token generation) + +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H` +GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead +(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev +box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness +to suppress that jitter. + +| Total Seqlen | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) | +| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) | +| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) | +| 4097 | 6.1 | 4.5 | 1.35\u20131.5x | + +Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of +the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the +per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM +B-packing overhead that previously made flash decode noticeably **slower** than naive +(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity +for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated) +and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed +single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound, +so it is not the target of the prefill-oriented causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/docs/cuda_plugin_ep/arena_allocator_migration_design.md b/docs/cuda_plugin_ep/arena_allocator_migration_design.md index 285aa3e60ed5c..d4ff21e713f85 100644 --- a/docs/cuda_plugin_ep/arena_allocator_migration_design.md +++ b/docs/cuda_plugin_ep/arena_allocator_migration_design.md @@ -62,7 +62,18 @@ if (!factory.arena_allocator_) { **Stream-aware allocation.** `ArenaImpl::AllocOnStream(size, stream)` tracks which chunks are assigned to which stream. `ResetChunksUsingStream(stream_impl)` is called from `OrtSyncStreamImpl::OnSessionRunEnd` to release chunk-to-stream assignments when a session run completes. -**Read-only allocator bypasses arena.** The factory creates a plain `CustomAllocator` (no arena) for `OrtReadOnlyAllocator` (initializers), since initializer memory doesn't benefit from arena allocation. +**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` uses the framework `OrtSyncStream*` exposed through `KernelContext_GetSyncStream` to stream-tag scratch chunks, while kernels continue to use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` for launches and library handles. This keeps allocation bookkeeping on the same framework stream wrapper that the arena stores in `chunk->stream` and later queries through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). If the negotiated ORT API version does not include `KernelContext_GetSyncStream`, the adapter falls back to a null stream tag and the EP does not advertise concurrent run support. + +#### Scratch buffer stream tagging + +A common review question is: *"Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is yes for concurrent multi-stream runs, but the allocator must receive the framework stream wrapper, not the raw CUDA handle. + +- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is consumed by the kernel on its raw CUDA compute stream. +- **Raw CUDA stream and framework stream are different objects.** `KernelContext_GetGPUComputeStream` returns the raw `cudaStream_t` used for CUDA calls. The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) because that stable wrapper is what it persists in each chunk. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (`OrtSyncStreamImpl`), but that is not the ORT-core `OrtSyncStream*` the arena expects. +- **How the plugin bridges them.** `KernelContext_GetSyncStream` exposes the framework stream for the current kernel dispatch. The CUDA plugin adapter records the mapping from raw `cudaStream_t` to framework stream when migrated kernels call `GetComputeStream(ctx)`, and `GetScratchBuffer` uses the framework stream for `AllocOnStream`. This preserves the existing migrated-kernel pattern while making scratch chunks safe for cross-stream reuse decisions. +- **Compatibility fallback.** When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, scratch allocations use a null stream tag. A null tag is correct for serialized runs and single-unified-stream CUDA graph capture, but it is not safe for overlapping runs on different CUDA streams, so the plugin EP only advertises concurrent `Session::Run()` when `KernelContext_GetSyncStream` is available. + + ### 2.2 How ORT Core Calls the Factory diff --git a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md index 8092a15e26973..9035ce91bb3bb 100644 --- a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md +++ b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md @@ -34,6 +34,7 @@ Session::Run() **Key design choices:** - Each thread gets its own dedicated graph `cudaStream_t`, `CudaGraphManager`, and capture bookkeeping for the EP instance. `CudaSyncStream::InitHandlesWithExternalStream()` wraps the thread's graph stream so graph capture sees the same stream as kernels. The manager stores captured `cudaGraphExec_t` executables keyed by annotation ID, allowing multiple graphs (e.g., different input shapes) for that thread. +- When a `user_compute_stream` is supplied together with graph capture, the per-thread context adopts that user-owned stream as its graph stream instead of creating one, so capture/replay run on the same stream the caller drives. The context records that it does not own the stream and never destroys it. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). - Warm-up runs (default: 2) allow memory allocations to stabilize before capture begins. - Graph annotation IDs are parsed from `OrtRunOptions` key `"gpu_graph_id"`. ID `-1` skips capture; `0` is the default. @@ -53,25 +54,51 @@ Session::Run() Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supported. For the warm-up count, `ep.cuda.min_num_runs_before_cuda_graph_capture` is also accepted. +The provider option `user_compute_stream` (a `cudaStream_t` passed as a pointer) may be combined with `enable_cuda_graph`. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). + --- +## User Compute Stream + CUDA Graph + +A caller can supply its own CUDA stream through the `user_compute_stream` provider option and enable CUDA graph capture at the same time. This combination was previously rejected with `ORT_INVALID_ARGUMENT`; it is now supported and matches the bundled (non-plugin) CUDA EP. + +When both options are set: + +- `CudaEpFactory::CreateEpImpl` no longer rejects the pair. Setting `user_compute_stream` still forces unified-stream mode (matching the bundled EP). +- `CudaEp::CreateSyncStreamForDeviceImpl` wraps the user stream via `InitHandlesWithUserStream()`, attaching full cuBLAS/cuDNN/cuBLASLt handles to it. +- `CudaEp::GetPerThreadContext()` builds the thread's `PerThreadContext` around the user stream (`external_graph_stream`) instead of creating an EP-owned graph stream. Capture and replay therefore run on the same stream the kernels are issued to. +- The context records `owns_graph_stream = false`, so it tears down captured graph execs on destruction but never calls `cudaStreamDestroy` on the user-owned stream. Stream lifetime stays with the caller. + +Because the user supplies one stream, this mode is inherently single-stream; the per-thread graph isolation still applies if the same session is driven from multiple threads, but each thread must drive its own captures on the stream it provides. + +### `user_compute_stream` is not limited to the CUDA graph case + +A natural question when reading `GetPerThreadContext()` is why `use_external_stream` is gated on `has_user_compute_stream && enable_cuda_graph` — does that restrict a user compute stream to graph-enabled runs? It does not. + +- A user compute stream is honored for kernels in **both** graph and non-graph runs. That happens in `CudaEp::CreateSyncStreamForDeviceImpl`, whose first branch wraps `config_.user_compute_stream` via `InitHandlesWithUserStream()` **independently of `enable_cuda_graph`**. +- The `enable_cuda_graph` term in `use_external_stream` only governs the `PerThreadContext`'s *graph stream*. `PerThreadContext` is a graph-capture-only object: `GetPerThreadContext()` is reached exclusively from the graph path (the `enable_cuda_graph` branch of `CreateSyncStreamForDeviceImpl`, `OnRunStart`/`OnRunEnd`, `IsGraphCaptured`, `ReplayGraph`). With graph disabled, no `PerThreadContext` is ever constructed, so its stream-ownership flag is irrelevant. +- The flag therefore answers a narrower question — *"should the per-thread capture/replay graph stream adopt (and not destroy) the user's stream?"* — which is only meaningful when a graph is actually being captured. + ## Implementation Summary ### Files Changed | File | Change | |------|--------| -| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to use the current thread's graph stream when graph capture is enabled, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to wrap a `user_compute_stream` or otherwise use the current thread's graph stream when graph capture is enabled, made `PerThreadContext` adopt the user stream as its (non-owned) graph stream when `user_compute_stream` + `enable_cuda_graph` are combined, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | | `onnxruntime/core/providers/cuda/plugin/cuda_ep.h` | Added `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` config fields, graph callback declarations, and a per-thread graph context cache | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc` | **NEW** — Complete `CudaGraphSet` and `CudaGraphManager` implementation | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h` | **NEW** — Header for graph manager types and constants | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member | -| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture`; removed the validation that rejected `user_compute_stream` + `enable_cuda_graph` (the combination is now supported) | +| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) and stream-tags scratch chunks with the framework stream exposed by `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call | +| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available | +| `include/onnxruntime/core/session/onnxruntime_c_api.h` | Added `KernelContext_GetSyncStream` so plugin kernels can obtain the framework `OrtSyncStream*` for stream-aware allocation bookkeeping while still using `KernelContext_GetGPUComputeStream` for raw CUDA work | | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` | | `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` | | `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion | -| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread | +| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread; added a fallback in the PrePack loop to resolve the kernel's default-memory allocator (`Info().GetAllocator()`) when the device-keyed initializer-allocator lookup returns null for a separately-registered plugin EP | | `onnxruntime/core/framework/session_state.h` | Added thread-affine stream pool bucket state for `DeviceStreamCollection` reuse | | `onnxruntime/core/session/inference_session.h` | Added `RunImpl()` private method and `kMaxGraphCaptureWarmupRuns` constant | | `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | Added version-gated `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` bridge implementations | @@ -83,7 +110,8 @@ Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supp - **Thread safety**: Mutable graph state and graph streams are stored per thread. ORT core's `DeviceStreamCollection` cache is also thread-affine, so graph-enabled runs can recycle stream wrappers without exposing them to a different thread. - **Scope**: Capture/replay pipeline plus allocator compatibility. Arena integration is complete — see the [Arena Allocator Integration](#arena-allocator-integration) section. - **Callback assignment**: `IsGraphCaptureEnabled` and `GetGraphCaptureNodeAssignmentPolicy` are always set. `OnRunStart`, `OnRunEnd` are conditional on `enable_cuda_graph`. `IsGraphCaptured` and `ReplayGraph` are always set (return false/error when disabled). -- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally to use the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when disabled. +- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally: it wraps a user-provided `user_compute_stream` (via `InitHandlesWithUserStream`) when one is set, otherwise uses the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when both are disabled. +- **User compute stream + CUDA graph**: These options can now be combined (previously rejected at factory creation). When both are set, `CudaEp::GetPerThreadContext()` builds the `PerThreadContext` around the user's stream (`external_graph_stream`) so capture and replay run on the same stream the kernels use, and the context never destroys the user-owned stream (`owns_graph_stream = false`). - **Run-end synchronization**: `OnRunEndImpl` honors the `sync_stream` flag without double-synchronizing replayed graphs, preserving the normal EP completion contract. - **Stream collection reuse**: ORT core now recycles `DeviceStreamCollection` objects into a thread-affine session pool keyed by a per-thread lifetime token. Warm-up, capture, replay, and later user-visible `Run()` calls on the same thread can reuse the same stream wrappers, while dead-thread buckets are pruned before they can be reused by another thread. - **Per-thread context lifecycle**: Thread-local caches hold the strong `PerThreadContext` references, so CUDA streams and captured graph executables are released when the owning thread exits. The EP tracks weak references to those cache maps to remove stale entries during EP destruction without keeping the contexts alive. @@ -101,6 +129,7 @@ CUDA graph capture requires that all memory allocations happen during warmup, no **Arena integration details (now implemented):** - Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls. +- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` stream-tags scratch chunks with the framework `OrtSyncStream*` exposed by `KernelContext_GetSyncStream`. The raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` is still used for CUDA launches and library calls; the framework stream is used only for the arena's cross-stream reuse bookkeeping. - When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology. - Pinned allocations are also arena-backed, but remain non-stream-aware. - The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay. @@ -109,14 +138,12 @@ CUDA graph capture requires that all memory allocations happen during warmup, no ### Concurrent Run Support -Concurrent `Session::Run()` is supported with CUDA graph enabled: +Concurrent `Session::Run()` is advertised by the CUDA plugin EP when the host ORT runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. -- `CudaEp::PerThreadContext` owns the graph stream, graph manager, warm-up run counts, and memory watermark for the current thread. -- The current thread's cache owns the `PerThreadContext`; new threads get independent contexts, and exited threads release their contexts automatically. -- `CreateSyncStreamForDeviceImpl()` wraps the current thread's graph stream, so warm-up, capture, and replay all use the same stream for that thread. -- `CudaGraphManager::CaptureBegin()` uses `cudaStreamCaptureModeThreadLocal`, allowing overlapping capture scopes on different threads. -- ORT core recycles graph-enabled `DeviceStreamCollection` objects into a thread-affine session pool, so internal warm-up/capture recursion and later top-level `Run()` calls on the same thread reuse the same stream wrappers without cross-thread leakage. -- `IsGraphCaptured()` and `ReplayGraph()` resolve the current thread's graph context. If a new thread runs a graph-enabled session for the first time, that thread performs its own warm-up and capture before replaying. +- `CudaEp::PerThreadContext` still owns graph stream, graph manager, warm-up run counts, and memory watermark state per thread. This keeps graph bookkeeping thread-local and avoids sharing captured graph executables across threads. +- Plugin kernels now obtain the framework `OrtSyncStream*` through `KernelContext_GetSyncStream` and use it only for scratch/workspace allocation bookkeeping. CUDA work still launches on the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`. +- Stream-tagged scratch chunks let the shared arena apply its normal cross-stream reuse rules for overlapping runs on different CUDA streams. +- When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, `CudaKernel::GetScratchBuffer` falls back to a null stream tag and `CudaEp::IsConcurrentRunSupportedImpl()` returns false. ## Verification @@ -128,6 +155,7 @@ Concurrent `Session::Run()` is supported with CUDA graph enabled: - `test_cuda_graph_with_mempool` — graph capture with `arena.use_cuda_mempool=1` - `test_cuda_graph_annotation_id` — multiple graphs via `gpu_graph_id` run config - `test_cuda_graph_add_model` — graph capture with Add op (arena-backed) +4. `onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc` is a C++ test (gated by `ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP`) covering `user_compute_stream` combined with `enable_cuda_graph`: it verifies session creation succeeds with both options set (regression for the removed validation), capture + replay on the user stream produce correct results, and replay after an in-place input update on the user stream is correct. ## Future Work diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 2f61da90233b4..8f9a0d388d1af 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -97,10 +97,11 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA | API surface | Newest `\since` used | Representative functions | | --- | --- | --- | | `OrtApi` — direct calls (`ort_api_.*`, `Ort::GetApi().*`) | **1.23** | `SyncStream_GetHandle`, `GetTensorSizeInBytes`, `GetRunConfigEntry`, `CreateMemoryInfo_V2`, `Graph_GetNumNodes`/`Graph_GetNodes` (older: `CreateStatus`, `Logger_LogMessage`, `*KeyValuePairs`, `HardwareDevice_*`, `MemoryInfoGet*`, `GetSessionConfigEntry`) | +| `OrtApi` — optional gated kernel-context capability | **1.28** | `KernelContext_GetSyncStream` (called from the adapter only when `CurrentOrtApiVersion() >= 28`; otherwise scratch allocation uses a null stream tag and concurrent run support is not advertised) | | `OrtEpApi` — direct calls (`ep_api_.*`, `Ort::GetEpApi().*`) | **1.24** | `CreateKernelRegistry`, `KernelRegistry_AddKernel`, `ReleaseKernelRegistry`, `CreateIfKernel`/`CreateLoopKernel`/`CreateScanKernel`, `EpGraphSupportInfo_LookUpKernel` (older: `MemoryDevice_*`, `MemoryInfo_GetMemoryDevice`, `SyncStream_*`, `EpDevice_AddAllocatorInfo`, `EpGraphSupportInfo_AddSingleNode`, `CreateEpDevice`/`ReleaseEpDevice`) | | EP profiler API (only when built with `ENABLE_CUDA_PROFILING`) | **1.25** | `CreateProfilingEvent`, `ProfilingEventsContainer_AddEvents`, `ReleaseProfilingEvent` (called from `cuda_profiler_plugin.cc` via the `Ort::ProfilingEvent` / `Ort::UnownedProfilingEventsContainer` wrappers) | -`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from the optional EP profiler, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. +`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from optional gated capabilities such as EP profiling and stream-tagged scratch allocation, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. **Defensive capability gating.** Reading a struct field is safe because the field is append-only and ORT only reads fields it knows about. The real hazard is *calling* an `OrtApi`/`OrtEpApi` function that the (possibly older) runtime does not provide. The correct guard for that is the runtime API version, `onnxruntime::ep::CurrentOrtApiVersion()`, not `ort_version_supported`. The `CudaEp` constructor (`cuda_ep.cc`) therefore reads `const uint32_t ort_version = onnxruntime::ep::CurrentOrtApiVersion();` and only installs an `OrtEp` callback when that runtime version is new enough to provide both the callback field and every API its implementation calls: @@ -113,7 +114,9 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA All other `OrtEp` and `OrtEpFactory` callbacks are `\since 1.24` or older and are installed unconditionally. Gating `CreateProfiler` is what makes the three `\since 1.25` profiler functions unreachable on an older runtime: when the profiler is never created, ORT never drives the `OrtEpProfilerImpl` callbacks that call them. -The gates use **graceful degradation rather than throwing**: the gated callbacks are all optional capabilities (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). +`KernelContext_GetSyncStream` is guarded at the adapter call site rather than through an `OrtEp` callback field: `OpKernelContext::GetSyncStream()` returns null when `CurrentOrtApiVersion() < 28`, and `CudaEp::IsConcurrentRunSupportedImpl()` only advertises concurrent runs when that API is available. Older runtimes therefore keep the previous serialized-run behavior while still using the same plugin binary. + +The gates use **graceful degradation rather than throwing**: the gated callbacks and adapter capabilities are optional features (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting, stream-tagged scratch for concurrent runs), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). --- @@ -459,9 +462,18 @@ The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: | 2 | Cache the shim provider pointer in the adapter `OpKernelInfo` | Implemented; fixes the observed NHWC runtime crash | | 3 | Consolidate allowlists, improve internal-domain diagnostics, and strengthen structural NHWC assertions | Recommended follow-up work | +#### 5.3.2 Allocator Resolution for Kernels (Scratch and PrePack) + +Migrated kernels need a valid device allocator in two places: scratch/workspace buffers during `Compute()`, and one-time weight conversion or packing during `PrePack()`. Both now resolve the allocator the same way the bundled CUDA EP does, through the kernel's own `OpKernelInfo`. + +- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) and stream-tags scratch chunks with the framework `OrtSyncStream*` from `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). CUDA launches still use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`; the framework stream is used only for stream-aware arena bookkeeping. +- **PrePack.** The framework prepack loop (`SessionState::PrepackConstantInitializedTensors`) resolves the allocator with `GetInitializerAllocator(kernel->Info().GetDevice(OrtMemTypeDefault))`, a session map keyed by device. For a plugin EP registered as a separate library, that device-keyed lookup can miss and return null. The loop now falls back to `kernel->Info().GetAllocator(OrtMemTypeDefault)` when the lookup is null, so every `PrePack` implementation receives a valid allocator at the single framework call site. This replaces the earlier approach of adding a per-kernel `if (!alloc) alloc = Info().GetAllocator(...)` guard to each prepacking op (which only covered the few ops that were touched and risked missing future ones). The fallback is behavior-neutral for in-tree EPs, whose device-keyed lookup already succeeds, and it does **not** force `is_packed`/`prepacked_weights` handling \u2014 ops such as `QMoE` and `MatMulNBits` still set `is_packed = true` and populate prepacked weights normally. + +The enabling adapter changes are in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h) and [`include/onnxruntime/ep/adapter/op_kernel.h`](../../include/onnxruntime/ep/adapter/op_kernel.h): `IAllocatorWrappingOrtAllocator` implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), and `OpKernelContext::GetSyncStream()` exposes the framework stream when the negotiated ORT API version includes `KernelContext_GetSyncStream`. The CUDA plugin uses that framework stream for `GetScratchBuffer`; if it is unavailable, allocation falls back to a null stream tag and concurrent `Session::Run()` is not advertised. + ### 5.4 CUDA Graph Support -CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and concurrent `Session::Run()` support. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and concurrent run details — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. +CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is supported when the host runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. #### 5.4.1 OrtEp C API Extensions (v1.26) @@ -488,6 +500,10 @@ Session-level changes in `inference_session.cc`: - **Bounded recursion**: After each normal run when graph capture is enabled, the session recursively calls `RunImpl()` (bounded by `kMaxGraphCaptureWarmupRuns = 8`) until the graph is captured. From the user's perspective, a single `Run()` call handles the entire warm-up + capture sequence. - **Stream collection lifetime**: ORT core now caches `DeviceStreamCollection` objects in thread-affine session buckets keyed by a per-thread lifetime token. Graph-enabled runs recycle and reacquire stream wrappers only on the creating thread, which preserves warm-up/capture reuse without cross-thread leakage. +#### 5.4.3 User Compute Stream with CUDA Graph + +A caller-provided `user_compute_stream` may be combined with `enable_cuda_graph` (the factory previously rejected this pair). When both are set, `CudaEp::GetPerThreadContext()` builds the per-thread graph context around the user-owned stream rather than an EP-owned one, so capture and replay run on the same stream the kernels are issued to (matching the bundled CUDA EP). The context marks the stream as not owned and never destroys it. Details are in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#user-compute-stream--cuda-graph). + --- ## 6. EP Adapter Layer (`include/onnxruntime/ep/adapter/`) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 815fc6aa69a60..a6d8eaecad0c0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1608,6 +1608,34 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return *prepacked_weights_for_graph_; } + // Tags a fusion-generated initializer (whose name is not stable across sessions) with a stable, + // content-derived identity that SessionState uses to key cross-session pre-pack sharing. + // + // Single-consumer invariant: a MatMulNBits packed buffer folds in the *consuming* node's + // scales/zero_points/attributes, not B alone, so this id is meaningful only for a B initializer that + // has exactly one consumer. The DQ->MatMulNBits producers guarantee that -- each generated B has a + // unique name with a single consumer, and the fusion bails when the source weight/scale is shared (the + // DQMatMulNotConvertedToMatMulNBits_SharedWeight case). If a future change ever tags a multi-consumer + // initializer whose consumers differ in scales/zp/attrs, they would compute different ids for the same + // name and the last writer would silently mis-share. Enforce that a name is never re-tagged with a + // conflicting id so the invariant survives later refactors. + void SetSharedPrepackInitializerId(const std::string& initializer_name, std::string share_id) { + auto it = generated_shared_prepack_ids_.find(initializer_name); + if (it != generated_shared_prepack_ids_.end()) { + ORT_ENFORCE(it->second == share_id, "MatMulNBits pre-pack sharing id for initializer '", + initializer_name, "' was re-tagged with a different id; the single-consumer invariant ", + "is violated (a multi-consumer weight whose consumers differ in scales/zp/attrs)."); + return; + } + generated_shared_prepack_ids_.emplace(initializer_name, std::move(share_id)); + } + + // Returns the sharing identity for a generated initializer, or nullptr if it was not tagged. + const std::string* GetSharedPrepackInitializerId(const std::string& initializer_name) const { + auto it = generated_shared_prepack_ids_.find(initializer_name); + return it == generated_shared_prepack_ids_.end() ? nullptr : &it->second; + } + /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */ const Node* ParentNode() const { return parent_node_; } @@ -2011,6 +2039,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // This is optional due to delayed construction. std::optional prepacked_weights_for_graph_; + // Maps a fusion-generated initializer name to its cross-session sharing identity. + // See SetSharedPrepackInitializerId. + InlinedHashMap generated_shared_prepack_ids_; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Runtime optimization storage. // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5dd53a8cf45c0..a73868527771a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7496,6 +7496,27 @@ struct OrtApi { * \since Version 1.28. */ ORT_API_T(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); + + /** \brief Get the framework synchronization stream associated with a kernel context. + * + * This returns the framework stream wrapper for the execution provider stream used by this kernel invocation. + * It is intended for APIs that need a stable framework stream object for stream-aware allocation and + * synchronization bookkeeping. Use KernelContext_GetGPUComputeStream when launching native GPU work. + * + * \param[in] context OrtKernelContext instance. + * \param[out] out Returns the framework synchronization stream, or nullptr if the kernel has no stream. + * Do not free or mutate the returned pointer. It is owned by the underlying session. + * The pointer may be stored and used for stream-aware allocation and synchronization + * bookkeeping beyond the Compute call (e.g. an allocator may persist it in arena + * chunks); it remains valid until the owning Session::Run() completes its teardown. + * Do not retain or dereference it after the run that produced this kernel context ends. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.28. + */ + ORT_API2_STATUS(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4798d3d4ad1b8..55a4e36167e86 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3020,6 +3020,7 @@ struct KernelContext { UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; + OrtSyncStream* GetSyncStream() const; Logger GetLogger() const; Ort::Allocator GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d7439e7b356c6..ed3abc0961be6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2876,6 +2876,12 @@ inline void* KernelContext::GetGPUComputeStream() const { return out; } +inline OrtSyncStream* KernelContext::GetSyncStream() const { + OrtSyncStream* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetSyncStream(ctx_, &out)); + return out; +} + inline Ort::Allocator KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { OrtAllocator* out = nullptr; Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h index e5b9bd4713a1c..ce76bd385cd85 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h @@ -46,6 +46,68 @@ ORT_RUNTIME_CLASS(ModelPackageOptions); ORT_RUNTIME_CLASS(ModelPackageContext); ORT_RUNTIME_CLASS(ModelPackageComponentContext); +// Opaque handle holding the EPContext callbacks and opaque state extracted from an OrtSessionOptions instance. Used by +// the experimental OrtEpApi_* EPContext data functions. Create via OrtEpApi_SessionOptions_GetEpContextConfig and +// release with OrtEpApi_ReleaseEpContextConfig. +ORT_RUNTIME_CLASS(EpContextConfig); + +/** \brief Function called to write named binary data. + * + * This callback is currently used for EPContext binary data, but its contract is intentionally generic so future APIs + * can reuse it for other named data payloads. The callback is called synchronously by the component that receives it. + * ORT does not own or retain buffer after the callback returns. ORT does not serialize invocations made by different + * EP instances or worker threads. + * + * Each callback invocation represents one complete write operation for name. The callback signature does not + * provide an offset, sequence number, or final-chunk marker, so the component invoking the callback must define any + * chunked ordering and completion contract with the application. Current EPContext use should prefer a single callback + * invocation per EPContext binary unless chunking semantics are documented by the EP. + * + * The application's implementation can process the data in any way (e.g., encrypt and store, upload to cloud storage, + * or compress) before persisting it. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid for the duration required by the API that accepted the callback and must provide + * any synchronization required if it can be used concurrently. + * \param[in] name The file name or logical data identifier as a null-terminated UTF-8 string. + * \param[in] buffer The buffer containing data to write. + * \param[in] buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * On failure, use CreateStatus to provide error info with an appropriate OrtErrorCode + * (e.g., ORT_FAIL); ORT propagates the returned code. ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteNamedBufferFunc)(_In_ void* state, + _In_ const char* name, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called to read named binary data. + * + * This callback is currently used for EPContext binary data, but its contract is intentionally generic so future APIs + * can reuse it for other named data payloads. The application reads, processes (e.g., decrypts, decompresses, + * downloads), and returns the requested data. ORT provides an allocator so the application can allocate the output + * buffer directly. The callback is called synchronously by the component that receives it. ORT does not serialize + * invocations made by different EP instances or worker threads. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid for the duration required by the API that accepted the callback and must provide + * any synchronization required if it can be used concurrently. + * \param[in] name The file name or logical data identifier to read as a null-terminated UTF-8 string. + * \param[in] allocator ORT-provided allocator. The application must use this to allocate the output buffer. + * \param[out] buffer Set by the implementation to the allocated buffer containing the output data. + * \param[out] data_size Set by the implementation to the size of the output data in bytes. + * + * \return OrtStatus* Read status. Return nullptr on success. + * On failure, use CreateStatus to provide error info with an appropriate OrtErrorCode + * (e.g., ORT_FAIL); ORT propagates the returned code. ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtReadNamedBufferFunc)(_In_ void* state, + _In_ const char* name, + _In_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* data_size); + // // C function pointer typedefs and name constants // diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 57a4e472b6f6d..b1140485f7ff1 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -282,3 +282,117 @@ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession, _In_ OrtModelPackageComponentContext* context, _In_opt_ const OrtSessionOptions* session_options, _Outptr_ OrtSession** session) + +/** \brief Registers a callback to provide EPContext binary data during session load. + * + * When loading a compiled model with external (non-embedded) EPContext binary data, an execution provider can + * retrieve this callback from OrtEpContextConfig and call it instead of reading the binary data from disk. + * + * Reading happens at session load, so this callback is configured on OrtSessionOptions. The corresponding write + * callback runs only at compile time and is configured on OrtModelCompilationOptions via + * OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid while any session or EP created + * from these options may call the callback. If the same state may be used by multiple EPs or threads, the application + * is responsible for synchronization. + * + * \param[in] options The OrtSessionOptions instance. + * \param[in] read_func The OrtReadNamedBufferFunc callback. Pass NULL to clear a previously set callback (any + * previously set state is cleared as well). + * \param[in] state Opaque state passed to read_func. Can be NULL. Ignored when read_func is NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_SessionOptions_SetEpContextDataReadFunc, + _Inout_ OrtSessionOptions* options, _In_opt_ OrtReadNamedBufferFunc read_func, _In_opt_ void* state) + +/** \brief Sets a callback for writing EPContext binary data during compilation. + * + * When EPContext embed mode is disabled, execution providers can retrieve this callback from OrtEpContextConfig and + * call it instead of writing EPContext binary data directly to disk. + * + * This callback may be used together with OrtCompileApi::ModelCompilationOptions_SetEpContextBinaryInformation. The + * binary information still describes the compiled model/output location that EPs may use to generate stable logical + * EPContext data names or as a file-fallback location. If this callback is configured, EPs should call it for + * EPContext binary data instead of writing that data to the fallback file path. + * + * Writing happens only at compile time, so this callback is configured on OrtModelCompilationOptions. The + * corresponding read callback runs at session load and is configured on OrtSessionOptions via + * OrtApi_SessionOptions_SetEpContextDataReadFunc. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid for the duration of the compile + * operation that may call the callback. If the same state may be used by multiple EPs or threads, the application is + * responsible for synchronization. + * + * Like OrtApi_SessionOptions_SetEpContextDataReadFunc, passing a NULL write_func clears any previously set callback + * (any previously set state is cleared as well). Calling this multiple times overwrites the previously configured + * callback. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteNamedBufferFunc callback used to write EPContext bytes. Pass NULL to clear a + * previously set callback (any previously set state is cleared as well). + * \param[in] state Opaque state passed to write_func. Can be NULL. Ignored when write_func is NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_opt_ OrtWriteNamedBufferFunc write_func, _In_opt_ void* state) + +/** \brief Extracts the EPContext configuration (callbacks and state) from an OrtSessionOptions instance. + * + * The EP should call this during CreateEp() while session_options is still valid, and store the returned handle for + * use during Compile(). On success, `*config` is set to a non-NULL handle that must be released with + * OrtEpApi_ReleaseEpContextConfig. On failure, an error status is returned and `*config` is not modified. + * + * The returned handle owns only ORT's copy of callback function pointers and opaque state pointer values. It does not + * own the application-provided state. The application is responsible for keeping callback state valid and + * synchronized while an EP may call callbacks retrieved from this config. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[out] config The extracted OrtEpContextConfig. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_SessionOptions_GetEpContextConfig, + _In_ const OrtSessionOptions* session_options, _Outptr_ OrtEpContextConfig** config) + +/** \brief Release an OrtEpContextConfig instance. + * + * \param[in] config The OrtEpContextConfig instance to release. May be NULL. + */ +ORT_EXPERIMENTAL_API(28, void, OrtEpApi_ReleaseEpContextConfig, _Frees_ptr_opt_ OrtEpContextConfig* config) + +/** \brief Get the application-provided EPContext data read callback. + * + * Returns the OrtReadNamedBufferFunc and opaque state pointer registered via + * OrtApi_SessionOptions_SetEpContextDataReadFunc. If no callback was registered, *read_func and *state are set to + * NULL. The EP is responsible for calling the callback when present and for using its own normal read path when no + * callback is present. + * + * \param[in] config The OrtEpContextConfig from OrtEpApi_SessionOptions_GetEpContextConfig. + * \param[out] read_func The registered read callback, or NULL if none was registered. + * \param[out] state Opaque state pointer passed to read_func, or NULL if none was registered. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_EpContextConfig_GetEpContextDataReadFunc, + _In_ const OrtEpContextConfig* config, _Out_ OrtReadNamedBufferFunc* read_func, + _Out_ void** state) + +/** \brief Get the application-provided EPContext data write callback. + * + * Returns the OrtWriteNamedBufferFunc and opaque state pointer registered via + * OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc. If no callback was registered, *write_func and + * *state are set to NULL. The EP is responsible for calling the callback when present and for using its own normal + * write path when no callback is present. + * + * \param[in] config The OrtEpContextConfig from OrtEpApi_SessionOptions_GetEpContextConfig. + * \param[out] write_func The registered write callback, or NULL if none was registered. + * \param[out] state Opaque state pointer passed to write_func, or NULL if none was registered. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc, + _In_ const OrtEpContextConfig* config, _Out_ OrtWriteNamedBufferFunc* write_func, + _Out_ void** state) diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h index fbd7ba3659435..d3f3a37eca33d 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h @@ -108,5 +108,79 @@ namespace Experimental { // C++ wrapper types or helpers go here in the `Ort::Experimental` namespace. // +// Move-only RAII owner for an OrtEpContextConfig handle, which carries the EPContext read/write callbacks and opaque +// state extracted from an OrtSessionOptions instance. The handle is released via OrtEpApi_ReleaseEpContextConfig when +// the wrapper is destroyed. +// +// Typical EP usage: construct from the session options during CreateEp(), keep the wrapper for the EP's lifetime, and +// query the callbacks via GetReadFunc() / GetWriteFunc(). +class EpContextConfig { + public: + explicit EpContextConfig(std::nullptr_t) noexcept {} + + explicit EpContextConfig(const SessionOptions& session_options) : EpContextConfig{session_options.GetConst()} {} + + // Extracts the EPContext config from `session_options`. Throws Ort::Exception (ORT_NOT_IMPLEMENTED) if the + // experimental functions are not available in this build, or propagates any error from the extraction. + explicit EpContextConfig(ConstSessionOptions session_options) { + const OrtApi* api = &GetApi(); + // Ensure the release function is available before creating a handle, so the handle can always be freed. + Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_FnOrThrow(api); + auto* get_config = Get_OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28_FnOrThrow(api); + ThrowOnError(get_config(static_cast(session_options), &config_)); + } + + EpContextConfig(EpContextConfig&& other) noexcept : config_{other.config_} { other.config_ = nullptr; } + + EpContextConfig& operator=(EpContextConfig&& other) noexcept { + if (this != &other) { + reset(); + config_ = other.config_; + other.config_ = nullptr; + } + return *this; + } + + EpContextConfig(const EpContextConfig&) = delete; + EpContextConfig& operator=(const EpContextConfig&) = delete; + + ~EpContextConfig() { reset(); } + + OrtEpContextConfig* get() const noexcept { return config_; } + explicit operator bool() const noexcept { return config_ != nullptr; } + + // Relinquishes ownership of the handle without releasing it. + OrtEpContextConfig* release() noexcept { + OrtEpContextConfig* released = config_; + config_ = nullptr; + return released; + } + + // Releases any owned handle and resets to empty. + void reset() noexcept { + if (config_ != nullptr) { + if (auto* release_fn = Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_Fn(&GetApi())) { + release_fn(config_); + } + config_ = nullptr; + } + } + + // Returns the configured read callback and opaque state (both nullptr if none was set). Throws on failure. + void GetReadFunc(OrtReadNamedBufferFunc& read_func, void*& state) const { + auto* get_read_func = Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_FnOrThrow(&GetApi()); + ThrowOnError(get_read_func(config_, &read_func, &state)); + } + + // Returns the configured write callback and opaque state (both nullptr if none was set). Throws on failure. + void GetWriteFunc(OrtWriteNamedBufferFunc& write_func, void*& state) const { + auto* get_write_func = Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_FnOrThrow(&GetApi()); + ThrowOnError(get_write_func(config_, &write_func, &state)); + } + + private: + OrtEpContextConfig* config_ = nullptr; +}; + } // namespace Experimental } // namespace Ort diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 0db30f39b3f57..1fb78f81fce19 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -41,21 +41,19 @@ class IAllocatorWrappingOrtAllocator final : public IAllocator { } bool IsStreamAware() const override { - return false; - - // TODO: Enable once AllocOnStream() is implemented. - // static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; - // const OrtAllocator* raw = ort_allocator_; - // return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + const OrtAllocator* raw = ort_allocator_; + return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; } - void* AllocOnStream(size_t /*size*/, Stream* /*stream*/) override { - // TODO: Implement AllocOnStream(). - // The internal `onnxruntime::IAllocator::AllocOnStream` signature takes an internal `onnxruntime::Stream*` - // argument, while the public `::OrtAllocator::AllocOnStream` signature takes an `::OrtSyncStream*` argument. - // We need to properly map from one to the other. - // `::OrtSyncStream*` should be treated as an opaque type from the plugin EP's perspective. - ORT_NOT_IMPLEMENTED("IAllocatorWrappingOrtAllocator::AllocOnStream is not implemented yet."); + void* AllocOnStream(size_t size, Stream* stream) override { + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + OrtAllocator* raw = ort_allocator_; + if (raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr) { + return raw->AllocOnStream(raw, size, reinterpret_cast(stream)); + } + + return raw->Alloc(raw, size); } void GetStats(AllocatorStats* stats) override { diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 27a46cc10e306..1f103b64a443e 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -164,6 +164,14 @@ struct OpKernelContext { void* GetGPUComputeStream() const { return context_.GetGPUComputeStream(); } + OrtSyncStream* GetSyncStream() const { + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + if (CurrentOrtApiVersion() < kOrtKernelContextGetSyncStreamMinVersion) { + return nullptr; + } + + return context_.GetSyncStream(); + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index df14bc8c57f24..0a06156fe78d8 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -150,25 +150,57 @@ def _extract_cuda_major_version(version_str: str) -> str: return version_str.split(".", maxsplit=1)[0] if version_str else "12" -def _get_cufft_version(cuda_major: str) -> str: +def _get_cufft_version(cuda_major_version: str) -> str: """Get cufft library version based on CUDA major version. Args: - cuda_major: CUDA major version as string (e.g., "12", "13") + cuda_major_version: CUDA major version as string (e.g., "12", "13") Returns: cufft version as string """ # cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12 - return "12" if cuda_major == "13" else "11" + return "12" if int(cuda_major_version) >= 13 else "11" def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True): - # Dynamically determine CUDA major version from build info + # Dynamically determine CUDA major version from build info. + # build_cuda_version defaults to the version this package was built with; it is a parameter for testability. cuda_major_version = _extract_cuda_major_version(cuda_version) cufft_version = _get_cufft_version(cuda_major_version) - if is_windows: + # Starting with CUDA 13, NVIDIA consolidated the per-component CUDA Toolkit wheels + # (cublas, cufft, cuda_runtime, cuda_nvrtc, curand, ...) into a single "nvidia/cu{major}" + # package and dropped the "-cuNN" suffix from those package names. On Windows the DLLs + # moved into an architecture sub-folder ("bin/", e.g. "bin/x86_64"); on Linux the + # libraries are placed directly in "lib" (the wheel itself is architecture specific, so + # there is no arch sub-folder). cuDNN keeps its own "nvidia/cudnn" package and layout. + use_consolidated_layout = cuda_major_version.isdigit() and int(cuda_major_version) >= 13 + + if use_consolidated_layout: + cuda_dir = f"cu{cuda_major_version}" + if is_windows: + import platform # noqa: PLC0415 + + arch = "arm64" if platform.machine().lower() in ("arm64", "aarch64") else "x86_64" + cuda_dll_paths = [ + ("nvidia", cuda_dir, "bin", arch, f"cublasLt64_{cuda_major_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cublas64_{cuda_major_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cufft64_{cufft_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cudart64_{cuda_major_version}.dll"), + ] + else: # Linux + # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first. + cuda_dll_paths = [ + ("nvidia", cuda_dir, "lib", f"libcublasLt.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", f"libcublas.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", f"libnvrtc.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", "libcurand.so.10"), + ("nvidia", cuda_dir, "lib", f"libcufft.so.{cufft_version}"), + ("nvidia", cuda_dir, "lib", f"libcudart.so.{cuda_major_version}"), + ] + elif is_windows: + # CUDA 12 and earlier: each component ships its own "nvidia/" package. # Path is relative to site-packages directory. cuda_dll_paths = [ ("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"), @@ -176,16 +208,6 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru ("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"), ("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"), ] - cudnn_dll_paths = [ - ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_engines_tensor_ir64_9.dll"), - ] else: # Linux # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first. cuda_dll_paths = [ @@ -197,6 +219,19 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru ("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"), ] + # cuDNN keeps its own "nvidia/cudnn" package layout in both old and consolidated schemes. + if is_windows: + cudnn_dll_paths = [ + ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_engines_tensor_ir64_9.dll"), + ] + else: # Linux # Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux. cudnn_dll_paths = [ ("nvidia", "cudnn", "lib", "libcudnn.so.9"), @@ -238,15 +273,19 @@ def print_debug_info(): # Print version of installed packages that is related to CUDA or cuDNN DLLs. cuda_major = _extract_cuda_major_version(cuda_version) + # Starting with CUDA 13, NVIDIA dropped the "-cuNN" suffix from the per-component + # CUDA Toolkit packages (cuDNN keeps its suffixed package name). + cuda_pkg_suffix = "" if (cuda_major.isdigit() and int(cuda_major) >= 13) else f"-cu{cuda_major}" + packages = [ "torch", - f"nvidia-cuda-runtime-cu{cuda_major}", + f"nvidia-cuda-runtime{cuda_pkg_suffix}", f"nvidia-cudnn-cu{cuda_major}", - f"nvidia-cublas-cu{cuda_major}", - f"nvidia-cufft-cu{cuda_major}", - f"nvidia-curand-cu{cuda_major}", - f"nvidia-cuda-nvrtc-cu{cuda_major}", - f"nvidia-nvjitlink-cu{cuda_major}", + f"nvidia-cublas{cuda_pkg_suffix}", + f"nvidia-cufft{cuda_pkg_suffix}", + f"nvidia-curand{cuda_pkg_suffix}", + f"nvidia-cuda-nvrtc{cuda_pkg_suffix}", + f"nvidia-nvjitlink{cuda_pkg_suffix}", ] for package in packages: directory_name = "nvidia" if package.startswith("nvidia-") else None diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index ac3531e39eb53..e46075a86f811 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -91,6 +91,12 @@ constexpr bool LAYOUT_BNSH = true; namespace sparse_attention { // Environment variable to enable or disable sparse attention v1 kernel. Default is 0 (enabled). constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_V1"; + +// Environment variable to disable device-side validation of CSR indices and key sequence lengths. +// Default is 0 (validation enabled). Set to 1 to skip the validation kernel launch and stream +// synchronization, which may improve latency when inputs are known to be well-formed. +// Usage: export ORT_DISABLE_SPARSE_ATTENTION_INPUT_VALIDATION=1 +constexpr const char* kDisableInputValidation = "ORT_DISABLE_SPARSE_ATTENTION_INPUT_VALIDATION"; } // namespace sparse_attention namespace attention { diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 59313cf527c91..60fa4f0c4ada1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -909,6 +909,327 @@ class GQAAttentionBase { return Status::OK(); } + // Non-quantized flash attention path. Only supports T = float. + // Concatenates new K/V into the FP32 present cache, then runs the tiled + // online-softmax kernel MlasFlashAttentionGQA (QK^T + softmax + S*V fused). + Status ApplyAttentionFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (float) + const Tensor* past_value, // past V (float) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (float) + Tensor* present_value, // present V (float) + const Tensor* seqlens_k, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for flash attention"); + + const float* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + float* present_key_data = present_key->MutableData(); + const float* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + float* present_value_data = present_value->MutableData(); + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = SafeInt(kv_sequence_length) * head_size; + const size_t past_buff_chunk_length = SafeInt(seqlen_past_kv_cache) * head_size; + const size_t present_buff_chunk_length = SafeInt(seqlen_present_kv_cache) * head_size; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast((past_buff_chunk_length + kv_input_chunk_length) * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_length * sizeof(float)); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_length = past_seqlen * head_size; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_key_data, k_new, present_key_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_value_data, v_new, present_value_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + } + }); + } + + // ---- Phase 2: Flash Attention with FP32 KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention). + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), fall back to per-batch invocation + // so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. This KV-split is only + // wired into the unified kernel (common_past_seqlen >= 0); the ragged/per-batch + // fallback runs the single-pass decode kernel instead, which needs a larger + // per-thread scratch (scores[total_seqlen] + temp_output[head_size]). Gating on + // common_past_seqlen >= 0 keeps the per-thread buffer sizing below consistent + // with the kernel that actually runs. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + common_past_seqlen >= 0 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else if (sequence_length == 1) { + // Decode (GEMV kernel, no Q/KV tiling): per-thread scratch holds the full + // score row scores[total_seqlen] plus a temp output accumulator[head_size]. + buffer_size_per_thread = + (SafeInt(max_total_seqlen) + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionGQAArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.q_batch_stride = packed_qkv + ? static_cast(packed_batch_stride) + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; + + MlasFlashAttentionGQA(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionGQAArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0 + ? packed_batch_stride + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.query = Q + static_cast(b) * static_cast(q_batch_stride_elems); + args.q_batch_stride = SafeInt(num_heads_) * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside). + // Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head + // extent (1 when the head dim is broadcast). + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast(num_heads_); + batch_bias += static_cast(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; + + MlasFlashAttentionGQA(&args, tp); + } + } + + return Status::OK(); + } + private: // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index e36bdb2de263a..debda282eb4f1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -343,6 +343,31 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V const T* k_data = packed_qkv ? nullptr : k_rotary; const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + + // Non-quantized flash attention path (float only). Uses the tiled online-softmax + // kernel to avoid materializing the full attention score matrix. Falls back to the + // naive path when an unsupported feature is requested (softcap, smooth softmax, + // head sink, or QK output). + // + // Prefill (sequence_length > 1) uses the tiled kernel; single-token decode + // (sequence_length == 1 with total_sequence_length > 1) uses the dedicated GEMV + // decode kernel. Both are reached when total_sequence_length > 1. + if constexpr (std::is_same_v) { + const bool use_flash = !disable_gqa_flash_ && + parameters.total_sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr && + present_k != nullptr && present_v != nullptr; + if (use_flash) { + return ApplyAttentionFlash(q_rotary, k_data, v_data, + attention_bias, past_key, past_value, + output, present_k, present_v, seqlens_k, + parameters, allocator, context); + } + } + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 6bd1690fca815..162d7257d0a4c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #include +#include #include #include @@ -162,6 +163,13 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + // True once PrePack(InputIndex::B) has folded the scales and (constant) zero points into packed_b_, + // leaving the CompInt8 buffer fully packed and compute-ready. Pre-packed weight sharing + // content-hashes the buffer right after the B PrePack returns, so everything that affects the + // packed bytes (in particular the block sum / BZpCorr, which depend on the zero points) must be + // folded in by then. Once set, the later scales/zero_point PrePack calls must not pack again: the + // CompInt8 packing is single-shot, and the buffer may by then be one shared from another session. + bool packed_b_finalized_{false}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; @@ -227,7 +235,6 @@ template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { - ORT_UNUSED_PARAMETER(prepacked_weights); is_packed = false; if (has_g_idx_) { return Status::OK(); @@ -308,10 +315,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All static_cast(packed_b_.get()), threadpool_ptr); - if (prepacked_weights != nullptr) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } + // Do not append packed_b_ here. Both the LUT and non-LUT branches share the single append + // after this if/else, so each records exactly one buffer. Appending here as well would move + // packed_b_ out now and then have the shared append record a second, moved-from/null buffer + // with a non-zero packed_b_size_. PrePackedWeights::GetHash() skips null buffers so sharing + // appears to work, but the prepacked-blob save path writes buffer_sizes_[i] bytes from + // buffers_[i].get() and would dereference that null pointer. } else { // For HQNBIT_CompInt8, route through SQNBIT_CompInt8 for sizing and packing. // This gets KleidiAI-sized buffer when available for 4-bit and packs B+scales correctly. @@ -341,24 +350,64 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + // The framework content-hashes this packed buffer to deduplicate pre-packed weights, both + // within a session and across sessions (the shared container). The session-state prepack pass + // (SessionState::PrepackConstantInitializedTensors) passes a non-null prepacked_weights on both + // the container and the default single-session paths, so this zero-fill runs on essentially + // every prepack at load, not only when a sharing container is configured -- the guard below + // only skips a caller that asks for no cacheable buffer. The pack routines need not write every + // byte (alignment padding between the CompInt8 sub-regions; any layout could gain padding) and + // the reserve allocation is not zero-filled, so the hash would otherwise depend on uninitialized + // bytes. Zeroing the whole buffer is a one-time O(packed_b_size_) load cost (the pack overwrites + // the data regions, leaving only padding zeroed); inference is unaffected. + if (prepacked_weights != nullptr) { + std::memset(packed_b_.get(), 0, packed_b_size_); + } MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, effective_compute_type, qptr, packed_b_.get(), scale_ptr, has_zp_input_, nullptr, threadpool_ptr, &mlas_backend_kernel_selector_config_); -#if defined(MLAS_TARGET_ARM64) - // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales and zero_points are accessible. - if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 4 && has_zp_input_ && scales_fp32_ && - MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, SQNBIT_CompInt8, has_zp_input_, &mlas_backend_kernel_selector_config_)) { - const Tensor* zp_tensor = nullptr; - OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); - if (zp_tensor != nullptr) { - auto zptr = zp_tensor->Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + // Fold the scales and (constant) zero points into packed_b_ now, during the B PrePack, instead + // of deferring them to the later scales/zero_points PrePack calls. Pre-packed weight sharing + // content-hashes this buffer immediately after the B PrePack returns; the CompInt8 block sum + // (and the KleidiAI BZpCorr) is a function of the zero points, so they must already be folded + // in for the hash to reflect them. Otherwise two initializers with identical B and scales but + // different zero points would hash equal and the second would wrongly adopt the first's buffer + // and silently compute wrong results. scales and zero_points are constant initializers, so they + // are available here. The B pack above only partially populates the buffer (on x64 the block sum + // is deferred; on ARM64 8-bit the scales are ignored during B packing), so issue one more pack + // call with QuantBData == nullptr to finalize it. This is byte-identical to the staged + // scales + zero_points packing it replaces. + bool finalize_scale_zp_into_packed_b = effective_compute_type == SQNBIT_CompInt8 && scale_ptr != nullptr; +#if !defined(MLAS_TARGET_AMD64_IX86) + // On ARM64 the scales/zero points are folded into B only for 8-bit, or for 4-bit when MLAS bakes + // them in (KleidiAI). For 4-bit non-KleidiAI they are applied at compute time and must not be + // passed to the packing routine, which would dereference the null QuantBData buffer. + finalize_scale_zp_into_packed_b = + finalize_scale_zp_into_packed_b && + (nbits_ == 8 || MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, effective_compute_type, + has_zp_input_, &mlas_backend_kernel_selector_config_)); +#endif + if (finalize_scale_zp_into_packed_b) { + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + zp_ptr = zp_tensor->Data(); + } } + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, effective_compute_type, nullptr /*QuantBData*/, + packed_b_.get(), scale_ptr, has_zp_input_, zp_ptr, nullptr, + &mlas_backend_kernel_selector_config_); + packed_b_finalized_ = true; } -#endif // MLAS_TARGET_ARM64 } is_packed = true; + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } } else if (compute_type_ == SQNBIT_CompInt8 && !prefer_lut_gemm_) { // Packing scales and zero points // Guard: for LUT-eligible nodes, scales/ZP are already packed inside @@ -376,7 +425,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All }(); if (should_pack_scale_and_zp_inputs) { - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + // packed_b_ is already finalized during the B PrePack (scales and zero points folded in there so + // the sharing content hash captures them), so skip packing here. The CompInt8 packing is + // single-shot and packed_b_ may now be a buffer shared from another session. + if (input_idx == InputIndex::scales && packed_b_ != nullptr && !packed_b_finalized_) { auto sptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -384,7 +436,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } // Packing zero_point - if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -410,13 +462,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } scales_are_packed_ = true; - is_packed = true; - // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales are still accessible. - // After this PrePack returns is_packed=true, ORT may erase scales from the constant - // input table (use count drops to 0), making them unavailable in later PrePack calls. - // Zero points haven't been PrePacked yet so they are still accessible. - if (has_zp_input_ && nbits_ == 4) { + // The scales were folded into packed_b_ during the B PrePack, so there is no separate packed + // scales buffer to cache or share. Report is_packed = false (as the x64 path already does for + // the scales input) so the framework does not engage pre-packed weight sharing for scales. + // Engaging it would require pushing a placeholder buffer, but the real scales live inside + // packed_b_ so the placeholder would be null - and PrePackedWeights::GetHash() skips null + // buffers, making the scales container key identical for every MatMulNBits node. That would + // falsely increment the shared-weights counter for unrelated nodes without sharing any real + // buffer. The quantized weight B (which carries the folded-in scales) is shared on its own. + is_packed = false; + + // BZpCorr was already folded into packed_b_ during the B PrePack (so the sharing content hash + // captures the zero points), so re-folding it here must be skipped: the packing is single-shot + // and packed_b_ may now be a buffer shared from another session. + if (has_zp_input_ && nbits_ == 4 && !packed_b_finalized_) { const Tensor* zp_tensor = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); if (zp_tensor != nullptr) { @@ -457,7 +517,14 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // BZpCorr was already computed during B packing in Step 1 (if applicable). scales_are_packed_ = true; - is_packed = true; + + // The scales were folded into the packed B buffer during the B PrePack, so there is no + // separate packed scales buffer to cache or share. Report is_packed = false (mirroring the + // x64 path and the SQNBIT_CompInt8 path above) so the framework does not engage sharing for + // the scales input; engaging it would push a null placeholder whose content hash is identical + // for every node, falsely incrementing the shared-weights counter without sharing any real + // buffer. + is_packed = false; } else #endif // MLAS_TARGET_ARM64 { @@ -471,7 +538,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // Pack scales separately only for 8-bit. For 4-bit on ARM64, scales are already packed // during B packing or used as a raw pointer at compute time (matching standard // SQNBIT_CompInt8 behavior where should_pack_scale_and_zp_inputs = (nbits_ == 8) on ARM64). - if (nbits_ == 8) { + // Skip when packed_b_ was already finalized during the B PrePack (scales/zero points folded + // in there for the sharing content hash); it may now be a buffer shared from another session. + if (nbits_ == 8 && !packed_b_finalized_) { MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -482,7 +551,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // Pack zero_points separately only for 8-bit (matching standard SQNBIT_CompInt8 behavior). // For 4-bit, zero_points are passed directly in data params or handled via KleidiAI BZpCorr. - if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && nbits_ == 8) { + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && nbits_ == 8 && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -540,8 +609,6 @@ template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { - ORT_UNUSED_PARAMETER(prepacked_weights); - if (input_idx == InputIndex::scales || input_idx == InputIndex::bias) { auto sptr = tensor.Data(); auto tensor_size = static_cast(tensor.Shape().Size()); @@ -565,8 +632,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); - if (scales && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, - has_zp_input_, &mlas_backend_kernel_selector_config_)) { + // Convert the constant fp16 scales to fp32 up front so they (and the zero points) can be folded + // into packed_b_ during this B PrePack, mirroring the primary float PrePack above. Pre-packed + // weight sharing content-hashes the buffer right after this B PrePack returns, so for CompInt8 + // everything that affects the packed bytes (the scales, and the block sum / KleidiAI BZpCorr that + // depend on the zero points) must be folded in by now. + if (scales && compute_type_ == SQNBIT_CompInt8) { auto sptr = scales->Data(); auto scales_size = static_cast(scales->Shape().Size()); auto ptr = IAllocator::MakeUniquePtr(alloc, scales_size, true); @@ -581,25 +652,55 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + // See the primary PrePack() above: SessionState::PrepackConstantInitializedTensors passes a + // non-null prepacked_weights on both the container and the default single-session paths, so this + // zero-fill runs on essentially every prepack at load (the guard only skips a caller that asks for + // no cacheable buffer). It keeps the dedup content hash reproducible regardless of bytes the pack + // leaves uninitialized (alignment padding), for any compute type. One-time O(packed_b_size_) load + // cost; inference is unaffected. + if (prepacked_weights != nullptr) { + std::memset(packed_b_.get(), 0, packed_b_size_); + } MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); -#if defined(MLAS_TARGET_ARM64) - // For KleidiAI asymmetric 4-bit path: compute BZpCorr during B packing. - // The fp16 specialization packs B here (with scales already converted to fp32), - // so we also compute BZpCorr now while both scales and zero_points are accessible. - if (has_zp_input_ && nbits_ == 4 && scales_fp32_ != nullptr) { - const Tensor* zp_tensor = nullptr; - OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); - if (zp_tensor != nullptr) { - auto zptr = zp_tensor->Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + // Fold the scales and (constant) zero points into packed_b_ now (see the primary PrePack above): + // the CompInt8 block sum and the KleidiAI BZpCorr depend on the zero points, so they must be + // folded in before the sharing content hash is taken. Otherwise two initializers with identical B + // and scales but different zero points would hash equal and the second would wrongly adopt the + // first's buffer. The B pack above only partially populates the buffer, so issue one more pack + // call with QuantBData == nullptr to finalize it. This is byte-identical to the staged + // scales + zero_points packing it replaces. + bool finalize_scale_zp_into_packed_b = compute_type_ == SQNBIT_CompInt8 && scales_fp32_ != nullptr; +#if !defined(MLAS_TARGET_AMD64_IX86) + // On ARM64 the scales/zero points are folded into B only for 8-bit, or for 4-bit when MLAS bakes + // them in (KleidiAI). For 4-bit non-KleidiAI they are applied at compute time and must not be + // passed to the packing routine, which would dereference the null QuantBData buffer. + finalize_scale_zp_into_packed_b = + finalize_scale_zp_into_packed_b && + (nbits_ == 8 || MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, + has_zp_input_, &mlas_backend_kernel_selector_config_)); +#endif + if (finalize_scale_zp_into_packed_b) { + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + zp_ptr = zp_tensor->Data(); + } } + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr /*QuantBData*/, + packed_b_.get(), scales_fp32_.get(), has_zp_input_, zp_ptr, nullptr, + &mlas_backend_kernel_selector_config_); + packed_b_finalized_ = true; } -#endif // MLAS_TARGET_ARM64 - is_packed = true; + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } } else if (compute_type_ == SQNBIT_CompInt8) { bool should_pack_scale_and_zp = [&]() { #if defined(MLAS_TARGET_AMD64_IX86) @@ -610,11 +711,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou }(); if (should_pack_scale_and_zp) { - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr && !packed_b_finalized_) { MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); is_packed = false; - } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -635,6 +736,11 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& used_shared_buffers = false; if (input_idx == InputIndex::B && !prepacked_buffers.empty()) { + // The buffer handed back is fully finalized: the producing session folded the scales and zero + // points (block sums / KleidiAI BZpCorr) into it during its PrePack(B), which is also when this + // kernel set packed_b_finalized_ on its own (identical) B PrePack. The later scale/zero-point + // PrePack calls already skip the staged packing whenever packed_b_finalized_ is set, so simply + // adopt the shared buffer here - no extra bookkeeping is needed to avoid re-folding into it. packed_b_ = std::move(prepacked_buffers[0]); used_shared_buffers = true; @@ -643,6 +749,9 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); } } + // Only the quantized weight B yields a separately cached pre-packed buffer. The scales (and zero + // points) are folded into packed_b_ during the B PrePack and reported with is_packed = false, so + // the framework never asks this kernel to adopt a shared buffer for them. return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index bab2dfd13e046..b027772971540 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -36,6 +36,74 @@ namespace contrib { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +namespace { + +// Validate CSR row-pointer monotonicity and column-index range. +// Must be called after CheckInputs has populated the parameters struct. +Status ValidateCSRIndices(const SparseAttentionParameters& parameters, + const Tensor& block_row_indices, + const Tensor& block_col_indices) { + const int num_layout = parameters.num_sparse_layout; + const int max_blocks = parameters.stride_row_indices - 1; + const int col_count = parameters.stride_col_indices; + + const int32_t* row_data = block_row_indices.Data(); + const int32_t* col_data = block_col_indices.Data(); + for (int l = 0; l < num_layout; ++l) { + const int32_t* r = row_data + l * (max_blocks + 1); + if (r[0] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_row_indices[", l, "][0] must be 0, got ", r[0]); + } + for (int i = 0; i < max_blocks; ++i) { + if (r[i] < 0 || r[i] > r[i + 1] || r[i + 1] > col_count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_row_indices values are not monotonically non-decreasing or exceed " + "block_col_indices columns at layout ", + l, " row ", i, + ": r[", i, "]=", r[i], ", r[", i + 1, "]=", r[i + 1], + ", col_count=", col_count); + } + } + const int32_t* c = col_data + l * col_count; + const int nnz = r[max_blocks]; + for (int k = 0; k < nnz; ++k) { + if (c[k] < 0 || c[k] >= max_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_col_indices[", l, "][", k, "]=", c[k], + " is out of valid range [0, ", max_blocks, ")"); + } + } + } + + return Status::OK(); +} + +// Validate total_key_lengths element values. +Status ValidateKeyLengths(const SparseAttentionParameters& parameters, + const Tensor& total_key_lengths) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + + const auto* key_len_data = total_key_lengths.Data(); + const bool is_prompt = (sequence_length == total_sequence_length); + const int min_key_length = is_prompt ? 1 : sequence_length; + for (int i = 0; i < batch_size; ++i) { + const int key_length = key_len_data[i]; + if (key_length < min_key_length || key_length > total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key_total_sequence_lengths value ", key_length, + " at batch index ", i, + " is out of range [", min_key_length, ", ", total_sequence_length, "]."); + } + } + + return Status::OK(); +} + +} // namespace + template SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) { } @@ -75,6 +143,11 @@ Status SparseAttention::Compute(OpKernelContext* context) const { block_col_indices, total_key_lengths, total_seq_len)); + ORT_RETURN_IF_ERROR(ValidateCSRIndices(parameters, + *block_row_indices, + *block_col_indices)); + ORT_RETURN_IF_ERROR(ValidateKeyLengths(parameters, + *total_key_lengths)); const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h index 2804f30a9611d..af320b250abdb 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h @@ -198,26 +198,13 @@ Status CheckInputs(void* params, past_key_dims[3]); } - // Check the shape and values of total_key_sequence_lengths. + // Check the shape of total_key_sequence_lengths. const auto& k_len_dim = total_key_lengths->Shape().GetDims(); if (k_len_dim.size() != 1 || k_len_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_total_sequence_lengths must have shape (batch_size)."); } - const auto* key_len_data = total_key_lengths->Data(); - const bool is_prompt = (sequence_length == total_sequence_length); - const int min_key_length = is_prompt ? 1 : sequence_length; - for (int i = 0; i < batch_size; ++i) { - const int key_length = key_len_data[i]; - if (key_length < min_key_length || key_length > total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "key_total_sequence_lengths value ", key_length, - " at batch index ", i, - " is out of range [", min_key_length, ", ", total_sequence_length, "]."); - } - } - int rotary_dim = 0; int max_rotary_sequence_length = 0; if (do_rotary) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index ad778fb7ef907..1fe4ee65e5f31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -168,9 +168,12 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); - int* seq_copy_ptr = seq_copy.get(); + // total_size is an element count: it sizes the int32 staging buffer and the copy spans. + // sequence_bytes is the byte count for copying a single beam's sequence per iteration. + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t sequence_bytes = static_cast(cur_len) * sizeof(int32_t); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + int32_t* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { ORT_RETURN_IF_ERROR(device_copy_int32_func( @@ -183,10 +186,10 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( gsl::span sequence = sequences.GetSequence(i); const int32_t* sequence_data = sequence.data(); long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + memcpy(seq_copy_ptr + seq_index, sequence_data, sequence_bytes); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); ORT_RETURN_IF_ERROR(device_copy_int32_func( temp_input, temp_sequence, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 6a2d95f089f2b..aa84ee05c2bd1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -641,13 +641,18 @@ __global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, const bool is_first_prompt) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < batch_size) { - const int total_len = total_seq_lens_minus_one[i] + 1; + // total_seq_lens_minus_one is the seqlens_k input and is not range-checked on the device. + // Clamp the negative case at the source so the derived lengths below stay non-negative and + // cannot flow as negative offsets into KV-cache or attention index computations. + const int seqlens_k = total_seq_lens_minus_one[i]; + const int total_len = (seqlens_k > 0 ? seqlens_k : 0) + 1; total_seq_lens[i] = total_len; if (is_first_prompt) { past_seq_lens[i] = 0; padded_seq_lens[i] = sequence_length; } else { - past_seq_lens[i] = total_len - sequence_length; + const int past_len = total_len - sequence_length; + past_seq_lens[i] = past_len > 0 ? past_len : 0; padded_seq_lens[i] = 0; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 0c62aef11d53a..1a5dc6e0fca6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -231,7 +231,9 @@ __global__ void UnpackRoPEAppend( } else { // Store K or V into the KV cache at index (past_seqlen + s) const int cache_s = past_seq_lens[b] + s; - if (cache_s < max_seqlen) { + // Two-sided bound: the lower check mirrors the position guard above and prevents a + // negative offset from being sign-extended into the cache index arithmetic below. + if (cache_s >= 0 && cache_s < max_seqlen) { void* cache_ptr = (head_type == KEY) ? k_cache : v_cache; if (cache_ptr != nullptr) { int64_t cache_idx; diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 3185a9a86b231..9db229555ecbc 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -200,204 +200,6 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( } } -/** - * Takes the input maps and prepares the expanded maps for min latency - * @param num_active_experts_per_node: Number of active experts on current node - * @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by - * the token. Only the first num_active_experts_per_ rows are valid - * @param active_expert_global_ids: The global expert id for each activated expert - * Only the first num_active_experts_per_ values are valid - * @param expert_first_token_offset: Store the first token offset for each expert - */ -template -__device__ __forceinline__ void initTensor(T* value, int const tid, int const total_num, T const init_value) { - for (int i = tid; i < total_num; i += BLOCK_SIZE) { - value[i] = init_value; - } -} - -template -__device__ __forceinline__ void setLocalExperts(int* s_local_experts, T const* token_selected_experts, - int const total_num_experts, int const tid, int const start_expert, int const end_expert) { - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = expert >= start_expert && expert < end_expert; - if (is_valid_expert) { - int local_expert_id = expert - start_expert; - if (s_local_experts[local_expert_id] == 0) { - s_local_experts[local_expert_id] = 1; // @TODO: Make sure that we allow duplicated write here - } - } - } - __syncthreads(); -} - -template -__device__ __forceinline__ void prefixSum(T* out, T* in, int const num, int const tid) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage tempStorage; - - T threadData = 0; - if (tid < num) { - threadData = in[tid]; - } - - BlockScan(tempStorage).InclusiveSum(threadData, threadData); - __syncthreads(); - - if (tid < num) { - out[tid] = threadData; - } - __syncthreads(); -} - -__device__ __forceinline__ void setActiveNum(int& num_active, int& num_active_offset_start, int& num_active_offset_end, - int const cluster_size, int const cluster_rank) { - int num_remainder = num_active % cluster_size; - int num_active_per_node = max(0, num_active - 1) / cluster_size; // num_active_per_node shouldn't be neg - if (cluster_rank < num_remainder) { - num_active = num_active_per_node + 1; - num_active_offset_start = cluster_rank * num_active; - } else { - num_active = num_active_per_node; - num_active_offset_start = cluster_rank * num_active_per_node + num_remainder; - } - num_active_offset_end = num_active_offset_start + num_active; -} - -template -__global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const num_experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, bool const smart_routing, int const cluster_rank, - int const cluster_size, int const num_experts_smem) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - // Use one block to process the min latency case - int tid = threadIdx.x; - // 0. init the global memory experts_to_token_scores [num_experts_per_node, num_token] - int const total_local_scales = num_experts_per_node * num_tokens; - initTensor(experts_to_token_scores, tid, total_local_scales, 0.0f); - initTensor(active_expert_global_ids, tid, num_experts_per_node, -1); - - __threadfence(); //@Todo: check do I need this fence for previous zero setting - - // 1. mask for the active expert: 1 stands for active - extern __shared__ int s_local_experts[]; - int* s_store_experts = s_local_experts + num_experts_smem; - initTensor(s_local_experts, tid, num_experts_smem, 0); - __syncthreads(); - - // 2. set the shared array s_local_experts[] - int const total_num_experts = num_tokens * num_experts_per_token; - setLocalExperts( - s_local_experts, token_selected_experts, total_num_experts, tid, start_expert, end_expert); - - // 3. perform prefix sum to acquire the store position and total active experts - //@TODO: Use cub first, might need to change it to self-defined api - prefixSum(s_store_experts, s_local_experts, num_experts_smem, tid); - - // 4. store the num of active experts - int num_active = s_store_experts[num_experts_smem - 1]; - int num_active_offset_start = 0; - int num_active_offset_end = 0; - - if (smart_routing) { - setActiveNum(num_active, num_active_offset_start, num_active_offset_end, cluster_size, cluster_rank); - } - - if (tid == 0) { - *num_active_experts_per_node = num_active; - } - - // 5. store the global expert id for each expert - if (smart_routing) { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - if (offset >= num_active_offset_start && offset < num_active_offset_end) { - active_expert_global_ids[offset - num_active_offset_start] = i; - } else { - s_local_experts[i] = 0; - } - } - } - __syncthreads(); // Need sync to update the s_local_experts - } else { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - active_expert_global_ids[offset] = i + start_expert; - } - } - } - - // 6. store the scale values - __threadfence(); //@Todo: check do I need this fence for previous zero setting - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is not in the current node, set it to num_experts_per_node - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = smart_routing ? s_local_experts[expert] : (expert >= start_expert && expert < end_expert); - - if (is_valid_expert) { - int token = i / num_experts_per_token; - float const scale = token_final_scales[i]; - int offset = s_store_experts[expert - start_expert] - 1 - num_active_offset_start; - experts_to_token_scores[offset * num_tokens + token] = scale; - } - } - // 7. set default value for redundant memory - for (int i_exp = num_active + tid; i_exp < num_experts_per_node; i_exp += BLOCK_SIZE) { - active_expert_global_ids[i_exp] = -1; - } - // 8. set expert_first_token_offset - for (int i_exp = tid; i_exp < num_experts_per_node + 1; i_exp += BLOCK_SIZE) { - if (i_exp < num_active) { - expert_first_token_offset[i_exp] = i_exp * num_tokens; - } else { - expert_first_token_offset[i_exp] = num_active * num_tokens; - } - } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, int const cluster_rank, int const cluster_size, - int const num_experts_smem, cudaStream_t const stream) { - ORT_ENFORCE(num_experts_per_node == (end_expert - start_expert), - "num_experts_per_node must be equal to end_expert - start_expert"); - - ORT_ENFORCE(num_experts_per_node <= 256, "don't support num_experts_per_node > 256 cases"); - - int const threads = 256; - int const blocks = 1; - bool const smart_routing = cluster_size > 1; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = num_experts_smem * sizeof(int) * 2; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, buildMinLatencyActiveExpertMapsKernel, num_active_experts_per_node, - experts_to_token_scores, active_expert_global_ids, expert_first_token_offset, token_selected_experts, - token_final_scales, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, - smart_routing, cluster_rank, cluster_size, num_experts_smem); -} - template __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, int* const unpermuted_row_to_permuted_row, @@ -1168,123 +970,6 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir #endif } -template -__global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int64_t const num_experts_per_node, T const* in1, T const* in2, - WeightType const* weights1, WeightType const* weights2, float const* alpha_scale_flat1, - float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* output1, OutputType* output2, - int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; - - if (expert >= num_experts_per_node) { - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // Note: expert is used to calculate the offset of the input and output - // local_expert is used to calculate the offset of the weight - auto const num_tokens_before_expert = expert * num_tokens; - bool const is_active_expert = expert < *num_active_experts_per; - int const local_expert = is_active_expert ? active_expert_global_ids[expert] - start_expert : -1; - auto const gemm_m = is_active_expert ? num_tokens : 0; - - // M and N transposed since we are using the #tokens as the N dimension - layout_info1.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, gemm1_k); - layout_info2.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, gemm2_k); - - if (alpha_scale_flat1) { - assert(alpha_scale_flat2); - if (is_active_expert) { - layout_info1.alpha_scale_ptr_array[expert] = alpha_scale_flat1 + local_expert; - layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + local_expert; - } else { - layout_info1.alpha_scale_ptr_array[expert] = nullptr; - layout_info2.alpha_scale_ptr_array[expert] = nullptr; - } - } - - if (quant_params.fp4.fc1.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info1, expert, - gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc1 uses the same A input for all experts and the scaling factor B offsets from - // the local expert index - if (is_active_expert) { - layout_info1.fpX_block_scaling_factors_A[expert] = fp4_act_flat1; - layout_info1.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc1.weight_block_scale + getOffsetWeightSF( - local_expert, gemm1_n, gemm1_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info1.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info1.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - - if (quant_params.fp4.fc2.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info2, expert, - gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc2 scaling factor B offsets by the local expert index - if (is_active_expert) { - layout_info2.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc2.weight_block_scale + getOffsetWeightSF( - local_expert, gemm2_n, gemm2_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info2.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info2.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif - - assert(gemm_m <= INT32_MAX); - assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); - assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); - assert(gemm2_n > 0 && gemm2_n <= INT32_MAX); - assert(gemm2_k > 0 && gemm2_k <= INT32_MAX); - computeTmaWarpSpecializedInputStrides(layout_info1, gemm_m, gemm1_n, gemm1_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - computeTmaWarpSpecializedInputStrides(layout_info2, gemm_m, gemm2_n, gemm2_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - - if (is_active_expert) { - // Note: under low latency mode, we use the same input for all experts - // so for gemm1, the inputs are the same, - // for gemm2, we use the input generated by gemm1 - layout_info1.ptr_a[expert] = in1; - layout_info2.ptr_a[expert] = safe_inc_ptr(in2, expert * num_tokens * gemm2_k); - - // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info1.ptr_b[expert] = safe_inc_ptr(weights1, local_expert * (gemm1_n * gemm2_k)); - layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); - - assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); - - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); - } - } else { - layout_info1.ptr_a[expert] = nullptr; - layout_info2.ptr_a[expert] = nullptr; - layout_info1.ptr_b[expert] = nullptr; - layout_info2.ptr_b[expert] = nullptr; - - layout_info1.default_epilogue.ptr_d[expert] = nullptr; - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; - } - } -} - // ========================== Permutation things ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. @@ -2886,68 +2571,6 @@ CutlassMoeFCRunner: return std::make_pair(layout_info1, layout_info2); } -template -std::pair -CutlassMoeFCRunner::computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream) { - ORT_ENFORCE(!use_w4afp8, "W4AFP8 is not supported in low latency mode"); - - // Always nullptr - layout_info1.ptr_c = nullptr; - layout_info1.stride_c = nullptr; - layout_info2.ptr_c = nullptr; - layout_info2.stride_c = nullptr; - - auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc1.global_scale - ? quant_params.fp8_mxfp4.fc1.global_scale - : quant_params.mxfp8_mxfp4.fc1.global_scale) - : fp8_dequant1; - auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc2.global_scale - ? quant_params.fp8_mxfp4.fc2.global_scale - : quant_params.mxfp8_mxfp4.fc2.global_scale) - : fp8_dequant2; - if (!alpha_scale_flat1) { - layout_info1.alpha_scale_ptr_array = nullptr; - } - if (!alpha_scale_flat2) { - layout_info2.alpha_scale_ptr_array = nullptr; - } - - layout_info1.int4_groupwise_params.enabled = false; - layout_info2.int4_groupwise_params.enabled = false; - - int const threads = std::min(1024, num_experts); - int const blocks = (num_experts + threads - 1) / threads; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, - computeStridesTmaWarpSpecializedLowLatencyKernel, layout_info1, - layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, input1, input2, weights1, weights2, - alpha_scale_flat1, alpha_scale_flat2, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, bias1, bias2, output1, - output2, num_active_experts_per, active_expert_global_ids, start_expert); - - return std::make_pair(layout_info1, layout_info2); -} - template std::pair CutlassMoeFCRunner::setupTmaWarpSpecializedInputs( @@ -3083,78 +2706,6 @@ __global__ void populateRandomBufferKernel(void* buffer_void, size_t size) { buffer[tid * elem_per_thread + i] = curand4(&state); } -template -__global__ void prepareMinLatencyBuffer(int* num_active_experts_per_node, int* active_expert_global_ids, - int64_t* expert_first_token_offset, int const num_tokens, int const num_experts_per_token, - int const num_experts_per_node) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - // 0. set offset - num_active_experts_per_node += bid; - active_expert_global_ids += bid * num_experts_per_node; - expert_first_token_offset += bid * (num_experts_per_node + 1); - - // 1. set the num_active_experts_per_node - int num_active = max(1, (int)(bid * ((float)num_experts_per_node / NUM_ROUTING_SAMPLES))); - *num_active_experts_per_node = num_active; - - // 2. generate random active experts - extern __shared__ float s_buf[]; - float* expert_refs = s_buf; - int* expert_refs_idx = reinterpret_cast(expert_refs + num_experts_per_node); - - curandState_t local_state; - curand_init(bid, tid, 0, &local_state); - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - expert_refs[i] = (float)curand_uniform(&local_state); - expert_refs_idx[i] = (int)i; - } - __syncthreads(); - - float thread_key[1]; - int thread_value[1]; - thread_key[0] = std::numeric_limits::max(); - thread_value[0] = num_experts_per_node; - - if (tid < num_experts_per_node) { - thread_key[0] = expert_refs[tid]; - thread_value[0] = expert_refs_idx[tid]; - } - - using BlockRadixSort = cub::BlockRadixSort; - using BlockRadixSortValue = cub::BlockRadixSort; - - union TempStorage { - typename BlockRadixSort::TempStorage key_value; - typename BlockRadixSortValue::TempStorage value; - }; - __shared__ union TempStorage temp_storage; - - BlockRadixSort(temp_storage.key_value).Sort(thread_key, thread_value); - __syncthreads(); - - if (tid > num_active) { - thread_value[0] = std::numeric_limits::max(); - } - BlockRadixSortValue(temp_storage.value).Sort(thread_value); - __syncthreads(); - - // 3. set the active_expert_global_ids and expert_first_token_offset - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - if (i < num_active) { - active_expert_global_ids[i] = thread_value[0]; - expert_first_token_offset[i] = i * num_tokens; - } else { - active_expert_global_ids[i] = -1; - expert_first_token_offset[i] = num_active * num_tokens; - } - } - if (tid == 0) { - expert_first_token_offset[num_experts_per_node] = num_active * num_tokens; - } -} - void populateRandomBuffer(void* buffer_void, size_t size, cudaStream_t stream) { // Each thread initialises 128 bytes ORT_ENFORCE(size % 128 == 0, "Unexpected size alignment"); @@ -3292,10 +2843,6 @@ std::map> GemmProfilerBackend::getProfile size_t const blocked_expert_counts_cumsum_size = blocked_expert_counts_size; size_t const blocked_row_to_unpermuted_row_size = num_experts_per_node * maxM * sizeof(int); - // The follow buffers are used in min_latency_mode - size_t num_active_experts_per_node_size = 0; - size_t active_expert_global_ids_size = 0; - size_t map_offset = 0; std::map> out_map; @@ -3316,8 +2863,6 @@ std::map> GemmProfilerBackend::getProfile ADD(blocked_expert_counts_cumsum); ADD(blocked_row_to_unpermuted_row); ADD(token_topk_unpermuted_scales); - ADD(num_active_experts_per_node); - ADD(active_expert_global_ids); ADD(input); ADD(output); ADD(intermediate); @@ -3358,8 +2903,6 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha GET_WS_PTR(int*, blocked_expert_counts); GET_WS_PTR(int*, blocked_expert_counts_cumsum); GET_WS_PTR(int*, blocked_row_to_unpermuted_row); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR_BASE #undef GET_WS_PTR @@ -3473,8 +3016,6 @@ void GemmProfilerBackend::prepareTmaWsInputs( GET_WS_PTR(void*, gemm_workspace); GET_WS_PTR(float*, alpha_scale_ptr_array); GET_WS_PTR(TmaWarpSpecializedGroupedGemmInput::ElementSF*, fp4_act_scale_flat); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR @@ -3573,8 +3114,6 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac GET_WS_PTR(float const*, token_topk_unpermuted_scales); auto const* token_topk_permuted_scales = token_topk_unpermuted_scales; - GET_WS_PTR_OFFSET(int*, num_active_experts_per_node, mSampleIndex); - GET_WS_PTR_OFFSET(int*, active_expert_global_ids, (mSampleIndex * mNumExpertsPerNode)); GET_WS_PTR(void const*, input); GET_WS_PTR(void*, output); GET_WS_PTR(void*, intermediate); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h index 9f780c313e1fd..bf9ab2e684c5a 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h @@ -305,16 +305,6 @@ class CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) = 0; - virtual std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) = 0; - virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; @@ -521,25 +511,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { reinterpret_cast(gemm2_output), stream); } - std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) override { - return Self::computeStridesTmaWarpSpecializedLowLatency(layout_info1, layout_info2, num_tokens, gemm1_n, - gemm1_k, gemm2_n, gemm2_k, num_experts, reinterpret_cast(input1), - reinterpret_cast(input2), reinterpret_cast(weights1), - reinterpret_cast(weights2), fp8_dequant1, fp8_dequant2, fc1_fp4_act_flat, - fc2_fp4_act_flat, quant_params, reinterpret_cast(bias1), - reinterpret_cast(bias2), reinterpret_cast(output1), - reinterpret_cast(output2), num_active_experts_per, active_expert_global_ids, - start_expert, stream); - } - private: std::pair setupTmaWarpSpecializedInputs( int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, bool use_fused_gated_activation, @@ -560,16 +531,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); - static std::pair - computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream); std::map> getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 656fde2f46ab8..355319e84e534 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -56,6 +56,8 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); disable_v1_kernel_ = ParseEnvironmentVariableWithDefault(sparse_attention::kDisableSparseAttentionV1, false); + disable_input_validation_ = ParseEnvironmentVariableWithDefault( + sparse_attention::kDisableInputValidation, false); } template @@ -105,6 +107,26 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { block_col_indices, seqlens_k_total, total_seq_len)); + + // Validate CSR indices and key lengths on device to prevent out-of-bounds access. + // This must run before the shared-buffer check so OpTester-based tests can exercise it. + cudaStream_t cuda_stream = Stream(context); + if (!disable_input_validation_) { + auto csr_error_buffer = GetScratchBuffer(1, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(ValidateCSRIndicesOnDevice( + cuda_stream, + block_row_indices->Data(), + block_col_indices->Data(), + seqlens_k_total->Data(), + parameters.num_sparse_layout, + parameters.stride_row_indices - 1, // max_blocks + parameters.stride_col_indices, // col_count + parameters.batch_size, + parameters.sequence_length, + parameters.total_sequence_length, + csr_error_buffer.get())); + } + // Some limitations of CUDA kernels // The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported. int sm = device_prop.major * 10 + device_prop.minor; @@ -137,7 +159,6 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { int32_t* total_k_seq_len_pinned = nullptr; AutoDestoryCudaEvent new_event; cudaEvent_t& isCopyDone = new_event.Get(); - cudaStream_t cuda_stream = Stream(context); if (use_v2_kernel) { pinned_buffer = AllocateBufferOnCPUPinned(parameters.batch_size); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h index 1df3affe17ea3..06fc07f088c60 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h @@ -18,14 +18,15 @@ class SparseAttention final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - int num_heads_; // number of attention heads for q - int kv_num_heads_; // number of attention heads for k and v - float scale_; // Scaling factor applied prior to softmax. - bool is_causal_; // unidirectional attention or not - int sparse_block_size_; // block size for sparsity - bool do_rotary_; // Has rotary positional embedding - bool rotary_interleaved_; // Interleaved rotary positional embedding - bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. + int num_heads_; // number of attention heads for q + int kv_num_heads_; // number of attention heads for k and v + float scale_; // Scaling factor applied prior to softmax. + bool is_causal_; // unidirectional attention or not + int sparse_block_size_; // block size for sparsity + bool do_rotary_; // Has rotary positional embedding + bool rotary_interleaved_; // Interleaved rotary positional embedding + bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. + bool disable_input_validation_; // Whether to skip device-side CSR and key-length validation. }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu index b2a6eb89d4d23..1fecb91b4f578 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu @@ -333,6 +333,120 @@ template Status QkvToContext( contrib::SparseAttentionParameters& parameters, SparseAttentionData& data); +// Validation kernel for CSR sparse layout indices and key sequence lengths. +// Each block handles one layout (blocks [0, num_layout)) or key lengths (block num_layout). +// All threads in a warp cooperate via strided iteration over elements. +// Writes a CSRValidationError code to *error_flag if any check fails. +__global__ void ValidateCSRIndicesKernel( + const int32_t* csr_row_indices, + const int32_t* csr_col_indices, + const int32_t* seqlens_k_total, + int max_blocks, + int col_count, + int num_layout, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* error_flag) { + int block_id = blockIdx.x; + int tid = threadIdx.x; + int num_threads = blockDim.x; + + if (block_id < num_layout) { + // Validate CSR indices for this layout. + const int stride_row = max_blocks + 1; + const int32_t* r = csr_row_indices + block_id * stride_row; + + // Phase 1: thread 0 validates all row pointers sequentially. + // Row arrays are small (max_blocks+1 elements), so single-thread scan is sufficient. + // All threads must reach __syncthreads before proceeding to col validation. + __shared__ int row_valid; + if (tid == 0) { + row_valid = 1; + if (r[0] != 0) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationRowFirstNotZero); + row_valid = 0; + } else { + for (int i = 0; i < max_blocks; ++i) { + if (r[i] < 0 || r[i] > r[i + 1] || r[i + 1] > col_count) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationRowNonMonotonic); + row_valid = 0; + break; + } + } + } + } + __syncthreads(); + if (!row_valid) return; + + // Phase 2: row pointers are validated, r[max_blocks] is safe to use as NNZ bound. + // All threads cooperate on the potentially larger col-index array. + const int nnz = r[max_blocks]; + const int32_t* c = csr_col_indices + block_id * col_count; + for (int i = tid; i < nnz; i += num_threads) { + if (c[i] < 0 || c[i] >= max_blocks) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationColOutOfRange); + return; + } + } + } else if (block_id == num_layout) { + // Validate key lengths. All threads cooperate in strided fashion. + bool is_prompt = (sequence_length == total_sequence_length); + int min_key_length = is_prompt ? 1 : sequence_length; + for (int i = tid; i < batch_size; i += num_threads) { + int key_length = seqlens_k_total[i]; + if (key_length < min_key_length || key_length > total_sequence_length) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationKeyLengthOutOfRange); + return; + } + } + } +} + +Status ValidateCSRIndicesOnDevice( + cudaStream_t stream, + const int32_t* csr_row_indices, + const int32_t* csr_col_indices, + const int32_t* seqlens_k_total, + int num_layout, + int max_blocks, + int col_count, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* d_error_flag) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(d_error_flag, 0, sizeof(int32_t), stream)); + + // Launch num_layout blocks for CSR validation + 1 block for key-length validation. + // Each block uses a full warp (32 threads) with strided iteration over elements. + ValidateCSRIndicesKernel<<>>( + csr_row_indices, csr_col_indices, seqlens_k_total, + max_blocks, col_count, num_layout, + batch_size, sequence_length, total_sequence_length, + d_error_flag); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + // Copy error flag back to host. + int32_t h_error_flag = 0; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(&h_error_flag, d_error_flag, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + if (h_error_flag != kCSRValidationOk) { + const char* msg = (h_error_flag == kCSRValidationRowFirstNotZero) + ? "block_row_indices first element must be 0 for all layouts" + : (h_error_flag == kCSRValidationRowNonMonotonic) + ? "block_row_indices values are not monotonically non-decreasing or exceed " + "block_col_indices columns" + : (h_error_flag == kCSRValidationColOutOfRange) + ? "block_col_indices value is out of valid range" + : "key_total_sequence_lengths value is out of valid range"; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, msg); + } + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h index d4f686afe5db0..f5b24d99b10a9 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h @@ -68,6 +68,31 @@ Status QkvToContext( contrib::SparseAttentionParameters& parameters, SparseAttentionData& data); +// Error codes returned by the CSR validation kernel via the device error flag. +enum CSRValidationError : int32_t { + kCSRValidationOk = 0, + kCSRValidationRowFirstNotZero = 1, + kCSRValidationRowNonMonotonic = 2, + kCSRValidationColOutOfRange = 3, + kCSRValidationKeyLengthOutOfRange = 4, +}; + +// Validate CSR row-pointer monotonicity, column-index range, and key lengths on device. +// Returns Status::OK() if valid, or INVALID_ARGUMENT with a description of the failure. +// d_error_flag must point to a device-allocated int32_t scratch buffer (1 element). +Status ValidateCSRIndicesOnDevice( + cudaStream_t stream, + const int32_t* csr_row_indices, // device pointer, shape [num_layout, max_blocks + 1] + const int32_t* csr_col_indices, // device pointer, shape [num_layout, col_count] + const int32_t* seqlens_k_total, // device pointer, shape [batch_size] + int num_layout, + int max_blocks, + int col_count, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* d_error_flag); // device scratch buffer (1 int32_t) + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d9d299d4fd5d9..4e926c7efa597 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -55,6 +55,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + if (prepare_indirect_dispatch_) { + sh.AddInput("total_sequence_length_input", ShaderUsage::None); + } const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); @@ -97,8 +100,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::None); } - // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + // If prepare_indirect_dispatch is enabled, add total_sequence_length_input + // and indirect_buffer output. total_sequence_length_input is the global max + // total sequence length across the batch (from GQA input #6); using it for + // dispatch sizing covers right-padded batches where batch 0 is not the max. if (prepare_indirect_dispatch_) { + shader.AddInput("total_sequence_length_input", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -109,11 +116,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; if (use_seqlen_k_) { - shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[batch]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + // Right-padded batches with prompt shorter than kv_sequence_length would underflow u32; clamp to 0. + shader.MainFunctionBody() << " let past_sequence_length = select(total_seq_length - uniforms.kv_sequence_length, 0u, total_seq_length <= uniforms.kv_sequence_length);\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; } else { @@ -124,7 +132,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (prepare_indirect_dispatch_) { shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" - << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" + << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -152,7 +161,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles, + const Tensor* total_seqlen) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -188,6 +198,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -262,9 +275,15 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (use_indirect_dispatch_) { + // Global max total sequence length across batches (from GQA input #6). + // Used in indirect-dispatch mode for the workgroup_idx slicing so that + // batch 0's per-batch length cannot undersize the dispatch grid. + shader.AddInput("total_sequence_length_input", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -282,6 +301,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), @@ -293,7 +313,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -303,13 +323,16 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (use_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -320,10 +343,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } if (use_indirect_dispatch) { @@ -332,7 +357,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -343,6 +368,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, {attn_bias_dim1}, + {attn_bias_dim3}, {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); @@ -351,7 +377,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -364,7 +390,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(output, output)); @@ -379,17 +405,17 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch, const Tensor* head_sink, - uint32_t m_tile) { + uint32_t m_tile, + bool use_seqlen_k) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { @@ -399,7 +425,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) + .CacheHint(tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -415,7 +441,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink, + const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr. @@ -437,7 +464,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Read seqlens_k per batch_idx in the shader whenever seqlens_k is supplied. + // This covers both graph-capture (total_sequence_length_ is 0 on the host) and + // right-padded batches (batch_size > 1 with distinct per-batch totals), and lets + // batch=1 share the same path. When seqlens_k is null, kernels fall back to + // uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr; // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -453,8 +485,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Prepare indirect dispatch buffer for split-reduce path with static KV cache. // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && + // Static KV cache (past_present_share_buffer_) is guaranteed by GQA's + // ORT_ENFORCE when graph capture is enabled. + const bool use_indirect_dispatch = seqlen_k != nullptr && + total_seqlen != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions @@ -492,10 +526,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles)); + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); } // Extract present_sequence_length directly from present_key tensor shape @@ -555,10 +590,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) @@ -572,7 +609,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {alpha}, {num_seq_tile}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {attn_bias_dim3}}); return context.RunProgram(program); } @@ -596,27 +634,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile)); + present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - head_sink, m_tile)); + num_present_sequence_length_tile, tile_size, + head_sink, m_tile, use_seqlen_k)); return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - const bool kv_empty = parameters.kv_sequence_length_ == 0; - // FlashAttention here does not implement right-padded per-batch prefill, so the - // first disjunction restricts it to inputs where padding cannot occur: - // - batch_size_ == 1: single sequence, no padding possible. - // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. - // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. - // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), - // mismatched head/value sizes, and head_size alignments unsupported by the kernel. - return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && - !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } @@ -631,7 +660,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size, uint32_t num_q_tiles) { + uint32_t tile_size, uint32_t num_q_tiles, + const Tensor* total_seqlen) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -669,6 +699,9 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {cos_cache, ProgramTensorMetadataDependency::Rank, components}, {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }); + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, {present_key, ProgramTensorMetadataDependency::None, components}, {present_value, ProgramTensorMetadataDependency::None, components}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 218baf926173f..85ba61c1d20b5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -125,7 +125,8 @@ class FlashAttentionProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, - {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim3", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -148,8 +149,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -195,17 +199,18 @@ class FlashAttentionDecodeVxReduceProgram final : public Program u32 { - return u32(seqlens_k[0]) + 1u; +// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. +fn get_total_sequence_length(batch_idx: u32) -> u32 { + return u32(seqlens_k[batch_idx]) + 1u; } #else -// When graph capture is disabled, total_sequence_length comes from uniforms -fn get_total_sequence_length() -> u32 { +// Without seqlens_k, total_sequence_length comes from uniforms (max across batches). +fn get_total_sequence_length(batch_idx: u32) -> u32 { return uniforms.total_sequence_length; } #endif @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); @@ -95,15 +93,19 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { } #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t { - if (k_idx_global >= get_total_sequence_length()) { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t { + if (k_idx_global >= total_seq) { return q_element_t(0); } let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); - return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; + return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq - 1u)]); } #endif @@ -111,24 +113,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // For max performance max_k_step should be the same as sg_size, however we might run out of registers // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } } @@ -160,18 +162,22 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= total_seq) { return vec4(0); } // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; - let offset_max = offset_base + get_total_sequence_length(); + let offset_max = offset_base + stride_total_seq - 1u; let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -179,7 +185,7 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { return vec4(0); } #endif @@ -226,11 +232,14 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; #endif - let total_sequence_length = get_total_sequence_length(); + let total_sequence_length = get_total_sequence_length(batch_idx); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); @@ -244,8 +253,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx); - loadv(k_start, batch_head_idx, local_idx); + loadk(k_start, batch_head_idx, local_idx, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, total_sequence_length); workgroupBarrier(); for (var k = 0u; k < max_k_step; k++) { @@ -254,7 +263,7 @@ $MAIN { score += dot(q_tile[i], k_tile[k][i]); } #if has_attention_bias - score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx); + score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length); #endif qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length); } @@ -302,8 +311,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx, capped_sg_size); - loadv(k_start, batch_head_idx, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); workgroupBarrier(); // Compute QKt @@ -361,11 +370,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length); } // Neuter qk values where K is out of bounds. diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 524a18ca43245..778e07fbf63ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -10,6 +10,7 @@ #param tile_size #param tile_size_k_vec #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -34,18 +35,22 @@ var tile_max: array; var tile_sum: array; #if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + - bias_head_idx * uniforms.new_sequence_length * total_seq_length + - q_idx * total_seq_length + + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + + q_idx * stride_total_seq + k_idx; return attention_bias[offset]; } #else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { return q_element_t(0); } @@ -54,12 +59,14 @@ var tile_sum: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; + // total_sequence_length used for workgroup_idx slicing must match the host-side dispatch + // grid, i.e. the global maximum across batches. Per-batch total is derived separately below. #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + let global_total_sequence_length = u32(total_sequence_length_input[0]); #else - let total_sequence_length = uniforms.total_sequence_length; + let global_total_sequence_length = uniforms.total_sequence_length; #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_total_seq_length_tile = (global_total_sequence_length + tile_size - 1) / tile_size; let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; @@ -71,9 +78,28 @@ $MAIN { if (batch_idx >= uniforms.batch_size) { return; } + // Per-batch total_sequence_length used for K/V bounds, causal mask, and softmax range. + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[batch_idx]) + 1u; + #else + let total_sequence_length = global_total_sequence_length; + #endif let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + // If this workgroup's tile lies entirely beyond this batch's per-batch total_sequence_length, + // write neutral metadata so VxReduce contributes nothing for these tiles, then exit early. + if (total_seq_offset >= total_sequence_length) { + if (local_idx == 0u) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx_local = q_base + m; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx_local) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(-3.4028234663852886e+38f, 0.0f)); + } + } + return; + } + // ============================================================ // Phase 1: QK^T computation // ============================================================ @@ -109,6 +135,12 @@ $MAIN { } // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask +#if is_unidirectional + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); +#endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { @@ -117,9 +149,9 @@ $MAIN { sum += inner_qk_values[m][local_idx][i]; } - sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional - if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + if (total_seq_offset + local_idx > past_sequence_length + q_idx) { sum = q_element_t(-65504.0f); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a3ce0b68cb659..628ad835a9d4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -5,7 +5,7 @@ #param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -32,8 +32,12 @@ $MAIN { } let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + // Per-batch total_sequence_length: short batches contributed neutral metadata + // (-inf, 0) for tiles beyond their per-batch total, so reading only this batch's + // tiles ensures softmax rescaling is not skewed by garbage tiles. + #if use_seqlen_k + let batch_idx_for_seqlen = batch_head_idx / uniforms.num_heads; + let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 36d688c9723fd..24ace3487a4c5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -327,6 +327,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& past_value->DataRaw() == present_value->DataRaw(); ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + ORT_ENFORCE(!context.IsGraphCaptureEnabled() || parameters.past_present_share_buffer_, + "Graph capture requires past/present KV cache to share the same buffer (static KV cache)."); Tensor qSplit; Tensor kSplit; @@ -350,7 +352,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (kv_empty) { @@ -381,7 +383,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink); + present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink, + total_seqlen_tensor); } // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -472,7 +475,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink, + total_seqlen_tensor); } // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 97c610fb90024..e3d92c036d2c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -43,7 +43,8 @@ $MAIN { #if prepare_indirect_dispatch if (global_idx == 0u) { - let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + let global_total_seq_length = u32(total_sequence_length_input[0]); + let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index ebf3cc9f50be6..ec5c1386e8336 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -405,4 +405,13 @@ CPUIDInfo::CPUIDInfo() { #endif #endif // defined(CPUIDINFO_ARCH_RISCV64) } + +CPUIDInfo::~CPUIDInfo() { +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + cpuinfo_deinitialize(); + pytorch_cpuinfo_init_ = false; + } +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index bf502c645c9eb..6eed234332f46 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -110,6 +110,7 @@ class CPUIDInfo { static void LogEarlyWarning(std::string_view message); CPUIDInfo(); + ~CPUIDInfo(); void VendorInfoInit(); diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index 99fa21b1e4be8..b53a99084152f 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -56,6 +56,10 @@ const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const return std::get_if(&output_model_location); } +const EpContextDataWriteFuncHolder* ModelGenOptions::TryGetEpContextDataWriteFunc() const { + return ep_context_data_write_func.write_func != nullptr ? &ep_context_data_write_func : nullptr; +} + bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { return std::holds_alternative(initializers_location); } diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 6643516bfb4c3..344ec5f7f6b58 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -7,6 +7,9 @@ #include #include "core/framework/allocator.h" #include "core/framework/config_options.h" +// Needed for OrtWriteNamedBufferFunc (used by EpContextDataWriteFuncHolder below). This include can be removed +// once the experimental EPContext data callback APIs are promoted to the stable C API. +#include "core/session/onnxruntime_experimental_c_api.h" namespace onnxruntime { namespace epctx { @@ -27,6 +30,14 @@ struct BufferWriteFuncHolder { void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. }; +/// +/// Holds the opaque state and write function that EPs use to write EPContext binary data. +/// +struct EpContextDataWriteFuncHolder { + OrtWriteNamedBufferFunc write_func = nullptr; + void* state = nullptr; +}; + /// /// Holds path and size threshold used to write out initializers to an external file. /// @@ -84,10 +95,13 @@ struct ModelGenOptions { InitializerHandler> // Custom function called for every initializer to determine location. initializers_location = std::monostate{}; + EpContextDataWriteFuncHolder ep_context_data_write_func = {}; + bool HasOutputModelLocation() const; const std::filesystem::path* TryGetOutputModelPath() const; const BufferHolder* TryGetOutputModelBuffer() const; const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + const EpContextDataWriteFuncHolder* TryGetEpContextDataWriteFunc() const; bool AreInitializersEmbeddedInOutputModel() const; const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b328fc916f885..ddd21074afe8a 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -16,6 +16,9 @@ #include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" +// Needed for OrtReadNamedBufferFunc, the type of the EPContext data read callback stored in this struct. This include +// can be removed once the experimental EPContext data callback APIs are promoted to the stable C API. +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/optimizer/graph_transformer_level.h" #include "core/util/thread_utils.h" @@ -226,6 +229,9 @@ struct SessionOptions { bool has_explicit_ep_context_gen_options = false; epctx::ModelGenOptions ep_context_gen_options = {}; epctx::ModelGenOptions GetEpContextGenerationOptions() const; + + OrtReadNamedBufferFunc ep_context_data_read_func = nullptr; + void* ep_context_data_read_state = nullptr; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ef2319c1d3f4..241eb8362ddfa 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -498,8 +498,15 @@ Status SessionState::PrepackConstantInitializedTensors( auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); - // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + // CPU EP only. An initializer joins the shared pre-packed container either when it was + // registered via OrtApi::AddInitializer (is_shared_initializer) or when a graph transformer + // tagged this synthesized initializer with a sharing identity. Only the tag's *presence* + // matters here: it is the enrollment signal. The container key below is the packed-bytes + // hash, never the tag value (see the rationale at the key computation). + const bool enroll_tagged_initializer = + (st->graph_.GetSharedPrepackInitializerId(input_name) != nullptr); + if ((is_shared_initializer || enroll_tagged_initializer) && + should_cache_prepacked_weights_for_shared_initializers && node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON @@ -530,12 +537,18 @@ Status SessionState::PrepackConstantInitializedTensors( // TODO: Check if some version of the ONNX IR allows op_type to be empty ORT_ENFORCE(!op_type.empty(), "The op type of a node cannot be empty"); - // The key for the pre-packed weights container lookup is the op_type + hash of the prepacked-weight - // that we just got by invoking PrePack() on this kernel. - + // Key by the packed-bytes hash (op_type + a hash of the packed buffer), exactly as the + // AddInitializer path does, so only byte-identical packed buffers are ever shared. The + // tag is solely the enrollment signal that opted this fusion-generated initializer into + // the container; it must NOT be used as the key, because it is derived from the + // *unpacked* initializer content and so cannot distinguish packings that differ by node + // options/attributes that change the packed layout (e.g. mlas.use_lut_gemm or a CPU + // backend-selector difference). Two sessions that share a container but differ in such an + // option compute the same tag yet produce different packed bytes; keying by the packed + // bytes gives them distinct keys and prevents reusing an incompatible buffer + // (wrong results/crash). const std::string prepacked_weights_container_key = - GenerateKeyForPrepackedWeightsMap(op_type, - weights_to_be_filled_in); + GenerateKeyForPrepackedWeightsMap(op_type, weights_to_be_filled_in); bool container_contains_packed_weight = prepacked_weights_container_->HasWeight( prepacked_weights_container_key); @@ -603,7 +616,16 @@ Status SessionState::PrepackConstantInitializedTensors( // within this session. Or if the weight is not present on disk, // we store the newly minted pre-packed data. - AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + AllocatorPtr session_initializer_alloc = GetInitializerAllocator( + kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + // A plugin EP registered as a separate library may not have an initializer + // allocator registered under the kernel's device key, so the lookup above can + // return null. Fall back to the kernel's own default-memory allocator (resolved + // through the EP), which is always valid. This keeps PrePack implementations from + // each having to special-case a null allocator at the library boundary. + if (!session_initializer_alloc) { + session_initializer_alloc = kernel->Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + } PrePackedWeights weights_to_be_filled_in; // The reason we invoke PrePack() before looking into the container for any pre-packed weight // cached by another instance of the same op_type (for the same constant initializer) is because @@ -615,11 +637,9 @@ Status SessionState::PrepackConstantInitializedTensors( is_packed, &weights_to_be_filled_in)); - // Some kernels (matmul_nbits and non-CPU related kernels) do not share their pre-packed results + // Some kernels (non-CPU related kernels) do not share their pre-packed results // even though they set is_packed = true so we leave it up to them. // We can change their behavior if we wish do so in a separate PR - // XXX: Interestingly enough, matmul_nbits does accept shared pre-packs, but does not - // produce them. if (is_packed && !weights_to_be_filled_in.buffers_.empty()) { const auto& op_type = node.OpType(); const std::string prepacked_weights_container_key = GenerateKeyForPrepackedWeightsMap( diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 80e7829a88e62..28477db5ec172 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -136,6 +136,20 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st deserialized_value, &prepacked_for_graph)); + const Tensor& cpu_staging_tensor = deserialized_value.Get(); + // Bool external initializers are copied verbatim and may carry bytes outside the canonical + // {0, 1} set. The CPU staging tensor above can be backed by a read-only mmap, so normalize into + // a writable CPU copy before copying to the device (see utils::NormalizeBoolTensorIfNeeded). + if (cpu_staging_tensor.IsDataType()) { + Tensor normalized_cpu_tensor; + ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true, + tensor_shape, type, + default_cpu_alloc, normalized_cpu_tensor)); + utils::MakeCpuTensorCopy(cpu_staging_tensor, normalized_cpu_tensor); + utils::NormalizeBoolTensorIfNeeded(normalized_cpu_tensor); + return CopyTensorFromCPUToDevice(data_transfer_mgr, normalized_cpu_tensor, std::move(tensor), ort_value); + } + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_value.Get(), std::move(tensor), ort_value); } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 8685655420a38..7b8f78136be3d 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -725,6 +725,17 @@ void ConvertRawDataInTensorProto(TensorProto& tensor) { SwapByteOrderInplace(element_size, span); } +// Bool tensors must hold canonical {0, 1} byte values. Data sourced from raw_data or external +// files is copied verbatim and may contain other non-zero bytes; normalize any non-zero byte to 1 +// so every consumer observes a single, consistent value. Operate on the byte representation to +// avoid loading a bool object that does not yet hold a valid value. +static void NormalizeBoolBytes(uint8_t* bool_bytes, size_t num_elements) { + static_assert(sizeof(bool) == 1, "Normalization assumes 1 byte per bool element"); + for (size_t i = 0; i < num_elements; ++i) { + bool_bytes[i] = bool_bytes[i] != 0 ? 1 : 0; + } +} + #if !defined(ORT_MINIMAL_BUILD) static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, @@ -752,6 +763,19 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } +// UnpackTensorWithExternalData +// External data is copied verbatim and may contain bytes outside the canonical {0, 1} set, so +// normalize them (see NormalizeBoolBytes). +template <> +Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, + const std::filesystem::path& tensor_proto_dir, size_t expected_num_elements, + /*out*/ bool* p_data) { + ORT_RETURN_IF_ERROR(UnpackTensorWithExternalDataImpl(tensor, tensor_proto_dir, expected_num_elements, sizeof(bool), + reinterpret_cast(p_data))); + NormalizeBoolBytes(reinterpret_cast(p_data), expected_num_elements); + return Status::OK(); +} + #define DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \ template <> \ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ @@ -800,7 +824,9 @@ INSTANTIATE_UNPACK_EXTERNAL_TENSOR(int32_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(int64_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(uint64_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(uint32_t) -INSTANTIATE_UNPACK_EXTERNAL_TENSOR(bool) +// bool is intentionally omitted: UnpackTensorWithExternalData is explicitly specialized +// above (to normalize bytes to {0, 1}), so an explicit instantiation here would have no effect +// and triggers -Werror,-Winstantiation-after-specialization. INSTANTIATE_UNPACK_EXTERNAL_TENSOR(MLFloat16) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(BFloat16) @@ -907,7 +933,11 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d } if (raw_data != nullptr) { - return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); + ORT_RETURN_IF_ERROR(UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data)); + // raw_data is copied verbatim and may contain bytes outside the canonical {0, 1} set (see + // NormalizeBoolBytes). + NormalizeBoolBytes(reinterpret_cast(p_data), expected_size); + return Status::OK(); } if (static_cast(tensor.int32_data_size()) != expected_size) @@ -1881,6 +1911,9 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, ort_value)); const auto& ext_tensor = ort_value.Get(); MakeCpuTensorCopy(ext_tensor, tensor); + // MakeCpuTensorCopy memcpy's external bytes verbatim. Bool external initializers may carry + // bytes outside the canonical {0, 1} set, so normalize them here as well (see NormalizeBoolBytes). + NormalizeBoolTensorIfNeeded(tensor); return Status::OK(); } @@ -2191,6 +2224,13 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { } } +void NormalizeBoolTensorIfNeeded(Tensor& tensor) { + if (tensor.IsDataType()) { + NormalizeBoolBytes(reinterpret_cast(tensor.MutableDataRaw()), + narrow(tensor.Shape().Size())); + } +} + #if !defined(DISABLE_SPARSE_TENSORS) // Validates the external data declaration on a sub-tensor of a SparseTensorProto (values or diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 63f7b5e78e478..c07e9703ad384 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -283,6 +283,15 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n /// void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor); +/// +/// Normalizes the bytes of a CPU bool tensor to the canonical {0, 1} set (any non-zero byte -> 1). +/// Bool data sourced from raw_data or external files is copied verbatim and may contain other +/// non-zero bytes; normalizing ensures every consumer observes a single, consistent value. +/// No-op for non-bool tensors. The tensor must reside in writable CPU memory. +/// +/// The CPU tensor to normalize in place. +void NormalizeBoolTensorIfNeeded(Tensor& tensor); + #if !defined(DISABLE_SPARSE_TENSORS) /// // The function supports only COO format with 1D or 2D indices. Values shape is expected to be 1D. diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 811bad15ebbab..cbca2d85a97a4 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2297,6 +2297,63 @@ MlasFlashAttention( MLAS_THREADPOOL* ThreadPool ); +// +// Flash Attention for non-quantized (FP32) GroupQueryAttention KV cache. +// +// Adapts the online-softmax tiled algorithm to operate on an FP32 present +// K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). +// Supports GQA head grouping, causal masking, local window attention, +// additive attention bias, and an optional flash-decoding split over the KV +// sequence dimension for the single-token decode case. +// +struct MlasFlashAttentionGQAArgs { + int batch_size; + int num_heads; // number of query heads + int kv_num_heads; // number of key/value heads (num_heads % kv_num_heads == 0) + int sequence_length; // number of new query tokens (S) + int total_seqlen; // total tokens (past + new) for this invocation (T) + int head_size; // per-head size (H) + int past_seqlen; // causal offset (number of cached tokens before the new ones) + int local_window_size; // -1 disables local window masking + int seqlen_present_kv; // sequence dimension of the present K/V buffer + int q_block_size; // query tile size (Br) + int kv_block_size; // key/value tile size (Bc) + float scale; // QK scaling factor + int thread_count; // number of partitions / threads + float* buffer; // per-thread scratch (+ optional flash-decoding partials) + size_t buffer_size_per_thread; + + const float* query; // [batch, num_heads, sequence_length, head_size] BNSH + size_t q_batch_stride; // element stride between consecutive batches in `query` + // (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV) + const float* k_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + const float* v_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + float* output; // [batch, sequence_length, num_heads, head_size] BSNH + + const float* attention_bias; // [batch|1, num_heads|1, S, T] additive bias, or nullptr + int attention_bias_seqlen_stride; + bool attention_bias_broadcast_batch; + bool attention_bias_broadcast_head; + + // Flash decoding (sequence_length == 1): partition KV across threads. + // Set flash_decoding_partials != nullptr to enable; otherwise the standard + // per-(batch, head, q_block) partitioning is used. + float* flash_decoding_partials; + int kv_chunk_count; +}; + +/** + * @brief FP32 Flash Attention for GroupQueryAttention with an FP32 KV cache. + * @param args Arguments + * @param ThreadPool Thread pool + */ +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +); + /** * @brief Enumeration of supported GELU algorithm variants. * diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp new file mode 100644 index 0000000000000..0f1210ca1e3c5 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -0,0 +1,818 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_gqa.cpp + +Abstract: + + Flash Attention kernel for the non-quantized (FP32) GroupQueryAttention + KV cache. + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate on + an FP32 present K/V cache laid out as BNSH + ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head + grouping (num_heads % kv_num_heads == 0), causal masking, local window + attention, additive attention bias, and an optional flash-decoding split + over the KV sequence dimension for single-token decode. + + For multi-token prefill (sequence_length > 1) QK^T and S*V use the + single-threaded SGEMM primitive MlasSgemmOperation. For single-token decode + (sequence_length == 1, including the flash-decoding KV split) the M == 1 + GEMVs use the local MlasGQADecodeQK / MlasGQADecodeSV helpers to avoid SGEMM + packing overhead. The outer parallelism is provided by MlasExecuteThreaded. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" + +// +// Decode (sequence_length == 1) GEMV helpers. +// +// With a single query token the QK^T and S*V products degenerate into +// matrix-vector products. Computing them directly streams the K and V cache +// exactly once and avoids the SGEMM B-packing overhead that otherwise dominates +// the tiny M = 1 GEMMs. These helpers live in the baseline-ISA MLAS translation +// unit, so the inner loops are written with independent accumulator lanes and a +// map-style update so the compiler can vectorize them without -ffast-math +// (which would be required to reassociate a plain scalar float reduction). +// + +// QK^T GEMV: scores[t] = scale * dot(q[0..H), K[t*H .. t*H+H)) for t in [0, n_kv). +static void +MlasGQADecodeQK( + const float* q, + const float* k_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float scale, + float* scores +) +{ + constexpr std::ptrdiff_t kLanes = 8; + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float* krow = k_cache + t * head_size; + float acc[kLanes] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + std::ptrdiff_t h = 0; + for (; h + kLanes <= head_size; h += kLanes) { + for (std::ptrdiff_t j = 0; j < kLanes; ++j) { + acc[j] += q[h + j] * krow[h + j]; + } + } + float sum = ((acc[0] + acc[1]) + (acc[2] + acc[3])) + + ((acc[4] + acc[5]) + (acc[6] + acc[7])); + for (; h < head_size; ++h) { + sum += q[h] * krow[h]; + } + scores[t] = sum * scale; + } +} + +// S*V GEMV (accumulate): out[h] = sum_t probs[t] * V[t*H + h] for h in [0, head_size). +// `out` is overwritten (initialized to zero) before accumulation. +static void +MlasGQADecodeSV( + const float* probs, + const float* v_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float* out +) +{ + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] = 0.0f; + } + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float p = probs[t]; + const float* vrow = v_cache + t * head_size; + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] += p * vrow[h]; + } + } +} + +void +MlasFlashAttentionGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers. Layout: [batch, kv_num_heads, seqlen_present, head_size] + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is + // supplied separately (args->q_batch_stride) so the kernel works with both the + // standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Causal early-termination bound: the largest global query position in this + // q_block is (past_seqlen + q_idx + row_size_q - 1), so it can attend to KV + // positions up to that index inclusive. Any KV block starting at or beyond + // (past_seqlen + q_idx + row_size_q) is fully causally masked for every row in + // the block, so it contributes nothing and can be skipped. This avoids the + // wasted QK/SV GEMMs over the causal upper triangle during prefill. + const ptrdiff_t kv_causal_limit = + past_seqlen + q_idx + static_cast(row_size_q); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + if (ir >= kv_causal_limit) { + break; + } + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with FP32 K block + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch + // stride uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so the S*V GEMM contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + ir == 0 ? 0.0f : 1.0f, // beta (accumulate after first block) + temp_output, // C (accumulated output) + static_cast(head_size) // ldc + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). + // The batch stride is supplied separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV for this KV chunk (M = 1) + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeQK(q_ptr, k_block, static_cast(row_size_kv), head_size, scale, scores); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride + // uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output (M = 1) + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeSV(scores, v_block, static_cast(row_size_kv), head_size, partial_output); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingGQAReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +// +// Decode kernel for sequence_length == 1 without KV-split (batch * heads >= +// thread_count). Parallelizes over (batch, head); each task attends the single +// query token to the whole KV cache with a pair of GEMVs and a two-pass softmax. +// Decode needs no causal masking (the single new token is the most recent +// position and attends to every cached token); only optional local-window +// masking and additive bias are applied. +// +void +MlasGQADecodeGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // One task per (batch, head). + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + // Local-window low bound: decode can attend to KV positions [window_start, total_seqlen). + // causal_limit == past_seqlen + 1 == total_seqlen for the single new token. + const ptrdiff_t window_start = + (local_window_size >= 0 && total_seqlen > local_window_size) ? (total_seqlen - local_window_size) : 0; + + // Per-thread scratch: scores[total_seqlen] followed by temp_output[head_size]. + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + float* temp_output = scores + total_seqlen; + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + const ptrdiff_t head_idx = task_index % num_heads; + const ptrdiff_t batch_idx = task_index / num_heads; + + // KV head index for GQA head sharing. + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size]; batch stride supplied + // separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV -> scores[0..T) + MlasGQADecodeQK(q_ptr, k_cache_head, total_seqlen, head_size, scale, scores); + + // Step 1b: additive attention bias (shape [batch|1, num_heads|1, S=1, T]). + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_matrix_size = + static_cast(args->attention_bias_seqlen_stride); // S == 1 + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset; + for (ptrdiff_t t = 0; t < total_seqlen; ++t) { + scores[t] += bias_row[t]; + } + } + + // Step 2: local-window masking (no causal mask needed for decode). + if (window_start > 0) { + for (ptrdiff_t t = 0; t < window_start; ++t) { + scores[t] = std::numeric_limits::lowest(); + } + } + + // Step 3: softmax over scores[0..T). +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, total_seqlen); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, total_seqlen); +#endif + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + if (rowmax == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#endif + + // Step 4: S_exp * V GEMV -> temp_output, then normalize by 1/l. + MlasGQADecodeSV(scores, v_cache_head, total_seqlen, head_size, temp_output); + + const float inv_l = (rowsum > 0.0f) ? (1.0f / rowsum) : 0.0f; + for (ptrdiff_t h = 0; h < head_size; ++h) { + output_ptr[h] = temp_output[h] * inv_l; + } + } +} + +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + if (args->sequence_length == 1) { + // Decode: M = 1, use the GEMV kernels (no SGEMM packing overhead). + if (args->flash_decoding_partials != nullptr) { + // Flash decoding: two-phase approach when KV is partitioned across threads. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + // Single-pass decode parallelized over (batch, head). + MlasExecuteThreaded( + MlasGQADecodeGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } + } else { + // Prefill (sequence_length > 1): tiled online-softmax SGEMM kernel. + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } +} + diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc index f3956d5e9e0f3..07fccef64fee1 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -12,6 +12,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/matmul_nbits_sharing_identity.h" #include "core/optimizer/utils.h" #include @@ -447,7 +448,6 @@ std::vector CollectDirectDQMatches( return direct_matches; } -// --------------------------------------------------------------------------- // Pattern 1 rewriting: DQ+Reshape+Transpose+[Cast]+MatMul/Gemm -> MatMulNBits // --------------------------------------------------------------------------- @@ -569,6 +569,10 @@ void ApplyReshapeTransposeFusions( zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } + // Cross-session sharing identity for the generated weight group; computed before the tensors move. + const std::string share_id = + ComputeMatMulNBitsSharingId(weight_dst, scale_dst, zp_dst, N, K, block_size, /*bits*/ 4, accuracy_level); + NodeAttributes mnb_attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); @@ -578,7 +582,10 @@ void ApplyReshapeTransposeFusions( std::vector mnb_inputs; mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); - mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + NodeArg& b_weight_arg = graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + mnb_inputs.push_back(&b_weight_arg); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); @@ -749,6 +756,10 @@ void ApplyDirectDQFusions( zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } + // Cross-session sharing identity for the generated weight group; computed before the tensors move. + const std::string share_id = + ComputeMatMulNBitsSharingId(weight_dst, scale_dst, zp_dst, N, K, block_size, /*bits*/ 4, accuracy_level); + NodeAttributes mnb_attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); @@ -758,7 +769,10 @@ void ApplyDirectDQFusions( std::vector mnb_inputs; mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); - mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + NodeArg& b_weight_arg = graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + mnb_inputs.push_back(&b_weight_arg); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); diff --git a/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h b/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h new file mode 100644 index 0000000000000..829a78d3ebcf1 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/framework/murmurhash3.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { + +// Stable, content-derived identity for a fusion-generated MatMulNBits weight group, used to share its +// pre-packed buffer across sessions. The id is identical for the same model in any session and differs +// whenever a semantic input differs. accuracy_level is hashed so buffers packed for different compute +// types never collide. Pass zero_point only when it is an actual kernel input. +inline std::string ComputeMatMulNBitsSharingId(const Tensor& weight, const Tensor& scale, + const std::optional& zero_point, + int64_t N, int64_t K, int64_t block_size, + int64_t bits, int64_t accuracy_level) { + // MurmurHash3 fmix64 finalizer: a bijection that avalanches a 64-bit value so each input bit affects + // every output bit. + auto fmix64 = [](uint64_t x) { + x ^= x >> 33; + x *= 0xff51afd7ed558ccdULL; + x ^= x >> 33; + x *= 0xc4ceb9fe1a85ec53ULL; + x ^= x >> 33; + return x; + }; + // Fold each segment's full 128-bit hash into the 64-bit accumulator and carry the whole accumulator + // forward, not just a 32-bit seed. Every bit of weight/scale/zero_point/params therefore reaches the + // id, so collision resistance tracks the 64-bit id width instead of the ~2^32 a chain forwarding only + // hash[0] would give. A collision would let one weight group adopt another's already-packed buffer and + // silently compute a wrong result, so the wider margin is worth the few extra mixing ops. + uint64_t acc = 0; + auto mix = [&acc, &fmix64](const void* data, size_t len) { + uint32_t h[4]; + MurmurHash3::x86_128(data, len, static_cast(acc), h); + acc = fmix64(acc ^ ((static_cast(h[1]) << 32) | h[0])); + acc = fmix64(acc ^ ((static_cast(h[3]) << 32) | h[2])); + }; + mix(weight.DataRaw(), weight.SizeInBytes()); + mix(scale.DataRaw(), scale.SizeInBytes()); + if (zero_point) { + mix(zero_point->DataRaw(), zero_point->SizeInBytes()); + } + const int64_t params[] = {N, K, block_size, bits, accuracy_level}; + mix(params, sizeof(params)); + return "MatMulNBits.DQ:" + std::to_string(acc); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index b9d7e898157bd..6bd5e157d8b65 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -7,6 +7,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/matmul_nbits_sharing_identity.h" #include "core/graph/node_attr_utils.h" #include "core/graph/graph_utils.h" #include "core/framework/tensorprotoutils.h" @@ -646,8 +647,23 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed)); + // Cross-session sharing identity for the generated B weight; computed before it is moved. + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* weight_shape = weight_arg->Shape(); + ORT_RETURN_IF_NOT(weight_shape != nullptr && weight_shape->dim_size() >= 2, + "Weight shape unavailable for DQ node ", dq_node->Name()); + const int64_t bits = DQWeightBits(weight_arg->TypeAsProto()->tensor_type().elem_type()); + const std::string share_id = ComputeMatMulNBitsSharingId( + transposed.weight, transposed.scale, transposed.zero_point, + weight_shape->dim(1).dim_value(), weight_shape->dim(0).dim_value(), + effective_bs, bits, accuracy_level_); + auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); + NodeArg& b_weight_arg = + graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + input_defs.push_back(&b_weight_arg); replacement_node.MutableInputArgsCount().push_back(1); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 0270bf9d4d79c..c34d8b3dbf696 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -657,6 +657,11 @@ class PosixEnv : public Env { } } } + ~PosixEnv() { + if (cpuinfo_available_) { + cpuinfo_deinitialize(); + } + } bool cpuinfo_available_{false}; #endif // ORT_USE_CPUINFO }; diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 1b0f37feab9fb..150fe7b478c1e 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -347,7 +347,15 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* mask_data = nullptr; bool delete_mask_data = false; - bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + // In the nonpad_kv_seqlen path, q_len=1 is external KV-cache decode with + // bottom-right alignment. The single query's causal frontier is the valid + // length, so nonpad masking alone leaves exactly all valid keys visible. + // Keep causal=false for that case to avoid applying the batch-shared + // upper-left overlay used by the no-nonpad path. + bool causal = parameters.is_causal && + (parameters.has_nonpad_kv_seqlen + ? parameters.q_sequence_length > 1 + : !(parameters.q_sequence_length == 1 && parameters.past_sequence_length > 0)); // When nonpad_kv_seqlen is present the causal frontier is offset-aware // (bottom-right) and per-batch, so it cannot be baked into the batch-shared mask // buffer here; it is applied per-batch in the main loop below. Skip the top-left diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index edafbfb3ede65..dc53b02141207 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1361,9 +1361,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // cross-attention case; MEA handles it via the causal_from_top_left flag and Unified // Unfused uses past_kv_length=0. (When an external cache is present — nonpad_kv_seqlen — // the required frontier IS bottom-right, so Flash is eligible; see below.) - const bool causal_cross_no_past = parameters.is_causal && - parameters.q_sequence_length != parameters.total_sequence_length && - parameters.past_sequence_length == 0; + [[maybe_unused]] const bool causal_cross_no_past = + parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; // is_causal=1 + nonpad_kv_seqlen (external KV cache) without past_key defines a // bottom-right causal frontier per onnx/onnx#8068: query in-block index i attends key j diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index d611366c8ad7e..73fa92d19cd1f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -83,20 +83,33 @@ void DestroyCudaStreamForDevice(cudaStream_t stream, int device_id) { } // namespace struct CudaEp::PerThreadContext { - explicit PerThreadContext(int device_id) + // When use_external_stream is true (user_compute_stream combined with CUDA graph), capture and + // replay happen on that user-owned stream so they see the same stream as the kernels; the + // context neither creates nor destroys it. Ownership is derived from the caller's intent rather + // than from external_stream being non-null, because a user may legitimately select the CUDA + // default stream (cudaStream_t(0), i.e. nullptr) as the compute stream — that is still an + // external, user-owned stream and must not be destroyed by the context. When use_external_stream + // is false the context creates and owns a dedicated graph stream. + explicit PerThreadContext(int device_id, bool use_external_stream = false, + cudaStream_t external_stream = nullptr) : device_id(device_id), - graph_stream(CreateCudaStreamForDevice(device_id)), + owns_graph_stream(!use_external_stream), + graph_stream(use_external_stream ? external_stream + : CreateCudaStreamForDevice(device_id)), cuda_graph(graph_stream) { } ~PerThreadContext() { // Destroy captured graph execs before destroying the stream they replay on. cuda_graph.Reset(); - DestroyCudaStreamForDevice(graph_stream, device_id); + if (owns_graph_stream) { + DestroyCudaStreamForDevice(graph_stream, device_id); + } graph_stream = nullptr; } int device_id; + bool owns_graph_stream; cudaStream_t graph_stream = nullptr; CudaGraphManager cuda_graph; size_t pre_capture_free_mem = 0; @@ -391,8 +404,14 @@ OrtStatus* ORT_API_CALL CudaEp::CreateSyncStreamForDeviceImpl( auto cuda_stream = std::make_unique(ep->factory_, device_id, this_ptr); - if (ep->config_.has_user_compute_stream && ep->config_.user_compute_stream != nullptr) { - // Wrap the user-provided external CUDA stream with full cuBLAS/cuDNN handles. + if (ep->config_.has_user_compute_stream) { + // A user-provided compute stream is honored for kernels regardless of whether CUDA graph + // capture is enabled - this branch is taken in both graph and non-graph runs. Use the caller's + // intent flag rather than checking the handle for non-null: cudaStream_t(0) / nullptr is the + // valid CUDA default stream and can be selected explicitly by the user. Wrap the external CUDA + // stream with full cuBLAS/cuDNN handles. When CUDA graph capture is also enabled, + // capture/replay run on this same user stream (see GetPerThreadContext), so kernels and graph + // capture share one stream. RETURN_IF_ERROR(cuda_stream->InitHandlesWithUserStream( static_cast(ep->config_.user_compute_stream))); } else if (ep->config_.enable_cuda_graph) { @@ -439,15 +458,18 @@ OrtStatus* ORT_API_CALL CudaEp::SyncImpl(OrtEp* this_ptr) noexcept { /*static*/ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl( OrtEp* this_ptr, bool* is_supported) noexcept { + ORT_UNUSED_PARAMETER(this_ptr); + if (is_supported == nullptr) { return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null."); } auto* ep = static_cast(this_ptr); - // When a unified stream is in use (either from user_compute_stream, external - // allocator, or explicit use_ep_level_unified_stream), all operations share a - // single stream so concurrent runs are not safe. - *is_supported = !ep->config_.use_ep_level_unified_stream; + // Concurrent runs require stream-tagged scratch allocations. The plugin kernel adapter can tag + // scratch chunks only when the hosting ORT runtime exposes KernelContext_GetSyncStream. + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + *is_supported = !ep->config_.use_ep_level_unified_stream && + ::onnxruntime::ep::CurrentOrtApiVersion() >= kOrtKernelContextGetSyncStreamMinVersion; return nullptr; } @@ -467,7 +489,25 @@ CudaEp::PerThreadContext& CudaEp::GetPerThreadContext() const { return *cached_context_it->second; } - auto context = std::make_shared(config_.device_id); + // NOTE: `enable_cuda_graph` in this condition does NOT restrict using a user compute stream to + // the graph case. A user compute stream is honored for kernels in BOTH graph and non-graph runs + // — that happens in CreateSyncStreamForDeviceImpl(), which wraps config_.user_compute_stream + // independently of enable_cuda_graph. This flag only governs the PerThreadContext's *graph + // stream*, and PerThreadContext is a graph-capture-only object: GetPerThreadContext() is reached + // exclusively from the graph path (CreateSyncStreamForDeviceImpl's enable_cuda_graph branch, + // OnRunStart/OnRunEnd, IsGraphCaptured, ReplayGraph). With graph disabled, no PerThreadContext is + // ever constructed, so its stream ownership is irrelevant. + // + // When a user compute stream IS combined with CUDA graph capture, capture/replay must run on the + // user's stream (the same stream the kernels are issued to) rather than a separate EP-owned + // stream. The user owns the stream's lifetime, so the context must not destroy it. Derive this + // from the caller's intent (has_user_compute_stream && enable_cuda_graph), not from whether the + // handle is null: a user may explicitly choose the CUDA default stream (nullptr), which is still + // an external stream that the context must not own/destroy. + const bool use_external_stream = config_.has_user_compute_stream && config_.enable_cuda_graph; + cudaStream_t external_stream = + use_external_stream ? static_cast(config_.user_compute_stream) : nullptr; + auto context = std::make_shared(config_.device_id, use_external_stream, external_stream); PerThreadContext& context_ref = *context; { std::lock_guard lock(per_thread_contexts_mutex_); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index cb021034662e8..d445d8bab033c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -609,12 +609,9 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( "CUDA plugin EP does not support using both user_compute_stream and external allocator simultaneously."); } - // Validate: user_compute_stream and cuda graph cannot both be active. - if (config.has_user_compute_stream && config.enable_cuda_graph) { - return factory->ort_api_.CreateStatus( - ORT_INVALID_ARGUMENT, - "CUDA plugin EP does not support using both user_compute_stream and enable_cuda_graph simultaneously."); - } + // user_compute_stream and enable_cuda_graph CAN be combined: when both are set, CUDA graph + // capture/replay runs on the user-provided stream (the same stream kernels are issued to), + // matching the bundled CUDA EP behavior. See CudaEp::GetPerThreadContext. // When user_compute_stream is set, force unified stream mode (matches bundled EP behavior). if (config.has_user_compute_stream) { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 021acaf142435..06fe635e35716 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -16,6 +16,7 @@ #pragma once #include +#include #include "core/common/status.h" #include "core/common/narrow.h" @@ -53,6 +54,81 @@ namespace onnxruntime { struct CudaStream; +namespace cuda_plugin { +namespace detail { +inline thread_local std::unordered_map stream_to_framework_stream; +inline thread_local void* current_cuda_stream = nullptr; +inline thread_local onnxruntime::Stream* current_framework_stream = nullptr; + +inline void RegisterFrameworkStreamForCudaStream(void* cuda_stream, OrtSyncStream* framework_stream) { + current_cuda_stream = cuda_stream; + current_framework_stream = reinterpret_cast(framework_stream); + + if (current_framework_stream == nullptr) { + return; + } + + // Map only from the raw cudaStream_t handle to the current framework stream. The framework + // stream is already handled directly by GetFrameworkStreamForStreamArg, so we deliberately do + // not insert a framework_stream -> framework_stream entry: it would be unused and would grow the + // thread-local map without bound while retaining framework stream pointers past the + // Session::Run() teardown lifetime documented for KernelContext_GetSyncStream. + if (cuda_stream != nullptr) { + stream_to_framework_stream[cuda_stream] = current_framework_stream; + } +} + +inline onnxruntime::Stream* GetFrameworkStreamForStreamArg(void* stream) { + // A null stream argument means "the compute stream of the current Compute call". This is the + // form used by GetTransientScratchBuffer and legacy GetScratchBuffer(..., nullptr). Map it to + // the framework stream registered for this call so scratch chunks are still stream-tagged even + // when the kernel runs on a non-default CUDA stream (where current_cuda_stream is non-null and a + // nullptr arg would otherwise miss the map lookup and fall back to a null stream tag). + // + // current_framework_stream is scoped to a single CudaKernel::Compute invocation by + // ComputeStreamScope (see below). Outside any Compute call it is nullptr, so allocations made + // from kernel constructors (which also call GetScratchBuffer(..., nullptr)) fall back to the + // non-stream-tagged path instead of inheriting a stale framework stream pointer whose lifetime + // ended with a previous Session::Run(). + if (stream == nullptr || stream == current_cuda_stream || stream == current_framework_stream) { + return current_framework_stream; + } + + auto it = stream_to_framework_stream.find(stream); + return it == stream_to_framework_stream.end() ? nullptr : it->second; +} + +// RAII guard that scopes the thread-local "current Compute call" framework stream to the lifetime +// of a single CudaKernel::Compute invocation on a worker thread. +// +// On entry it clears current_cuda_stream/current_framework_stream so that scratch allocated before +// the kernel registers its stream (via Stream(ctx)/GetComputeStream(ctx)/GetOrtStream(ctx)), or via +// a nullptr stream argument, does not inherit a stale framework stream left over from a previous +// Compute call on this worker thread. On exit it restores the previous values, which keeps nested +// Compute calls (a kernel that invokes another kernel's Compute) correct and leaves the per-thread +// "current" stream cleared once the outermost Compute returns. The borrowed framework stream is +// only valid until its owning Session::Run() completes teardown, so it must not outlive the call. +struct ComputeStreamScope { + ComputeStreamScope() + : saved_cuda_stream_(current_cuda_stream), + saved_framework_stream_(current_framework_stream) { + current_cuda_stream = nullptr; + current_framework_stream = nullptr; + } + ~ComputeStreamScope() { + current_cuda_stream = saved_cuda_stream_; + current_framework_stream = saved_framework_stream_; + } + + private: + void* saved_cuda_stream_; + onnxruntime::Stream* saved_framework_stream_; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ComputeStreamScope); +}; +} // namespace detail +} // namespace cuda_plugin + // Lightweight Stream shim for plugin build: wraps a raw cudaStream_t as a // framework-compatible Stream* that can be passed to _impl.cu functions which // call stream->GetHandle(). Stack-allocated; does NOT own the stream. @@ -70,6 +146,11 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : plugin_stream_shim_(cuda_stream_handle), stream_(&plugin_stream_shim_) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : plugin_stream_shim_(cuda_stream_handle), + stream_(framework_stream == nullptr ? static_cast(&plugin_stream_shim_) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -83,6 +164,10 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : stream_(static_cast(cuda_stream_handle)) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : stream_(framework_stream == nullptr ? static_cast(cuda_stream_handle) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -868,6 +953,11 @@ class CudaKernel : public OpKernel { } virtual ~CudaKernel() = default; Status Compute(OpKernelContext* ctx) const { + // Scope the thread-local "current Compute call" framework stream to this invocation so that + // scratch tagged via a nullptr stream argument never inherits a stale framework stream from a + // previous Compute call (or leaks one to a later kernel constructor) on this worker thread. + cuda_plugin::detail::ComputeStreamScope compute_stream_scope; + // Ensure the correct CUDA device is active for this kernel. // Worker threads default to device 0; sessions on device > 0 need an // explicit cudaSetDevice. Skip during CUDA graph capture because @@ -903,17 +993,27 @@ class CudaKernel : public OpKernel { cudaStream_t Stream(OpKernelContext* ctx) const { if (!ctx) return nullptr; - return static_cast(ctx->GetGPUComputeStream()); + // Register the framework sync stream for this Compute call so that scratch allocated via + // GetTransientScratchBuffer()/GetScratchBuffer(..., nullptr) is still stream-tagged for kernels + // that call Stream(ctx) before GetComputeStream()/GetOrtStream() (e.g. conv algo search). + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return static_cast(cuda_stream); } // Returns an opaque stream pointer for passing to GetScratchBuffer/AddDeferredReleaseCPUPtr/CopyToGpu. // Returns void* for dual-build compatibility: framework wraps Stream*, plugin wraps cudaStream_t. inline void* GetComputeStream(OpKernelContext* ctx) const { - return ctx->GetGPUComputeStream(); + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return cuda_stream; } inline onnxruntime::OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { - return onnxruntime::OrtStreamAdapter(GetComputeStream(ctx)); + void* cuda_stream = ctx->GetGPUComputeStream(); + OrtSyncStream* framework_stream = ctx->GetSyncStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, framework_stream); + return onnxruntime::OrtStreamAdapter(cuda_stream, framework_stream); } static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { @@ -1023,56 +1123,42 @@ class CudaKernel : public OpKernel { template using IAllocatorUniquePtr = std::unique_ptr>; template - inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* stream) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); - size_t sz = 0; - if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { - ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); - } - void* p = nullptr; - cudaError_t alloc_result = cudaSuccess; - bool used_async_alloc = false; - if (s) { - // Note: stream-ordered allocations (cudaMallocAsync/cudaFreeAsync) rely on CUDA Memory Pools, - // which are not supported on NVIDIA GPUs with Multi-Instance GPU (MIG) enabled. - // On such instances, this will return cudaErrorNotSupported. - alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); - used_async_alloc = (alloc_result == cudaSuccess); - if (!used_async_alloc && (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue)) { - cudaGetLastError(); // Clear the thread-local error state - alloc_result = cudaMalloc(&p, sz); - } - } else { - alloc_result = cudaMalloc(&p, sz); - } - - if (alloc_result != cudaSuccess) { - ORT_THROW("CUDA scratch buffer allocation failed for ", sz, " bytes: ", cudaGetErrorString(alloc_result)); - } - - return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { - if (ptr) { - // Guard: only attempt async free if the stream is still registered. - // CudaSyncStream::~CudaSyncStream guarantees UnregisterStream() is - // called before cudaStreamDestroy(), so a non-null lookup here means - // the raw cudaStream_t handle is still valid. - if (used_async_alloc && s && - cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)) != nullptr) { - // As noted above, cudaFreeAsync may also return cudaErrorNotSupported on MIG-enabled instances. - cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); - if (free_result == cudaSuccess) { - return; - } - cudaGetLastError(); // Clear any error set by cudaFreeAsync - } - - // Fall back to synchronous free if async free is unsupported or if the - // stream is no longer registered. cudaFree is valid for allocations - // returned by cudaMallocAsync and avoids using a stale stream handle. - cudaFree(ptr); - } - }); + // Route kernel scratch/workspace allocations through the EP allocator + // (a BFC arena by default) instead of raw cudaMallocAsync/cudaMalloc. + // + // The arena pre-reserves device memory and reuses freed chunks across runs. + // Once the model has executed `min_num_runs_before_cuda_graph_capture` + // warmup runs, the arena has grown to its steady-state working set, so the + // capture run serves every scratch allocation from an already-reserved chunk + // without issuing a fresh cudaMalloc. This keeps the device free-memory + // footprint stable across the capture window, which is required for correct + // CUDA graph capture/replay. + // + // The previous behavior (cudaMallocAsync/cudaMalloc allocated-and-freed per + // call) allocated new device memory on every run, including the capture run, + // so no amount of warmup could stabilize it and the + // "GPU memory was allocated during CUDA graph capture" detector would trip. + // This now matches the built-in (non-plugin) CUDA EP, which also obtains + // scratch from Info().GetAllocator() (see core/providers/cuda/cuda_kernel.h). + // The overflow check that the previous hand-rolled path performed is still + // enforced inside MakeUniquePtr via ValidatedCalcMemSizeForArray (it throws + // on cnt * sizeof(T) overflow). + // + // The `stream` argument is the raw cudaStream_t used by migrated CUDA kernels, or a Stream* + // from OrtStreamAdapter in code paths that need stream->GetHandle(). Stream-aware arena + // allocation needs the stable framework Stream* wrapper instead, because the arena stores it + // in each chunk and later queries sync ids through the EP stream API. Stream(ctx), + // GetComputeStream(ctx) and GetOrtStream(ctx) record the mapping from both argument forms to + // the framework stream for the current Compute call. + // If the negotiated ORT API version does not include KernelContext_GetSyncStream, the lookup + // returns null and allocation falls back to the non-stream-tagged path. + auto* framework_stream = cuda_plugin::detail::GetFrameworkStreamForStreamArg(stream); + return ::onnxruntime::IAllocator::MakeUniquePtr( + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), cnt, /*use_reserve*/ false, + framework_stream); } template inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { diff --git a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu index b06a640fb72a1..9abbb93d228f4 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu @@ -22,7 +22,12 @@ namespace cuda { // see https://github.com/NVIDIA/cub/issues/384 struct CastToInt32 { __host__ __device__ int32_t operator()(int8_t v) const { - return static_cast(v); + // Normalize to {0, 1} so the prefix-sum sizing path agrees with the truthiness predicate + // (condition_data[div]) used in _CompressKernel. A bool byte may hold any non-zero value; + // sign-extending it here would size the output differently from how elements are selected. + // bool initializers are normalized to {0, 1} when unpacked (see tensorprotoutils.cc), so the + // remaining source of non-canonical bytes is runtime-produced bool condition tensors. + return v != 0 ? 1 : 0; } }; diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index 230b0b495bfbf..0202789a8777d 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -2,12 +2,87 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/shape_op.h" +#endif #include "core/providers/cuda/cuda_fwd.h" +#ifdef BUILD_CUDA_EP_AS_PLUGIN +#include +#endif + namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// The bundled CUDA EP registers the CPU `onnxruntime::Shape` kernel (which +// derives from the framework `OpKernel`) and only marks its output as CPU +// memory. That class cannot be registered through the plugin EP's adapter +// kernel machinery, so the plugin build provides an adapter-based Shape kernel +// with identical semantics. Shape only reads the input's shape metadata (never +// its data) and writes the dims to a CPU output, so registering it on the CUDA +// EP keeps the node inside the device partition and avoids the device->host +// Memcpy node that the framework would otherwise insert to feed an isolated CPU +// Shape node -- a Memcpy that would prevent CUDA Graph capture. The output is +// still CPU memory, so a downstream device consumer may still need a copy; this +// removes the graph-breaking input-side Memcpy, it does not eliminate all copies. +class Shape final : public CudaKernel { + public: + explicit Shape(const OpKernelInfo& info) : CudaKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs a 1D int64 tensor (on CPU memory) + // containing the shape of the input tensor. + Status ComputeInternal(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = static_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), + static_cast(true_start), + static_cast(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; +#endif // BUILD_CUDA_EP_AS_PLUGIN + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5758ff3ad2847..f586fc8e117a6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1871,13 +1871,13 @@ Status QnnBackendManager::ExtractBackendProfilingInfo(qnn::profile::ProfilingInf // ETW disabled previously, but enabled now if (ProfilingLevel::INVALID == profiling_level_etw_ && tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } // ETW enabled previously, but disabled now if (ProfilingLevel::INVALID != profiling_level_etw_ && !tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } diff --git a/onnxruntime/core/providers/rknpu/onnx_converter.cc b/onnxruntime/core/providers/rknpu/onnx_converter.cc index 30de1af8fa7bd..bab76de49519e 100644 --- a/onnxruntime/core/providers/rknpu/onnx_converter.cc +++ b/onnxruntime/core/providers/rknpu/onnx_converter.cc @@ -1,6 +1,7 @@ // Copyright 2020 rock-chips.com Inc. #include +#include #include #include #include @@ -10,7 +11,9 @@ #include #include #include +#include "core/common/common.h" #include "core/common/logging/logging.h" +#include "core/common/safeint.h" #include "onnx_converter.h" #include "node_attr_helper.h" @@ -119,12 +122,28 @@ OnnxConverter::CreateRknnTensor(const std::string& name, return graph_->CreateTensor(attr, (void*)data); } +static uint32_t ToRknpuDim(int64_t dim, const std::string& name) { + ORT_ENFORCE(dim >= 0 && dim <= static_cast(std::numeric_limits::max()), + "RKNPU: tensor dimension out of uint32_t range (name=", name, ", dim=", dim, ")"); + + return static_cast(dim); +} + +// Allocates a zero-initialized bias buffer for `count` elements of `element_size` +// bytes, used when a Conv/Gemm node omits its bias input. SafeInt provides +// overflow-checked size arithmetic (throws on size_t overflow); std::make_unique +// zero-initializes and owns the buffer. +static std::unique_ptr AllocZeroedBias(size_t element_size, uint32_t count) { + const size_t num_bytes = SafeInt(element_size) * count; + return std::make_unique(num_bytes); +} + void OnnxConverter::HandleInitializer() { for (const auto& tensor : model_proto_.graph().initializer()) { const std::string name = tensor.name(); std::vector dims; for (const auto dim : tensor.dims()) { - dims.push_back(static_cast(dim)); + dims.push_back(ToRknpuDim(dim, name)); } if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { const char* ptr = tensor.float_data().empty() @@ -186,7 +205,7 @@ std::vector> OnnxConverter::GetInputOfOnnxModel( for (const auto& dim : input.type().tensor_type().shape().dim()) { if (dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue) { - shape.push_back(static_cast(dim.dim_value())); + shape.push_back(ToRknpuDim(dim.dim_value(), input.name())); } else { throw std::invalid_argument( "The input of graph doesn't have dim_value"); @@ -267,7 +286,7 @@ Shaper::Shape GetShape(const ONNX_NAMESPACE::ModelProto& model_proto, for (const auto& dim : value_info.type().tensor_type().shape().dim()) { if (dim.has_dim_value()) { - shape.push_back(dim.dim_value()); + shape.push_back(ToRknpuDim(dim.dim_value(), value_info.name())); } else { break; } @@ -548,7 +567,7 @@ std::vector> OnnxConverter::GetSupportedNodes( const std::string name = tensor.name(); std::vector dims; for (const auto dim : tensor.dims()) { - dims.push_back(static_cast(dim)); + dims.push_back(ToRknpuDim(dim, name)); } tensor_dims_[name] = dims; } @@ -814,9 +833,6 @@ void OnnxConverter::Clear() { rk_tensors_.clear(); shaper_.Clear(); - for (const auto p : free_list_) { - if (p) free(p); - } free_list_.clear(); } @@ -944,9 +960,8 @@ void OnnxConverter::AddLayerConvImpl(const std::string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); @@ -1053,9 +1068,8 @@ void OnnxConverter::AddLayerQLinearConvImpl(const string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(int32_t) * dim); - memset(ptr, 0, sizeof(int32_t) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(int32_t), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST, @@ -1142,9 +1156,8 @@ void OnnxConverter::AddLayerDepthwiseConvImpl( } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); @@ -1376,9 +1389,8 @@ void OnnxConverter::AddLayerFC(const std::string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); diff --git a/onnxruntime/core/providers/rknpu/onnx_converter.h b/onnxruntime/core/providers/rknpu/onnx_converter.h index 10cc09a9dba92..41d50d6b9401f 100644 --- a/onnxruntime/core/providers/rknpu/onnx_converter.h +++ b/onnxruntime/core/providers/rknpu/onnx_converter.h @@ -63,7 +63,7 @@ class OnnxConverter { // for GetSupportedNodes std::map> tensor_dims_; - std::vector free_list_; // remember free + std::vector> free_list_; // owns implicit-bias buffers std::pair, FuseCode> FindActivation(const ONNX_NAMESPACE::ModelProto& model_proto, diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 2c8b81e4ffefe..89969172c1bdc 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -206,6 +206,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKe }); }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + auto* stream = reinterpret_cast(context)->GetComputeStream(); + *out = reinterpret_cast(stream); + return nullptr; + }); +}; + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc index 458c47bcb58cd..5b1a8460d1e1f 100644 --- a/onnxruntime/core/session/experimental_c_api.cc +++ b/onnxruntime/core/session/experimental_c_api.cc @@ -6,12 +6,30 @@ #include #include +#include +#include "core/common/common.h" #include "core/framework/error_code_helper.h" +#include "core/framework/ep_context_options.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/ort_apis.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/model_compilation_options.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +// Backing definition of the OrtEpContextConfig handle used by the experimental OrtEpApi_* EPContext data functions. +// Holds copies of the application's EPContext read/write callbacks and opaque state extracted from an +// OrtSessionOptions instance. +struct OrtEpContextConfig { + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; +}; + // --------------------------------------------------------------------------- // Experimental function implementations // --------------------------------------------------------------------------- @@ -40,6 +58,94 @@ ORT_API_STATUS_IMPL(OrtApi_ExperimentalApiTest_SinceV28, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28, _Inout_ OrtSessionOptions* options, + _In_opt_ OrtReadNamedBufferFunc read_func, _In_opt_ void* state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + + // Passing a null read_func clears any previously set callback. Clear the state too so a stale state pointer is + // never paired with a missing callback. + options->value.ep_context_data_read_func = read_func; + options->value.ep_context_data_read_state = read_func != nullptr ? state : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_opt_ OrtWriteNamedBufferFunc write_func, _In_opt_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF(ort_model_compile_options == nullptr, ORT_INVALID_ARGUMENT, "OrtModelCompilationOptions is NULL"); + + // A null write_func clears any previously set callback (symmetric with OrtApi_SessionOptions_SetEpContextDataReadFunc + // and consistent with calling this multiple times to overwrite the callback). + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + model_compile_options->SetEpContextDataWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtEpContextConfig** config) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(session_options == nullptr, ORT_INVALID_ARGUMENT, "OrtSessionOptions is NULL"); + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "Output OrtEpContextConfig is NULL"); + + auto ep_context_config = std::make_unique(); + if (const auto* write_config = session_options->value.ep_context_gen_options.TryGetEpContextDataWriteFunc()) { + ep_context_config->write_func = write_config->write_func; + ep_context_config->write_state = write_config->state; + } + ep_context_config->read_func = session_options->value.ep_context_data_read_func; + ep_context_config->read_state = session_options->value.ep_context_data_read_state; + + *config = ep_context_config.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtEpApi_ReleaseEpContextConfig_SinceV28, _Frees_ptr_opt_ OrtEpContextConfig* config) { + delete config; +} + +ORT_API_STATUS_IMPL(OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28, + _In_ const OrtEpContextConfig* config, + _Out_ OrtReadNamedBufferFunc* read_func, + _Out_ void** state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "OrtEpContextConfig is NULL"); + ORT_API_RETURN_IF(read_func == nullptr, ORT_INVALID_ARGUMENT, "Output read_func is NULL"); + ORT_API_RETURN_IF(state == nullptr, ORT_INVALID_ARGUMENT, "Output state is NULL"); + + *read_func = config->read_func; + *state = config->read_func != nullptr ? config->read_state : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28, + _In_ const OrtEpContextConfig* config, + _Out_ OrtWriteNamedBufferFunc* write_func, + _Out_ void** state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "OrtEpContextConfig is NULL"); + ORT_API_RETURN_IF(write_func == nullptr, ORT_INVALID_ARGUMENT, "Output write_func is NULL"); + ORT_API_RETURN_IF(state == nullptr, ORT_INVALID_ARGUMENT, "Output state is NULL"); + + *write_func = config->write_func; + *state = config->write_func != nullptr ? config->write_state : nullptr; + return nullptr; + API_IMPL_END +} + } // namespace OrtExperimentalApis // --------------------------------------------------------------------------- diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index a393bb42fe2cb..9f6d1f9f1a9bc 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -133,6 +133,15 @@ void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( }; } +void ModelCompilationOptions::SetEpContextDataWriteFunc(OrtWriteNamedBufferFunc write_func, void* state) { + // A null write_func clears any previously set callback. Clear the state too so a stale state pointer is never + // paired with a missing callback. + session_options_.value.ep_context_gen_options.ep_context_data_write_func = epctx::EpContextDataWriteFuncHolder{ + write_func, + write_func != nullptr ? state : nullptr, + }; +} + Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { @@ -283,9 +292,12 @@ Status ModelCompilationOptions::Check() const { "OrtModel has no graph. Call AddGraphToModel before compilation."); } - if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { + // A model with zero graph inputs is legal (e.g., a graph composed of zero-input + // generator ops like RandomNormal that produces output without external input). + // We still require at least one graph output for the compiled model to be meaningful. + if (input_model_->graph->GetNumOutputs() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "OrtModel graph must have at least one input and one output defined."); + "OrtModel graph must have at least one output defined."); } if (input_model_->domain_to_version.empty()) { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 47529e794677e..a15af565c4d54 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -13,6 +13,7 @@ #include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -97,6 +98,13 @@ class ModelCompilationOptions { void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, void* state); + /// + /// Sets a user-provided function to handle EPContext binary data writes. + /// + /// The user-provided OrtWriteNamedBufferFunc callback used to write EPContext data. + /// The user's state. + void SetEpContextDataWriteFunc(OrtWriteNamedBufferFunc write_func, void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a663d209cfa53..22df898ca3227 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4916,6 +4916,8 @@ static constexpr OrtApi ort_api_1_to_28 = { // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, + + &OrtApis::KernelContext_GetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 61ece2dd9a682..e747d0d0ab2d8 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -196,6 +196,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetInputCount, _In_ const OrtKernelContext* co ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +ORT_API_STATUS_IMPL(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, _Outptr_result_maybenull_ OrtSyncStream** out); // OrtTypeInfo methods ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); diff --git a/onnxruntime/test/autoep/ep_context_data_callbacks.h b/onnxruntime/test/autoep/ep_context_data_callbacks.h new file mode 100644 index 0000000000000..d25c7a5c571ee --- /dev/null +++ b/onnxruntime/test/autoep/ep_context_data_callbacks.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace test { + +// Shared EPContext read/write callback test doubles, used by both the EpContextDataUtils unit tests +// (ep_context_data_utils_test.cc) and the PluginEp end-to-end EPContext tests (test_execution.cc). +struct EpContextDataCallbackState { + bool write_called = false; + bool read_called = false; + std::string write_file_name; + std::string read_file_name; + std::vector payload; +}; + +inline OrtStatus* ORT_API_CALL StoreEpContextDataCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* callback_state = static_cast(state); + callback_state->write_called = true; + callback_state->write_file_name = file_name; + callback_state->payload.clear(); + if (buffer_size != 0) { + if (buffer == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, + "StoreEpContextDataCallback received a null buffer for non-empty data"); + } + callback_state->payload.assign(static_cast(buffer), static_cast(buffer) + buffer_size); + } + return nullptr; +} + +inline OrtStatus* ORT_API_CALL LoadEpContextDataCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* callback_state = static_cast(state); + callback_state->read_called = true; + callback_state->read_file_name = file_name; + + *buffer = nullptr; + *data_size = callback_state->payload.size(); + if (callback_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, callback_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::copy(callback_state->payload.begin(), callback_state->payload.end(), static_cast(*buffer)); + return nullptr; +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/ep_context_data_utils_test.cc b/onnxruntime/test/autoep/ep_context_data_utils_test.cc new file mode 100644 index 0000000000000..8425280e4845b --- /dev/null +++ b/onnxruntime/test/autoep/ep_context_data_utils_test.cc @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for the sample-only EPContext data helpers in +// onnxruntime/test/autoep/library/ep_context_data_utils.h. + +#include +#include +#include +#include + +#include +#include + +#include "core/graph/model_editor_api_types.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" + +#include "test/autoep/ep_context_data_callbacks.h" +#include "test/autoep/library/ep_context_data_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +namespace { + +OrtStatus* ORT_API_CALL LoadInvalidEpContextDataCallback(void* state, const char* file_name, + OrtAllocator* /*allocator*/, void** buffer, + size_t* data_size) { + auto* callback_state = static_cast(state); + callback_state->read_called = true; + callback_state->read_file_name = file_name; + + *buffer = nullptr; + *data_size = 1; + return nullptr; +} + +void ExpectOrtStatusError(OrtStatus* status_ptr, OrtErrorCode expected_code, std::string_view expected_message) { + Ort::Status status{status_ptr}; + ASSERT_NE(status_ptr, nullptr) << "Expected a failure status, but the API returned nullptr (OK)."; + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), expected_code); + EXPECT_THAT(std::string{status.GetErrorMessage()}, ::testing::HasSubstr(std::string{expected_message})); +} + +std::filesystem::path PrepareTempTestDir(std::string_view name) { + std::filesystem::path test_dir = std::string{name}; + std::filesystem::remove_all(test_dir); + std::filesystem::create_directories(test_dir); + return test_dir; +} + +} // namespace + +TEST(OrtEpLibrary, EpContextDataUtils_PathHelpersRoundTrip) { + const auto& api = Ort::GetApi(); + const std::string file_name = "context_data.bin"; + +#ifdef _WIN32 + std::wstring wide_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8ToWideString(api, file_name, wide_file_name)); + ASSERT_FALSE(wide_file_name.empty()); + std::string round_tripped_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WideToUtf8String(api, wide_file_name, round_tripped_file_name)); + EXPECT_EQ(round_tripped_file_name, file_name); + + const std::string invalid_utf8(1, static_cast(0xff)); + std::wstring invalid_wide; + ExpectOrtStatusError(ep_context_data_utils::Utf8ToWideString(api, invalid_utf8, invalid_wide), + ORT_INVALID_ARGUMENT, "not valid UTF-8"); + EXPECT_TRUE(invalid_wide.empty()); +#endif + + std::filesystem::path file_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, file_name.c_str(), file_path)); + ASSERT_FALSE(file_path.empty()); + std::string round_tripped_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, file_path, round_tripped_path)); + EXPECT_EQ(round_tripped_path, file_name); + + std::filesystem::path empty_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, nullptr, empty_path)); + EXPECT_TRUE(empty_path.empty()); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, "", empty_path)); + EXPECT_TRUE(empty_path.empty()); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathAndInvalidArguments) { + const auto& api = Ort::GetApi(); + std::filesystem::path data_path; + + data_path = "stale.ctx"; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, nullptr, nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + EXPECT_TRUE(data_path.empty()); + + data_path = "stale.ctx"; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "", nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + EXPECT_TRUE(data_path.empty()); + + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ResolveEpContextDataPath(api, "relative.ctx", nullptr, data_path)); + std::string resolved_data_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, data_path, resolved_data_path)); + EXPECT_EQ(resolved_data_path, "relative.ctx"); + + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataToFile(api, "unused.ctx", nullptr, nullptr, 1), + ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback(api, nullptr, "unused.ctx", nullptr, + nullptr, 1), + ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataWithFileFallback(api, nullptr, "", nullptr, data), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback(api, nullptr, "", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "logical_context_data.bin", "", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data fallback file name must not be empty"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathRejectsUnsafeNames) { + const auto& api = Ort::GetApi(); + std::filesystem::path data_path; + + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "../escape.ctx", nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + EXPECT_TRUE(data_path.empty()); + +#ifdef _WIN32 + const char* absolute_file_name = "C:\\temp\\escape.ctx"; + const char* drive_relative_file_name = "C:escape.ctx"; + const char* root_relative_file_name = "\\escape.ctx"; +#else + const char* absolute_file_name = "/tmp/escape.ctx"; +#endif + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ResolveEpContextDataPath(api, absolute_file_name, nullptr, data_path)); + EXPECT_TRUE(data_path.is_absolute()); + +#ifdef _WIN32 + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, drive_relative_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, root_relative_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); +#endif + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataFromFile(api, "../escape.ctx", nullptr, data), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, absolute_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + + ModelEditorGraph empty_model_path_graph; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "../escape.ctx", + empty_model_path_graph.ToExternal(), data_path), + ORT_INVALID_ARGUMENT, "requires a model path"); + + // A model-derived name that designates a directory ("." or a trailing separator with an empty filename) is + // rejected up front, rather than resolving to a directory and failing later with a confusing I/O error. + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, ".", empty_model_path_graph.ToExternal(), + data_path), + ORT_INVALID_ARGUMENT, "must refer to a file"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, ".", "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "must refer to a file"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "sub/", "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "must refer to a file"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathRejectsSymlinkEscape) { + const auto& api = Ort::GetApi(); + const std::filesystem::path test_dir = PrepareTempTestDir("ort_ep_context_data_utils_symlink_escape_test"); + auto cleanup = gsl::finally([&]() { std::filesystem::remove_all(test_dir); }); + + const std::filesystem::path model_dir = test_dir / "model_dir"; + const std::filesystem::path outside_dir = test_dir / "outside_dir"; + ASSERT_TRUE(std::filesystem::create_directories(model_dir)); + ASSERT_TRUE(std::filesystem::create_directories(outside_dir)); + + const std::filesystem::path symlink_path = model_dir / "linked_outside"; + // Relative symlink targets are resolved by the OS relative to the link's own directory, not the test's working + // directory. Point to the sibling outside_dir using a link-relative target; using the test_dir-relative + // `outside_dir` path here would create a dangling link under model_dir, and weakly_canonical() would not traverse it. + const std::filesystem::path symlink_target = std::filesystem::path{".."} / outside_dir.filename(); + std::error_code symlink_error; + std::filesystem::create_directory_symlink(symlink_target, symlink_path, symlink_error); + if (symlink_error) { + GTEST_SKIP() << "Unable to create directory symlink for containment test: " << symlink_error.message(); + } + + ModelEditorGraph graph; + graph.model_path = model_dir / "model.onnx"; + + std::filesystem::path data_path; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "linked_outside/escape.ctx", + graph.ToExternal(), data_path), + ORT_INVALID_ARGUMENT, "resolve to a path within the model directory"); + EXPECT_TRUE(data_path.empty()); +} + +TEST(OrtEpLibrary, EpContextDataUtils_FileFallbackReadsAndWrites) { + const auto& api = Ort::GetApi(); + const std::filesystem::path test_dir = PrepareTempTestDir("ort_ep_context_data_utils_file_fallback_test"); + auto cleanup = gsl::finally([&]() { std::filesystem::remove_all(test_dir); }); + + const std::string payload = "file fallback payload"; + const std::filesystem::path data_path = test_dir / "context_data.bin"; + std::string data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, data_path, data_file_name)); + + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataToFile(api, data_file_name.c_str(), nullptr, + payload.data(), payload.size())); + + std::vector data; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(api, data_file_name.c_str(), nullptr, data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path wrapper_data_path = test_dir / "wrapper_context_data.bin"; + std::string wrapper_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, wrapper_data_path, wrapper_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, wrapper_data_file_name.c_str(), nullptr, payload.data(), payload.size())); + + data.clear(); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, nullptr, wrapper_data_file_name.c_str(), nullptr, data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path fallback_data_path = test_dir / "fallback_context_data.bin"; + std::string fallback_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, fallback_data_path, fallback_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "logical_context_data.bin", fallback_data_file_name.c_str(), nullptr, payload.data(), + payload.size())); + + const std::filesystem::path unsafe_logical_fallback_path = test_dir / "unsafe_logical_context_data.bin"; + std::string unsafe_logical_fallback_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, unsafe_logical_fallback_path, + unsafe_logical_fallback_file_name)); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "../logical_context_data.bin", + unsafe_logical_fallback_file_name.c_str(), nullptr, payload.data(), payload.size()), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + + data.clear(); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(api, fallback_data_file_name.c_str(), nullptr, + data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path empty_data_path = test_dir / "empty_context_data.bin"; + std::string empty_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, empty_data_path, empty_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, empty_data_file_name.c_str(), nullptr, nullptr, 0)); + + data.assign({'s', 't', 'a', 'l', 'e'}); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, nullptr, empty_data_file_name.c_str(), nullptr, data)); + EXPECT_TRUE(data.empty()); + + const std::filesystem::path missing_data_path = test_dir / "missing_context_data.bin"; + std::string missing_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, missing_data_path, missing_data_file_name)); + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataFromFile(api, missing_data_file_name.c_str(), nullptr, + data), + ORT_FAIL, "Failed to open EPContext data file for read"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_CallbackFallbackUsesCallbacks) { + const auto& api = Ort::GetApi(); + + EpContextDataCallbackState read_callback_state; + read_callback_state.payload = {'c', 'a', 'l', 'l', 'b', 'a', 'c', 'k'}; + EpContextDataCallbackState write_callback_state; + + std::vector data; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, LoadEpContextDataCallback, &read_callback_state, "callback_context.bin", nullptr, data)); + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, "callback_context.bin"); + EXPECT_EQ(data, read_callback_state.payload); + + const std::string payload = "callback write payload"; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, StoreEpContextDataCallback, &write_callback_state, "callback_write_context.bin", + "callback_write_context.bin", nullptr, payload.data(), payload.size())); + ASSERT_TRUE(write_callback_state.write_called); + EXPECT_EQ(write_callback_state.write_file_name, "callback_write_context.bin"); + EXPECT_EQ(std::string(write_callback_state.payload.begin(), write_callback_state.payload.end()), payload); + + write_callback_state = {}; + const std::string payload_with_unused_fallback = "callback write payload with unused fallback"; + // With a callback present the file fallback is never used, so the empty fallback name is accepted (not validated). + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, StoreEpContextDataCallback, &write_callback_state, + "callback_write_context_unused_fallback.bin", "", nullptr, + payload_with_unused_fallback.data(), payload_with_unused_fallback.size())); + ASSERT_TRUE(write_callback_state.write_called); + EXPECT_EQ(write_callback_state.write_file_name, "callback_write_context_unused_fallback.bin"); + EXPECT_EQ(std::string(write_callback_state.payload.begin(), write_callback_state.payload.end()), + payload_with_unused_fallback); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ReadCallbackRejectsNullBufferForNonEmptyPayload) { + const auto& api = Ort::GetApi(); + + EpContextDataCallbackState read_callback_state; + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, LoadInvalidEpContextDataCallback, &read_callback_state, + "invalid_callback_context.bin", nullptr, data), + ORT_FAIL, "OrtReadNamedBufferFunc returned a null buffer for non-empty EPContext data"); + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, "invalid_callback_context.bin"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/library/ep_context_data_utils.h b/onnxruntime/test/autoep/library/ep_context_data_utils.h new file mode 100644 index 0000000000000..a3f9a377ee92f --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_context_data_utils.h @@ -0,0 +1,501 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +// Define NOMINMAX (and WIN32_LEAN_AND_MEAN) before so the min/max macros it would otherwise pull in do +// not clobber std::numeric_limits<...>::max() and std::min/std::max used in this header. +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#include "plugin_ep_utils.h" +#include "onnxruntime_experimental_cxx_api.h" + +// Sample-only EPContext data helpers shared by the example plugin EP and its tests. These are intentionally outside +// the ORT C and EP ABI and are provided as a reference for EP authors that need to handle external (non-embedded) +// EPContext binary data. +// +// The intended entry points for EP implementers are the ReadEpContextDataWithFileFallback / +// WriteEpContextDataWithFileFallback overloads: they prefer an application-supplied OrtReadNamedBufferFunc / +// OrtWriteNamedBufferFunc (carried by OrtEpContextConfig) and fall back to file I/O when no callback is configured. +// The other functions are lower-level building blocks. Production EPs should additionally apply their own sandboxing, +// size limits, and path policies; see the per-function notes on how untrusted, model-derived names are treated. +namespace ep_context_data_utils { + +#ifdef _WIN32 +inline std::string WindowsLastErrorMessage(std::string_view message, DWORD error_code) { + return std::string{message} + " GetLastError=" + std::to_string(error_code); +} + +// Converts a UTF-8 string to a wide string. Reports conversion failures (e.g., invalid UTF-8) via OrtStatus* instead +// of silently returning an empty string. An empty input yields an empty output and a success status. +inline OrtStatus* Utf8ToWideString(const OrtApi& api, std::string_view value, std::wstring& wide_value) { + wide_value.clear(); + if (value.empty()) { + return nullptr; + } + if (value.size() > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is too long to convert"); + } + + const int wide_length = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, value.data(), + static_cast(value.size()), nullptr, 0); + if (wide_length <= 0) { + const std::string message = WindowsLastErrorMessage( + "EPContext data file name is not valid UTF-8 or could not be converted to a wide string.", GetLastError()); + return api.CreateStatus(ORT_INVALID_ARGUMENT, message.c_str()); + } + + wide_value.resize(static_cast(wide_length)); + const int converted = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, value.data(), + static_cast(value.size()), wide_value.data(), wide_length); + if (converted != wide_length) { + wide_value.clear(); + const std::string message = WindowsLastErrorMessage("Failed to convert EPContext data file name to a wide string.", + GetLastError()); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + return nullptr; +} + +// Converts a wide string to UTF-8. Reports conversion failures via OrtStatus* instead of silently returning an empty +// string. An empty input yields an empty output and a success status. +inline OrtStatus* WideToUtf8String(const OrtApi& api, std::wstring_view value, std::string& utf8_value) { + utf8_value.clear(); + if (value.empty()) { + return nullptr; + } + if (value.size() > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is too long to convert"); + } + + const int utf8_length = WideCharToMultiByte(CP_UTF8, 0, value.data(), static_cast(value.size()), + nullptr, 0, nullptr, nullptr); + if (utf8_length <= 0) { + const std::string message = WindowsLastErrorMessage( + "EPContext data file name could not be converted to UTF-8.", GetLastError()); + return api.CreateStatus(ORT_INVALID_ARGUMENT, message.c_str()); + } + + utf8_value.resize(static_cast(utf8_length)); + const int converted = WideCharToMultiByte(CP_UTF8, 0, value.data(), static_cast(value.size()), + utf8_value.data(), utf8_length, nullptr, nullptr); + if (converted != utf8_length) { + utf8_value.clear(); + const std::string message = WindowsLastErrorMessage("Failed to convert EPContext data file name to UTF-8.", + GetLastError()); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + return nullptr; +} +#endif + +// Converts a UTF-8 path to a std::filesystem::path. A null or empty input yields an empty path and a success status; +// conversion failures are reported via OrtStatus*. +inline OrtStatus* Utf8Path(const OrtApi& api, const char* path, std::filesystem::path& out_path) { + out_path.clear(); + if (path == nullptr || path[0] == '\0') { + return nullptr; + } + +#ifdef _WIN32 + std::wstring wide_path; + RETURN_IF_ERROR(Utf8ToWideString(api, path, wide_path)); + out_path = std::filesystem::path{wide_path}; +#else + (void)api; + out_path = std::filesystem::path{path}; +#endif + return nullptr; +} + +inline OrtStatus* PathToUtf8String(const OrtApi& api, const std::filesystem::path& path, std::string& utf8_path) { + utf8_path.clear(); +#ifdef _WIN32 + RETURN_IF_ERROR(WideToUtf8String(api, path.wstring(), utf8_path)); +#else + (void)api; + utf8_path = path.string(); +#endif + return nullptr; +} + +inline std::string PathToUtf8StringForMessage(const std::filesystem::path& path) { + std::string utf8_path; + Ort::Status status{PathToUtf8String(Ort::GetApi(), path, utf8_path)}; + return status.IsOK() ? utf8_path : std::string{""}; +} + +// Lexical check for a ".." component. This is a coarse guard used when there is no filesystem base directory to +// contain against: logical callback-namespace names and trusted (graph == nullptr) physical paths. It is NOT a +// containment mechanism: it does not resolve symlinks and it rejects benign cases such as "a/b/c/../file.txt". +// Filesystem containment against a model directory is done by IsResolvedPathWithinBase() below, which the untrusted +// (model-relative) resolution path uses. +inline bool ContainsPathTraversal(const std::filesystem::path& path) { + const std::filesystem::path parent_dir{".."}; + for (const auto& component : path) { + if (component == parent_dir) { + return true; + } + } + return false; +} + +inline bool HasAbsoluteOrRootedPath(const std::filesystem::path& path) { + return path.is_absolute() || path.has_root_name() || path.has_root_directory(); +} + +// Returns true if the final component of `path` is empty (e.g., a trailing separator like "sub/") or is the +// current-directory entry ".", i.e. the name designates a directory rather than a file (".." is handled separately by +// ContainsPathTraversal()). Such a name resolves to a directory and would only surface later as a confusing file I/O +// failure, so model-derived names like these are rejected up front. +inline bool IsDirectoryOrEmptyName(const std::filesystem::path& path) { + const std::filesystem::path leaf = path.filename(); + return leaf.empty() || leaf == std::filesystem::path{"."}; +} + +// Returns true if `candidate_full` (a base-relative name already combined with `base`) resolves to a location inside +// `base`. Both are normalized with std::filesystem::weakly_canonical, which resolves "." / ".." and any symlinks in +// the existing portion of the path, so a name that escapes `base` directly or through a symlink is rejected. On +// success the canonical resolved path is written to `resolved`. +inline bool IsResolvedPathWithinBase(const std::filesystem::path& base, const std::filesystem::path& candidate_full, + std::filesystem::path& resolved) { + std::error_code ec; + const std::filesystem::path base_for_canon = base.empty() ? std::filesystem::path{"."} : base; + const std::filesystem::path canonical_base = std::filesystem::weakly_canonical(base_for_canon, ec); + if (ec) { + return false; + } + std::filesystem::path candidate_resolved = std::filesystem::weakly_canonical(candidate_full, ec); + if (ec) { + return false; + } + const std::filesystem::path relative = candidate_resolved.lexically_relative(canonical_base); + if (relative.empty() || *relative.begin() == std::filesystem::path{".."}) { + return false; + } + + resolved = std::move(candidate_resolved); + return true; +} + +inline OrtStatus* ValidateEpContextDataName(const OrtApi& api, const char* file_name, + std::filesystem::path& data_name) { + data_name.clear(); + + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + std::filesystem::path candidate_path; + RETURN_IF_ERROR(Utf8Path(api, file_name, candidate_path)); + if (candidate_path.empty()) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is not a valid path"); + } + + if (HasAbsoluteOrRootedPath(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + } + + if (ContainsPathTraversal(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + } + + if (IsDirectoryOrEmptyName(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must refer to a file, not a directory"); + } + + data_name = candidate_path; + return nullptr; +} + +// Resolves `file_name` to a filesystem path for reading or writing EPContext data (used by both the read path and +// the write-fallback path). +// +// When `graph` is null the caller is trusted and owns the path: `file_name` is returned as-is and may be absolute (a +// lexical ".." is still rejected as a coarse guard). When `graph` is non-null, `file_name` originates from the +// untrusted EPContext model "ep_cache_context" attribute: the graph must have a model path, the name must be +// relative, and after combining it with the model's directory the result must stay within that directory. Symlinks and +// ".." are resolved (via weakly_canonical), so a name that escapes the model directory - including through a symlink - +// is rejected. +// This helper only decides whether a model-derived file name resolves inside the model directory. Production EPs +// should still choose an application-approved storage root (sandbox), reject special files/locations as appropriate, +// and cap the number of bytes they will read or write for a single EPContext payload. +inline OrtStatus* ResolveEpContextDataPath(const OrtApi& api, const char* file_name, const OrtGraph* graph, + std::filesystem::path& data_path) { + data_path.clear(); + + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + std::filesystem::path candidate_path; + RETURN_IF_ERROR(Utf8Path(api, file_name, candidate_path)); + if (candidate_path.empty()) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is not a valid path"); + } + + // Trusted direct callers (graph == nullptr) own the path and may pass an absolute physical path. + if (graph == nullptr) { + if (ContainsPathTraversal(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + } + data_path = candidate_path; + return nullptr; + } + + // Untrusted (model-derived) name: must be relative and must resolve within the model directory. + if (HasAbsoluteOrRootedPath(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + } + + if (IsDirectoryOrEmptyName(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must refer to a file, not a directory"); + } + + const ORTCHAR_T* model_path = nullptr; + RETURN_IF_ERROR(api.Graph_GetModelPath(graph, &model_path)); + if (model_path == nullptr || model_path[0] == 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, + "EPContext data file fallback requires a model path to resolve relative names"); + } + + const std::filesystem::path base_dir = std::filesystem::path{model_path}.parent_path(); + std::filesystem::path resolved; + if (!IsResolvedPathWithinBase(base_dir, base_dir / candidate_path, resolved)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, + "EPContext data file name must resolve to a path within the model directory"); + } + + data_path = resolved; + return nullptr; +} + +inline OrtStatus* WriteEpContextDataToResolvedFile(const OrtApi& api, const std::filesystem::path& data_path, + const void* buffer, size_t buffer_size) { + std::ofstream output_stream(data_path, std::ios::binary); + if (!output_stream) { + const std::string message = "Failed to open EPContext data file for write: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + if (buffer_size != 0) { + if (buffer_size > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer is too large to write"); + } + + output_stream.write(static_cast(buffer), static_cast(buffer_size)); + if (!output_stream) { + const std::string message = "Failed to write EPContext data file: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + } + + return nullptr; +} + +inline OrtStatus* ReadEpContextDataFromFile(const OrtApi& api, const char* file_name, const OrtGraph* graph, + std::vector& data) { + data.clear(); + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, file_name, graph, data_path)); + + std::ifstream input_stream(data_path, std::ios::binary); + if (!input_stream) { + const std::string message = "Failed to open EPContext data file for read: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + data.assign(std::istreambuf_iterator{input_stream}, std::istreambuf_iterator{}); + if (!input_stream) { + const std::string message = "Failed to read EPContext data file: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + return nullptr; +} + +inline OrtStatus* WriteEpContextDataToFile(const OrtApi& api, const char* file_name, const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + if (buffer == nullptr && buffer_size != 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + } + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, file_name, graph, data_path)); + return WriteEpContextDataToResolvedFile(api, data_path, buffer, buffer_size); +} + +// Low-level overload that takes the read callback and its opaque state directly. Production EPs should use the +// overload below that takes an OrtEpContextConfig; this overload exists so unit tests can inject a callback without +// constructing an OrtEpContextConfig. When `read_func` is null the data is read from a file. +inline OrtStatus* ReadEpContextDataWithFileFallback( + const OrtApi& api, + OrtReadNamedBufferFunc read_func, void* read_state, + const char* file_name, const OrtGraph* graph, + std::vector& data) { + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + if (read_func == nullptr) { + return ReadEpContextDataFromFile(api, file_name, graph, data); + } + + // Use the C allocator API (not Ort::AllocatorWithDefaultOptions, whose constructor throws) so this OrtStatus*-based + // helper stays exception-free. The default allocator is owned by ORT and must not be released here. + OrtAllocator* allocator = nullptr; + RETURN_IF_ERROR(api.GetAllocatorWithDefaultOptions(&allocator)); + void* ep_context_data = nullptr; + size_t ep_context_data_size = 0; + OrtStatus* status = read_func(read_state, file_name, allocator, &ep_context_data, &ep_context_data_size); + auto buffer_deleter = [&api, allocator](void* buffer_to_free) { + if (buffer_to_free != nullptr) { + // Best-effort free during cleanup; release any returned status without throwing. + Ort::Status free_status{api.AllocatorFree(allocator, buffer_to_free)}; + static_cast(free_status); + } + }; + std::unique_ptr ep_context_data_guard(ep_context_data, buffer_deleter); + + if (status != nullptr) { + return status; + } + + if (ep_context_data_size != 0 && ep_context_data == nullptr) { + return api.CreateStatus( + ORT_FAIL, "OrtReadNamedBufferFunc returned a null buffer for non-empty EPContext data"); + } + + data.clear(); + if (ep_context_data != nullptr) { + const char* ep_context_data_begin = static_cast(ep_context_data); + data.assign(ep_context_data_begin, ep_context_data_begin + ep_context_data_size); + } + + return nullptr; +} + +// Reads EPContext binary data named `file_name`. If the session configured an OrtReadNamedBufferFunc (carried by +// `ep_context_config`), it is used; otherwise the data is read from a file. When `graph` is non-null it is the +// EPContext model graph: untrusted absolute/rooted/traversal names are rejected and relative names are resolved +// against the model directory. Pass `graph == nullptr` only for trusted callers supplying a physical path. `data` is +// cleared first and receives the bytes on success. +inline OrtStatus* ReadEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const OrtGraph* graph, + std::vector& data) { + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + if (ep_context_config != nullptr) { + auto get_read_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_Fn(&api); + if (get_read_func == nullptr) { + return api.CreateStatus(ORT_NOT_IMPLEMENTED, + "OrtEpApi_EpContextConfig_GetEpContextDataReadFunc is not available"); + } + RETURN_IF_ERROR(get_read_func(ep_context_config, &read_func, &read_state)); + } + return ReadEpContextDataWithFileFallback(api, read_func, read_state, file_name, graph, data); +} + +// Low-level overload that takes the write callback and its opaque state directly. Production EPs should use the +// overloads below that take an OrtEpContextConfig; this overload exists so unit tests can inject a callback without +// constructing an OrtEpContextConfig. When `write_func` is null the data is written to the file fallback. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + OrtWriteNamedBufferFunc write_func, void* write_state, + const char* file_name, const char* fallback_file_name, + const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + if (buffer == nullptr && buffer_size != 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + } + + // The app-supplied write callback owns its own logical namespace, so file_name is passed through unmodified. + // Only the file-fallback path below maps a name onto the filesystem, so it validates the logical name there. + if (write_func != nullptr) { + return write_func(write_state, file_name, buffer, buffer_size); + } + + // Even when the physical fallback path is supplied separately, `file_name` is the logical name written into the + // EPContext model's ep_cache_context attribute. Validate it as a safe relative name so a generated model cannot + // contain an unsafe logical reference that later reaches the read-side resolver. + std::filesystem::path logical_path; + RETURN_IF_ERROR(ValidateEpContextDataName(api, file_name, logical_path)); + + if (fallback_file_name == nullptr || fallback_file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data fallback file name must not be empty"); + } + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, fallback_file_name, graph, data_path)); + return WriteEpContextDataToResolvedFile(api, data_path, buffer, buffer_size); +} + +// Writes EPContext binary data. If the compilation configured an OrtWriteNamedBufferFunc (carried by +// `ep_context_config`), it is used and `file_name` is passed through unmodified as the logical name. Otherwise the +// data is written to a file at `fallback_file_name`, which is resolved against the model directory when `graph` is +// non-null (and rejected if absolute or rooted in that case). `graph == nullptr` denotes a trusted caller that may +// supply an absolute physical path. `buffer` may be null only when `buffer_size` is 0. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const char* fallback_file_name, + const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + if (ep_context_config != nullptr) { + auto get_write_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_Fn(&api); + if (get_write_func == nullptr) { + return api.CreateStatus(ORT_NOT_IMPLEMENTED, + "OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc is not available"); + } + RETURN_IF_ERROR(get_write_func(ep_context_config, &write_func, &write_state)); + } + return WriteEpContextDataWithFileFallback(api, write_func, write_state, file_name, fallback_file_name, graph, buffer, + buffer_size); +} + +// Convenience overload that uses `file_name` as both the logical callback name and the file-fallback path. +// Because `file_name` doubles as the fallback path, it must be a safe relative name (this overload validates it and +// rejects absolute/rooted paths and `..` traversal). To write the file fallback to an absolute physical path (a +// trusted caller with `graph == nullptr`), use the overload above that takes a separate `fallback_file_name`. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + return WriteEpContextDataWithFileFallback(api, ep_context_config, file_name, file_name, graph, buffer, buffer_size); +} + +} // namespace ep_context_data_utils diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index fecf7ac9a4038..90ad4e7976824 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -4,15 +4,17 @@ #include "ep.h" #include +#include #include #include -#include +#include #include #include #include +#include #include -#include +#include "../ep_context_data_utils.h" #include "ep_factory.h" #include "ep_stream_support.h" @@ -167,13 +169,15 @@ struct EpContextNodeComputeInfo : NodeComputeInfoBase { ExampleEp& ep; }; -ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger) +ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::Experimental::EpContextConfig ep_context_config) : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized ApiPtrs{static_cast(factory)}, factory_{factory}, name_{name}, config_{config}, - logger_{logger} { + logger_{logger}, + ep_context_config_{std::move(ep_context_config)} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. // Initialize the execution provider's function table @@ -193,8 +197,6 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C ORT_FILE, __LINE__, __FUNCTION__)); } -ExampleEp::~ExampleEp() = default; - /*static*/ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept { const auto* ep = static_cast(this_ptr); @@ -409,6 +411,26 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const auto fused_node_name = fused_node.GetName(); if (is_ep_context_node) { + Ort::ConstOpAttr embed_mode_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("embed_mode", embed_mode_attr)); + int64_t embed_mode = 1; + RETURN_IF_ERROR(embed_mode_attr.GetValue(embed_mode)); + + if (embed_mode == 0) { + Ort::ConstOpAttr ep_cache_context_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("ep_cache_context", ep_cache_context_attr)); + std::string ep_cache_context; + RETURN_IF_ERROR(ep_cache_context_attr.GetValue(ep_cache_context)); + + // This example only exercises the load-side read flow (callback first, file fallback otherwise) to show how + // an EP retrieves EPContext binary data during compile. A real EP would consume `ep_context_data` (e.g., + // initialize a kernel/engine from it); here it is intentionally read and then discarded. + std::vector ep_context_data; + RETURN_IF_ERROR(ep_context_data_utils::ReadEpContextDataWithFileFallback( + ep->ort_api, ep->ep_context_config_.get(), ep_cache_context.c_str(), ort_graphs[0], + ep_context_data)); + } + // Create EpContextKernel for EPContext nodes - clearly separates from MulKernel ep->ep_context_kernels_.emplace(fused_node_name, std::make_unique(ep->ort_api, ep->logger_)); @@ -449,7 +471,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext). if (ep->config_.enable_ep_context) { assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + RETURN_IF_ERROR(ep->CreateEpContextNodes(ort_graphs[0], gsl::span(fused_nodes, count), gsl::span(ep_context_nodes, count))); } } @@ -479,7 +501,8 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // Creates EPContext nodes from the given fused nodes. // This is an example implementation that can be used to generate an EPContext model. However, this example EP // cannot currently run the EPContext model. -OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, +OrtStatus* ExampleEp::CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { try { assert(fused_nodes.size() == ep_context_nodes.size()); @@ -512,11 +535,32 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes collect_input_output_names(fused_node_outputs, /*out*/ output_names); int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; + int64_t embed_mode = config_.embed_ep_context_in_model ? 1 : 0; // Create node attributes. The CreateNode() function copies the attributes. std::array attributes = {}; - std::string ep_ctx = "binary_data"; + std::string ep_ctx = config_.embed_ep_context_in_model ? "binary_data" : fused_node_name + ".ctx"; + if (!config_.embed_ep_context_in_model) { + const std::string ep_context_data = "binary_data"; + std::string fallback_ep_ctx = ep_ctx; + const OrtGraph* fallback_graph = graph; + if (!config_.ep_context_output_model_path.empty()) { + std::filesystem::path output_model_path; + RETURN_IF_ERROR(ep_context_data_utils::Utf8Path(ort_api, config_.ep_context_output_model_path.c_str(), + output_model_path)); + const std::filesystem::path output_model_dir = output_model_path.parent_path(); + if (!output_model_dir.empty()) { + std::filesystem::path ep_ctx_path; + RETURN_IF_ERROR(ep_context_data_utils::Utf8Path(ort_api, ep_ctx.c_str(), ep_ctx_path)); + RETURN_IF_ERROR(ep_context_data_utils::PathToUtf8String(ort_api, output_model_dir / ep_ctx_path, + fallback_ep_ctx)); + } + fallback_graph = nullptr; + } + RETURN_IF_ERROR(ep_context_data_utils::WriteEpContextDataWithFileFallback( + ort_api, ep_context_config_.get(), ep_ctx.c_str(), fallback_ep_ctx.c_str(), fallback_graph, + ep_context_data.data(), ep_context_data.size())); + } attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), ORT_OP_ATTR_STRING); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 5dcd9f07bef1f..4112abb723d39 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -6,6 +6,7 @@ #include #include "../plugin_ep_utils.h" +#include "onnxruntime_experimental_cxx_api.h" class ExampleEpFactory; @@ -61,13 +62,16 @@ class ExampleEp : public OrtEp, public ApiPtrs { public: struct Config { bool enable_ep_context = false; + bool embed_ep_context_in_model = false; bool enable_weightless_ep_context_nodes = false; + std::string ep_context_output_model_path; // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) }; - ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger); + ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::Experimental::EpContextConfig ep_context_config); - ~ExampleEp(); + ~ExampleEp() = default; std::unordered_map>& MulKernels() { return mul_kernels_; @@ -108,7 +112,8 @@ class ExampleEp : public OrtEp, public ApiPtrs { static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, _Outptr_ const OrtMemoryDevice** device) noexcept; - OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + OrtStatus* CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); // Returns true if the EP should save constant initializers so that they are available during inference. @@ -122,6 +127,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { std::string name_; Config config_{}; const OrtLogger& logger_; + Ort::Experimental::EpContextConfig ep_context_config_; std::unordered_map> mul_kernels_; std::unordered_map> ep_context_kernels_; std::unordered_map float_initializers_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 875f70bd29f3c..a323ec0e8c15e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -196,6 +196,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, const OrtSessionOptions* session_options, const OrtLogger* logger, OrtEp** ep) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN auto* factory = static_cast(this_ptr); *ep = nullptr; @@ -219,20 +220,33 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, // Create EP configuration from session options, if needed. // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. std::string ep_context_enable; + std::string ep_context_embed_mode; + std::string ep_context_output_model_path; std::string weightless_ep_context_nodes_enable; RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEnable, "0", ep_context_enable)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEmbedMode, "0", + ep_context_embed_mode)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextFilePath, "", + ep_context_output_model_path)); RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpEnableWeightlessEpContextNodes, "0", weightless_ep_context_nodes_enable)); ExampleEp::Config config = {}; config.enable_ep_context = ep_context_enable == "1"; + config.embed_ep_context_in_model = ep_context_embed_mode == "1"; + config.ep_context_output_model_path = std::move(ep_context_output_model_path); config.enable_weightless_ep_context_nodes = weightless_ep_context_nodes_enable == "1"; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); - + // The EpContextConfig wrapper extracts the EPContext callbacks from the session options and owns the handle. It + // throws if the experimental functions are unavailable or extraction fails; EXCEPTION_TO_RETURNED_STATUS_END + // converts that (and any other exception thrown in this function) into an OrtStatus. + auto dummy_ep = std::make_unique( + *factory, factory->ep_name_, config, *logger, + Ort::Experimental::EpContextConfig{Ort::ConstSessionOptions{session_options}}); *ep = dummy_ep.release(); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } /*static*/ diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 93633f9a375bb..e95918c719324 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include +#include #include // #include #include @@ -12,11 +14,14 @@ #include "core/graph/constants.h" #include "core/graph/onnx_protobuf.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "nlohmann/json.hpp" +#include "test/autoep/ep_context_data_callbacks.h" #include "test/autoep/test_autoep_utils.h" +#include "test/autoep/library/ep_context_data_utils.h" #include "test/autoep/library/example_plugin_ep/ep_test_hooks.h" #include "test/shared_lib/utils.h" #include "test/util/include/api_asserts.h" @@ -29,6 +34,51 @@ namespace test { namespace { +// Invokes the experimental EPContext read setter on the public C API. +void SetEpContextDataReadFunc(Ort::SessionOptions& session_options, OrtReadNamedBufferFunc read_func, void* state) { + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&Ort::GetApi()); + ASSERT_ORTSTATUS_OK(set_read_func(session_options, read_func, state)); +} + +// Invokes the experimental EPContext write setter on the public C API. +void SetEpContextDataWriteFunc(Ort::ModelCompilationOptions& compile_options, OrtWriteNamedBufferFunc write_func, + void* state) { + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &Ort::GetApi()); + ASSERT_ORTSTATUS_OK(set_write_func(compile_options, write_func, state)); +} + +void LoadModelProtoFromFile(const ORTCHAR_T* model_file, ONNX_NAMESPACE::ModelProto& model_proto) { + std::ifstream model_stream{std::filesystem::path(model_file), std::ios::binary}; + ASSERT_TRUE(model_stream.is_open()); + ASSERT_TRUE(model_proto.ParseFromIstream(&model_stream)); +} + +std::vector GetEpContextNodes(const ONNX_NAMESPACE::ModelProto& model_proto) { + std::vector ep_context_nodes; + + for (const auto& node : model_proto.graph().node()) { + if (node.domain() == kMSDomain && node.op_type() == "EPContext") { + ep_context_nodes.push_back(&node); + } + } + + return ep_context_nodes; +} + +const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const ONNX_NAMESPACE::NodeProto& node, + std::string_view attribute_name) { + for (const auto& attribute : node.attribute()) { + if (attribute.name() == attribute_name) { + return &attribute; + } + } + + return nullptr; +} + void RunMulModelWithPluginEp(const ORTCHAR_T* model_path, const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, model_path, session_options); @@ -521,6 +571,222 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { } } +TEST(OrtEpLibrary, PluginEp_GenEpContextModel_EmbedModeDoesNotUseCallbacks) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_embedded_ctx.onnx"); + std::filesystem::remove(output_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(output_model_file); }); + + EpContextDataCallbackState write_callback_state; + EpContextDataCallbackState compile_read_callback_state; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &compile_read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(true); + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &write_callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + } + + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + EXPECT_FALSE(write_callback_state.write_called); + EXPECT_FALSE(compile_read_callback_state.read_called); + + ONNX_NAMESPACE::ModelProto compiled_model; + ASSERT_NO_FATAL_FAILURE(LoadModelProtoFromFile(output_model_file, compiled_model)); + + auto ep_context_nodes = GetEpContextNodes(compiled_model); + ASSERT_EQ(ep_context_nodes.size(), 1u); + + const ONNX_NAMESPACE::AttributeProto* embed_mode_attr = GetNodeAttribute(*ep_context_nodes[0], "embed_mode"); + ASSERT_NE(embed_mode_attr, nullptr); + EXPECT_EQ(embed_mode_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + EXPECT_EQ(embed_mode_attr->i(), 1); + + const ONNX_NAMESPACE::AttributeProto* ep_cache_context_attr = GetNodeAttribute(*ep_context_nodes[0], + "ep_cache_context"); + ASSERT_NE(ep_cache_context_attr, nullptr); + EXPECT_EQ(ep_cache_context_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + EXPECT_EQ(ep_cache_context_attr->s(), "binary_data"); + + EpContextDataCallbackState load_read_callback_state; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &load_read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, output_model_file, session_options); + } + + EXPECT_FALSE(load_read_callback_state.read_called); +} + +TEST(OrtEpLibrary, PluginEp_GenAndLoadEpContextModel_ExternalDataUsesFileFallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_file_ctx.onnx"); + std::vector files_to_cleanup{std::filesystem::path{output_model_file}}; + for (const auto& path : files_to_cleanup) { + std::filesystem::remove(path); + } + auto cleanup = gsl::finally([&]() { + for (const auto& path : files_to_cleanup) { + std::filesystem::remove(path); + } + }); + + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(false); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + } + + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + + ONNX_NAMESPACE::ModelProto compiled_model; + ASSERT_NO_FATAL_FAILURE(LoadModelProtoFromFile(output_model_file, compiled_model)); + + auto ep_context_nodes = GetEpContextNodes(compiled_model); + ASSERT_EQ(ep_context_nodes.size(), 1u); + + const ONNX_NAMESPACE::AttributeProto* embed_mode_attr = GetNodeAttribute(*ep_context_nodes[0], "embed_mode"); + ASSERT_NE(embed_mode_attr, nullptr); + EXPECT_EQ(embed_mode_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + EXPECT_EQ(embed_mode_attr->i(), 0); + + const ONNX_NAMESPACE::AttributeProto* ep_cache_context_attr = GetNodeAttribute(*ep_context_nodes[0], + "ep_cache_context"); + ASSERT_NE(ep_cache_context_attr, nullptr); + EXPECT_EQ(ep_cache_context_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + ASSERT_FALSE(ep_cache_context_attr->s().empty()); + + const std::filesystem::path output_model_dir = std::filesystem::path{output_model_file}.parent_path(); + std::filesystem::path ep_cache_context_rel; + ASSERT_ORTSTATUS_OK( + ep_context_data_utils::Utf8Path(Ort::GetApi(), ep_cache_context_attr->s().c_str(), ep_cache_context_rel)); + const std::filesystem::path context_data_path = output_model_dir / ep_cache_context_rel; + files_to_cleanup.push_back(context_data_path); + ASSERT_TRUE(std::filesystem::exists(context_data_path)); + + std::vector context_data; + std::string context_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(Ort::GetApi(), context_data_path, + context_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(Ort::GetApi(), context_data_file_name.c_str(), + nullptr, context_data)); + EXPECT_EQ(std::string(context_data.begin(), context_data.end()), "binary_data"); + + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, output_model_file, session_options); + } +} + +TEST(OrtEpLibrary, PluginEp_GenEpContextModel_ExternalDataUsesWriteCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx.onnx"); + std::filesystem::remove(output_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(output_model_file); }); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + EpContextDataCallbackState callback_state; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(false); + ASSERT_NO_FATAL_FAILURE(SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + ASSERT_TRUE(callback_state.write_called); + EXPECT_FALSE(callback_state.write_file_name.empty()); + EXPECT_EQ(std::string(callback_state.payload.begin(), callback_state.payload.end()), "binary_data"); +} + +TEST(OrtEpLibrary, PluginEp_LoadEpContextModel_ExternalDataUsesReadCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx_load.onnx"); + std::filesystem::remove(compiled_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(compiled_model_file); }); + + EpContextDataCallbackState write_callback_state; + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + compile_options.SetEpContextEmbedMode(false); + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &write_callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + ASSERT_TRUE(write_callback_state.write_called); + } + + EpContextDataCallbackState read_callback_state; + read_callback_state.payload = write_callback_state.payload; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE(SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, compiled_model_file, session_options); + } + + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, write_callback_state.write_file_name); +} + TEST(OrtEpLibrary, PluginEp_GenWeightlessEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index ee5c8bb567e1e..c9065878da1ce 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -621,6 +621,11 @@ TEST(ModelPackageTest, CheckCompiledModelCompatibilityInfo) { compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + // Embed the EPContext binary data inside the compiled model so the model is self-contained. + // This test copies only the compiled .onnx into the model package, so it must not rely on a + // separate sidecar EPContext data file (which non-embedded mode would produce). + compile_options.SetEpContextEmbedMode(true); + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); ASSERT_TRUE(std::filesystem::exists(output_model_file)); } diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index d2001cfb9f2bd..e9e7e20271090 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -446,6 +446,29 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyWhisperWithSequenceInputIds) { + // dummy_whisper_with_sequence_input_ids.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_whisper_model_generator.py \ + // --output-path dummy_whisper_with_sequence_input_ids.onnx --sequence-as-input + // The decoder subgraph leaves input_ids second dim symbolic, so the decoder feeds are built from the + // running sequence (use_sequence_as_input_ids_ == true), exercising the multi-token initial feed path. + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_whisper_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("input_features", {1, 8, 5}, + {-0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, + 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, + 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, + -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f}); + tester.AddInput("decoder_input_ids", {1, 2}, {2, 5}); + tester.AddOutput("sequences", {1, 1, 10}, {2, 5, 1, 1, 1, 1, 1, 1, 1, 1}); + tester.AddOutput("scores", {1, 1}, {-0.05625312775373459f}, false /* sort_output */, 1e-4f /* rel_error */, + 1e-4f /* abs_error */); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + TEST(BeamSearchTest, DummyT5PointerGenerator) { // dummy_t5_pointer_generator.onnx model generated using following command: // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 645564d01abc0..28c41b5bf5ed4 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -1569,6 +1569,109 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_Prompt_CUDA) { ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_Prompt_CUDA_vs_CPU"); } +// CUDA: out-of-range (negative) seqlens_k must not drive an out-of-bounds KV-cache write. +// On the CUDA EP seqlens_k is device-resident, so the host-side range check in the operator is +// skipped and the derived append offset is clamped on the device instead. With sequence_length > 1 +// the non-fast-decode path is taken, exercising both the derived-length clamp and the cache-store +// bound. The run must complete and yield finite outputs. This is a memory-safety regression that is +// most precisely observed under compute-sanitizer, where the pre-clamp code reported an invalid +// device write at this site. +TEST(GroupQueryAttentionTest, NegativeSeqlensK_CacheAppend_NoOOB_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 2; // > 1 forces the non-fast-decode path + constexpr int past_seq_len = 4; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be a multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int total_sequence_length = past_seq_len + sequence_length; + constexpr int present_seq_len = total_sequence_length; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + std::vector past_key_data(batch_size * kv_num_heads * past_seq_len * head_size); + std::vector past_value_data(batch_size * kv_num_heads * past_seq_len * head_size); + for (size_t i = 0; i < query_data.size(); ++i) query_data[i] = 0.05f * static_cast(i % 7 + 1); + for (size_t i = 0; i < key_data.size(); ++i) key_data[i] = 0.04f * static_cast(i % 5 + 1); + for (size_t i = 0; i < value_data.size(); ++i) value_data[i] = 0.03f * static_cast(i % 3 + 1); + for (size_t i = 0; i < past_key_data.size(); ++i) past_key_data[i] = 0.02f * static_cast(i % 11 + 1); + for (size_t i = 0; i < past_value_data.size(); ++i) past_value_data[i] = 0.01f * static_cast(i % 13 + 1); + + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, ToFloat16(query_data)); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(key_data)); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(value_data)); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, ToFloat16(past_key_data)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + ToFloat16(past_value_data)); + + // seqlens_k is negative, so the derived past length, (max(seqlens_k, 0) + 1) - sequence_length, is + // negative (here 0 + 1 - 2 = -1). The device-side derivation must neutralize this so the cache append + // for the new tokens stays within the present buffer instead of indexing before its start. + tester.AddInput("seqlens_k", {batch_size}, {-1}); + // Marked as an initializer so shape inference can read the value at graph-build time and size + // present_kv to max(past_seq_len, total_sequence_length), matching the declared present outputs below. + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, ToFloat16(cos_cache)); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, ToFloat16(sin_cache)); + + // Valid position_ids so the rotary index path is well-formed and only the cache-store bound is stressed. + std::vector position_ids(batch_size * sequence_length); + for (int s = 0; s < sequence_length; ++s) { + position_ids[s] = static_cast(past_seq_len + s); + } + tester.AddInput("position_ids", {batch_size, sequence_length}, position_ids); + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, MLFloat16(0.0f))); + const int present_size = batch_size * kv_num_heads * present_seq_len * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + + // The malformed seqlens_k drives the derived past length negative, which is the condition under test. + // That leaves the KV length under-specified for the query, so the attention is degenerate and its + // outputs may be non-finite; this is expected and intentionally not asserted. The regression point is + // that the cache append and attention complete without indexing outside their buffers (which a + // sanitizer build would otherwise flag), so only the output shape is verified. + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([](const std::vector& fetches, + const std::string& /*provider*/) { + ASSERT_FALSE(fetches.empty()); + ASSERT_TRUE(fetches[0].IsTensor()); + EXPECT_EQ(fetches[0].Get().Shape().Size(), static_cast(output_size)); + }); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // --------------------------------------------------------------------------- // Quantized KV cache tests for CPU GroupQueryAttention // --------------------------------------------------------------------------- @@ -2510,7 +2613,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - GqaTargetEp target_ep = GqaTargetEp::kCpu) { + GqaTargetEp target_ep = GqaTargetEp::kCpu, + bool smooth_softmax = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2529,6 +2633,12 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); + if (smooth_softmax) { + // smooth_softmax disqualifies the WebGPU FlashAttention path via the outer + // gating in GroupQueryAttention::ComputeInternal, routing this case through + // ApplyAttention instead. + tester.AddAttribute("smooth_softmax", static_cast(1)); + } // Packed QKV: pass through `query` input, leave key/value as optional edges. if (use_fp16) { @@ -2619,7 +2729,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, + const std::vector& real_lens = {4, 2, 6}, + bool smooth_softmax = false) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2628,10 +2740,10 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int kv_hidden_size = kv_num_heads * head_size; constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; - // Real prompt lengths per batch; max = sequence_length (right-padding extends + // Per-batch real prompt lengths; max = sequence_length (right-padding extends // shorter batches up to this length). The bug only manifests when at least // one batch is shorter than sequence_length. - const std::vector real_lens = {4, 2, 6}; + ASSERT_EQ(static_cast(real_lens.size()), batch_size); const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); std::vector packed_batched; @@ -2652,7 +2764,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, target_ep); + packed_single, target_ep, smooth_softmax); } // Now run all batches together with right-padding. @@ -2662,7 +2774,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, target_ep); + seqlens_k_data, packed_batched, target_ep, smooth_softmax); // Guard the regression deterministically: every element of the batched output // (including padding rows) must be finite. The CPU root cause is uninitialized @@ -2715,5 +2827,50 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with per-batch +// real_lens whose max crosses the prefill threshold (sequence_length >= 32) so +// the WebGPU EP picks FlashAttentionProgram (single-kernel prefill path with +// subgroup shuffles) instead of the split-reduce decode path. This exercises +// the prefill flash-attention kernel under right-padded batches with do_rotary, +// which is the path used by Phi-style models during batched prefill. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // sequence_length = max(real_lens) = 33 > 32 -> FlashAttentionProgram path. + // Mixed shorter batches (12, 20) ensure right-padding is non-trivial. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); +} + +// Stress the FlashAttention prefill path with a per-batch spread that exceeds +// the indirect-dispatch tile size (64). batch 0 has the SHORTEST real length; +// batch 2 has the LONGEST. This is the data pattern that would surface the +// indirect-dispatch undersizing bug when graph capture is enabled (where the +// dispatch grid is sized from a GPU buffer rather than the host scalar). +// OpTester does not toggle graph capture, so this test exercises the new +// total_sequence_length_input shader plumbing on the non-graph-capture path; +// the graph-capture path is covered end-to-end by phi4-graph-prune verification. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // spread = 96 - 20 = 76 > tile_size(64), batch 0 is not the max. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 96}); +} + +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with +// smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes +// through ApplyAttention (non-flash path). Covers right-padded batched prefill +// on the non-flash attention path (used by e.g. Phi-4 attention variants). +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 9deb064a90853..8e133caa15d55 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -3,6 +3,7 @@ #ifndef ORT_MINIMAL_BUILD +#include #include #include "gtest/gtest.h" @@ -26,6 +27,9 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" +#include "core/graph/model.h" +#include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test/test_environment.h" #include "core/providers/webgpu/webgpu_provider_options.h" #ifdef USE_WEBGPU #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" @@ -461,6 +465,117 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) { TestMatMul2BitsLutGemm(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true); } +// Regression test for the LUT GEMM pre-pack + prepacked-save path. A 2-bit MatMulNBits node pre-packed +// via the LUT path must record its packed B buffer exactly once. A prior bug appended packed_b_ twice +// on the LUT path (inside the LUT branch and again in the shared append at the end of the B block), so +// the second entry was a moved-from/null buffer paired with a non-zero packed_b_size_. The pre-packed +// content hash skips null buffers, so cross-session sharing appeared to work, but saving pre-packed +// initializers iterates every recorded buffer and writes buffer_sizes_[i] bytes from buffers_[i].get(), +// dereferencing the null pointer when mlas.use_lut_gemm=1. This drives mlas.use_lut_gemm=1 together with +// session.save_external_prepacked_constant_initializers=1 and a non-empty optimized_model_filepath, and +// asserts that initialization (which performs the save) and a subsequent run both succeed. +TEST(MatMulNBitsLutGemm, Float32_2Bits_PrepackSaveDoesNotCrash) { + constexpr int64_t M = 1, N = 128, K = 128, block_size = 32; + if (!MlasIsLutGemmAvailable(static_cast(N), static_cast(K), 2, static_cast(block_size))) { + GTEST_SKIP() << "LUT GEMM not available on this platform"; + } + + // Quantize random weights into valid 2-bit MatMulNBits B/scales/zero_points initializers. + RandomValueGenerator random{1234}; + std::vector b_fp32(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + int q_rows = 0, q_cols = 0; + MlasBlockwiseQuantizedShape(static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), q_rows, q_cols); + size_t q_data_size_in_bytes = 0, q_scale_size = 0, q_zp_size_in_bytes = 0; + MlasBlockwiseQuantizedBufferSizes(static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + std::vector b_data(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); + + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); + MlasQuantizeBlockwise(b_data.data(), scales.data(), zp.data(), b_fp32.data(), + static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), + static_cast(N), tp); + + // Single-node MatMulNBits model: A is a runtime input; B/scales/zero_points are constant initializers + // (so they are pre-packed at session initialization). + const int64_t k_blocks = (K + block_size - 1) / block_size; + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("matmul_2bits_lut_prepack_save", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + ONNX_NAMESPACE::TypeProto float_2d; + float_2d.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(M); + float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(K); + NodeArg* A = &graph.GetOrCreateNodeArg("A", &float_2d); + NodeArg* Y = &graph.GetOrCreateNodeArg("Y", nullptr); + + NodeArg* B = builder.MakeInitializer( + {static_cast(q_cols), k_blocks, static_cast(q_rows) / k_blocks}, b_data); + NodeArg* scales_arg = builder.MakeInitializer({N, static_cast(q_scale_size) / N}, scales); + NodeArg* zero_points = + builder.MakeInitializer({N, static_cast(q_zp_size_in_bytes) / N}, zp); + + Node& node = builder.AddNode("MatMulNBits", {A, B, scales_arg, zero_points}, {Y}, kMSDomain); + node.AddAttribute("K", K); + node.AddAttribute("N", N); + node.AddAttribute("block_size", block_size); + node.AddAttribute("bits", static_cast(QBits)); + node.AddAttribute("accuracy_level", static_cast(0)); + + graph.SetOutputs(std::vector{Y}); + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_bytes; + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); + + // Save the optimized model + pre-packed initializers into a unique temp dir. Writing the prepacked + // initializers is the path that dereferenced the duplicate null buffer before the fix. + namespace fs = std::filesystem; + const fs::path tmp_dir = fs::temp_directory_path() / "ort_matmul2bits_lut_prepack_save_test"; + std::error_code ec; + fs::remove_all(tmp_dir, ec); + ASSERT_TRUE(fs::create_directories(tmp_dir, ec)) << ec.message(); + const fs::path optimized_model_path = tmp_dir / "optimized.onnx"; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1")); + so.optimized_model_filepath = optimized_model_path.native(); + + std::vector fetches; + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + // Initialization performs the LUT pre-pack and writes the optimized model with external + // pre-packed initializers. Before the fix this dereferenced the duplicate null packed buffer. + ASSERT_STATUS_OK(session.Initialize()); + + auto cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + std::vector a_data = random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f); + OrtValue a_value; + CreateMLValue(cpu_allocator, AsSpan({M, K}), a_data, &a_value); + NameMLValMap feeds{{"A", a_value}}; + + ASSERT_STATUS_OK(session.Run(RunOptions{}, feeds, std::vector{"Y"}, &fetches)); + } + + ASSERT_EQ(fetches.size(), static_cast(1)); + EXPECT_TRUE(fs::exists(optimized_model_path)); + + fs::remove_all(tmp_dir, ec); +} + // Float zero point tests — directed QAD scenario (zp=1.5) void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) { RandomValueGenerator random{1234}; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 07b275b813aa7..bedf035d320f8 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -22,10 +22,14 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" #include "core/providers/webgpu/webgpu_provider_options.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/util/include/test/test_environment.h" extern std::unique_ptr ort_env; @@ -87,6 +91,10 @@ struct TestOptions { bool legacy_shape{false}; // for backward compatibility + // When set, RunTest validates cross-session sharing of the pre-packed weights instead of doing a + // single run. The model is run in two sessions that use the same pre-packed weights container. + std::optional prepack_sharing_mode{}; + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -269,6 +277,13 @@ void RunTest(const TestOptions& opts, test.SetOutputRelErr("Y", *opts.output_rel_error); } + if (opts.prepack_sharing_mode.has_value()) { + // Pre-packed weight sharing is a CPU-EP-only feature; the helper runs the model on the CPU EP + // in two sessions and validates the sharing counters. + CheckSharedPrepackedWeights(test, *opts.prepack_sharing_mode, {N, k_blocks, blob_size}, input1_vals); + return; + } + if (!explicit_eps.empty()) { test.ConfigEps(std::move(explicit_eps)); } @@ -597,6 +612,55 @@ TEST(MatMulNBits, Float32_4b_Accuracy4_Batch) { RunTest(opts); } +#ifndef ENABLE_TRAINING +// Pre-packing (and therefore cross-session sharing of pre-packed weights) is disabled in a full +// training build, so there is nothing to exercise there. + +namespace { +// Builds a representative MatMulNBits TestOptions for the pre-packed weight sharing tests. +TestOptions MakeSharingTestOptions(int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, + bool has_zero_point, bool has_bias, PrepackSharingMode mode) { + TestOptions opts{}; + opts.M = 8; + opts.N = N; + opts.K = K; + opts.block_size = block_size; + opts.accuracy_level = accuracy_level; + opts.has_zero_point = has_zero_point; + opts.zp_is_4bit = true; + opts.has_bias = has_bias; + opts.prepack_sharing_mode = mode; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + return opts; +} +} // namespace + +// Legacy sharing path: the weight B is registered as a shared initializer via +// SessionOptions::AddInitializer. Covers float and float16 activations, symmetric/asymmetric, +/- bias. +TEST(MatMulNBits, SharedPrepackedWeights_AddInitializer) { + for (bool has_zero_point : {false, true}) { + for (bool has_bias : {false, true}) { + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, has_zero_point, + has_bias, PrepackSharingMode::kAddInitializer)); + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, has_zero_point, + has_bias, PrepackSharingMode::kAddInitializer)); + } + } +} + +// Negative control: with the shared container present but neither opt-in mechanism enabled, no +// pre-packed weights are shared across sessions. +TEST(MatMulNBits, SharedPrepackedWeights_NotSharedWithoutOptIn) { + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, /*has_zero_point*/ true, + /*has_bias*/ true, PrepackSharingMode::kNoSharing)); + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, + /*has_zero_point*/ false, /*has_bias*/ false, + PrepackSharingMode::kNoSharing)); +} + +#endif // !ENABLE_TRAINING + #endif #endif diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index f99334c4f33ef..411e83536c190 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -21,6 +21,7 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -51,6 +52,10 @@ struct TestOptions8Bits { bool has_g_idx{false}; bool has_bias{false}; + // When set, RunTest8Bits validates cross-session sharing of the pre-packed weights instead of + // doing a single run. The model is run in two CPU sessions that use the same container. + std::optional prepack_sharing_mode{}; + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -221,6 +226,14 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.SetOutputRelErr("Y", *opts.output_rel_error); } + if (opts.prepack_sharing_mode.has_value()) { + // Pre-packed weight sharing is a CPU-EP-only feature; the helper runs the model on the CPU EP + // in two sessions and validates the sharing counters. + CheckSharedPrepackedWeights(test, *opts.prepack_sharing_mode, + {q_cols, k_blocks, q_rows / k_blocks}, input1_vals); + return; + } + std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); @@ -671,6 +684,56 @@ TEST(MatMulNBits, BFloat16_Int8_Chunked_BFloat16ZeroPoint) { } #endif +#if !defined(USE_CUDA) && !defined(USE_WEBGPU) +#ifndef ENABLE_TRAINING +// Pre-packing (and therefore cross-session sharing of pre-packed weights) is disabled in a full +// training build and is only implemented for the CPU EP, so these tests are CPU-only. + +namespace { +// Builds a representative 8-bit MatMulNBits TestOptions for the pre-packed weight sharing tests. +// accuracy_level 4 selects the int8 compute type (SQNBIT_CompInt8 / HQNBIT_CompInt8), which is the +// 8-bit path that pre-packs the quantized B weight. +TestOptions8Bits MakeSharingTestOptions8Bits(int64_t block_size, bool has_zero_point, bool has_bias, + PrepackSharingMode mode) { + TestOptions8Bits opts{}; + opts.M = 8; + opts.N = 32; + opts.K = 256; + opts.block_size = block_size; + opts.accuracy_level = 4; + opts.has_zero_point = has_zero_point; + opts.has_bias = has_bias; + opts.prepack_sharing_mode = mode; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + return opts; +} +} // namespace + +// Legacy sharing path for 8-bit weights: B is registered as a shared initializer via +// SessionOptions::AddInitializer. +TEST(MatMulNBits, SharedPrepackedWeights_8b_AddInitializer) { + for (bool has_zero_point : {false, true}) { + for (bool has_bias : {false, true}) { + RunTest8Bits(MakeSharingTestOptions8Bits(32, has_zero_point, has_bias, + PrepackSharingMode::kAddInitializer)); + RunTest8Bits(MakeSharingTestOptions8Bits(32, has_zero_point, has_bias, + PrepackSharingMode::kAddInitializer)); + } + } +} + +// Negative control for 8-bit weights: with the shared container present but neither opt-in mechanism +// enabled, no pre-packed weights are shared across sessions. +TEST(MatMulNBits, SharedPrepackedWeights_8b_NotSharedWithoutOptIn) { + RunTest8Bits(MakeSharingTestOptions8Bits(32, /*has_zero_point*/ true, /*has_bias*/ true, + PrepackSharingMode::kNoSharing)); + RunTest8Bits(MakeSharingTestOptions8Bits(32, /*has_zero_point*/ false, /*has_bias*/ false, + PrepackSharingMode::kNoSharing)); +} +#endif // !ENABLE_TRAINING +#endif // !USE_CUDA && !USE_WEBGPU + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc new file mode 100644 index 0000000000000..97566afe02489 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "core/framework/tensor.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +void CheckSharedPrepackedWeights(OpTester& test, PrepackSharingMode mode, + const std::vector& b_dims, + std::vector& b_data) { + SessionOptions so; + OrtValue b_ortvalue; + + switch (mode) { + case PrepackSharingMode::kAddInitializer: + // Register B as an explicitly shared initializer (the pre-existing sharing mechanism). + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(b_dims), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b_ortvalue); + ASSERT_STATUS_OK(so.AddInitializer("B", &b_ortvalue)); + break; + case PrepackSharingMode::kNoSharing: + // Neither opt-in mechanism is used. + break; + } + + // Have all sessions created by this OpTester use the same pre-packed weights container. + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited to the CPU EP, so the sharing behavior is only exercised there. + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_1, + &number_of_shared_pre_packed_weights_counter); + // Nothing can be shared yet because this is the first session. + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_container = test.GetNumPrePackedWeightsShared(); + + if (mode == PrepackSharingMode::kNoSharing) { + // Without opting in, pre-packed weights must not be placed in the shared container. + ASSERT_EQ(number_of_elements_in_shared_container, static_cast(0)); + } + + // On some platforms/architectures MLAS may choose not to pre-pack, in which case there is nothing + // to share and we cannot meaningfully continue. + if (number_of_pre_packed_weights_counter_session_1 == 0) { + return; + } + + if (mode != PrepackSharingMode::kNoSharing) { + // At least the quantized weight B is content-addressed into the shared container. Some + // architectures (e.g. ARM64 KleidiAI) additionally pre-pack scales, but in the AddInitializer + // mode only the explicitly-registered B participates, so the container can hold fewer elements + // than the total number of pre-packed weights. + ASSERT_GT(number_of_elements_in_shared_container, static_cast(0)); + ASSERT_LE(number_of_elements_in_shared_container, number_of_pre_packed_weights_counter_session_1); + } + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_2, + &number_of_shared_pre_packed_weights_counter); + + // The same number of weights is pre-packed in both sessions. + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Every weight stored in the shared container is served from it (i.e. shared) in the second + // session. For the no-sharing control this is zero; otherwise it matches the container size. + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, number_of_elements_in_shared_container); + + if (mode == PrepackSharingMode::kNoSharing) { + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } else { + ASSERT_GT(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h new file mode 100644 index 0000000000000..1de0bbaa4bb85 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include + +namespace onnxruntime { +namespace test { + +class OpTester; + +// How two sessions are configured to share the pre-packed weights of a MatMulNBits node. +enum class PrepackSharingMode { + // Legacy path: the weight is explicitly registered as a shared initializer via + // SessionOptions::AddInitializer. + kAddInitializer, + // Negative control: the shared container exists but neither opt-in mechanism is used, so no + // cross-session sharing must happen. + kNoSharing, +}; + +// Runs the already-configured MatMulNBits OpTester in two CPU sessions that share the same +// pre-packed weights container and asserts that the pre-packed weights are shared as expected. +// This logic is independent of the weight bit width, so it is shared by the 4-bit and 8-bit tests. +// `b_dims`/`b_data` describe the quantized B initializer and are only needed for the +// PrepackSharingMode::kAddInitializer path (to register B as a shared initializer). +void CheckSharedPrepackedWeights(OpTester& test, PrepackSharingMode mode, + const std::vector& b_dims, + std::vector& b_data); + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc b/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc index 71d8c34353f02..d7953442d738e 100644 --- a/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc @@ -24,10 +24,10 @@ namespace test { namespace { -void RunSparseAttentionInvalidInputTest(const std::vector& total_key_lengths_data, - const std::vector& total_key_lengths_dims, - const std::string& expected_error, - int32_t total_sequence_length = 4) { +void RunSparseAttentionInvalidKeyLengthsTest(const std::vector& total_key_lengths_data, + const std::vector& total_key_lengths_dims, + const std::string& expected_error, + int32_t total_sequence_length = 4) { OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", 2); test.AddAttribute("kv_num_heads", 2); @@ -42,7 +42,7 @@ void RunSparseAttentionInvalidInputTest(const std::vector& total_key_le test.AddInput("past_key", {1, 2, 4, 8}, std::vector(64, 0.0f)); test.AddInput("past_value", {1, 2, 4, 8}, std::vector(64, 0.0f)); test.AddInput("block_row_indices", {1, 5}, {0, 1, 2, 3, 4}); - test.AddInput("block_col_indices", {1, 1}, {0}); + test.AddInput("block_col_indices", {1, 4}, {0, 1, 2, 3}); test.AddInput("total_sequence_length", {1}, {total_sequence_length}); test.AddInput("key_total_sequence_lengths", total_key_lengths_dims, total_key_lengths_data); test.AddOptionalInputEdge(); @@ -209,17 +209,17 @@ void RunSparseAttentionPromptInputTest(const std::vector& total_key_len } // namespace TEST(SparseAttentionTest, RejectsOutOfRangeKeyTotalSequenceLengths) { - RunSparseAttentionInvalidInputTest({-5}, {1}, "key_total_sequence_lengths value -5 at batch index 0 is out of range [1, 4]"); + RunSparseAttentionInvalidKeyLengthsTest({-5}, {1}, "key_total_sequence_lengths value -5 at batch index 0 is out of range [1, 4]"); } TEST(SparseAttentionTest, RejectsKeyTotalSequenceLengthsShapeMismatch) { - RunSparseAttentionInvalidInputTest({4, 4}, {2}, "key_total_sequence_lengths must have shape (batch_size)"); + RunSparseAttentionInvalidKeyLengthsTest({4, 4}, {2}, "key_total_sequence_lengths must have shape (batch_size)"); } TEST(SparseAttentionTest, RejectsPromptKeyTotalSequenceLengthsShorterThanSequenceLength) { - RunSparseAttentionInvalidInputTest({0}, {1}, - "key_total_sequence_lengths value 0 at batch index 0 is out of range [1, 1]", - 1); + RunSparseAttentionInvalidKeyLengthsTest({0}, {1}, + "key_total_sequence_lengths value 0 at batch index 0 is out of range [1, 1]", + 1); } TEST(SparseAttentionTest, AcceptsPromptKeyTotalSequenceLengthsForPaddedBatch) { @@ -258,5 +258,278 @@ TEST(SparseAttentionTest, RejectsZeroDimBlockRowIndices) { {}, nullptr, &execution_providers); } +// Helper for CSR value-validation tests. +// Uses: num_heads=2, kv_num_heads=2, sparse_block_size=16, head_size=8. +// block_row_indices shape: (1, max_blocks+1), block_col_indices shape: (1, col_count). +// max_sequence_length = max_blocks * 16 must be >= total_sequence_length. +// These tests validate that element values in block_row_indices and block_col_indices are checked. +// Note: these tests expect failure via a returned Status (ORT_MAKE_STATUS), so they are safe in +// both exceptions-enabled and no-exceptions builds. +static void RunSparseAttentionCSRValidationTest( + const std::vector& block_row_indices_data, + const std::vector& block_row_indices_dims, + const std::vector& block_col_indices_data, + const std::vector& block_col_indices_dims, + const std::string& expected_error) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 2); + test.AddAttribute("kv_num_heads", 2); + test.AddAttribute("sparse_block_size", 16); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + // head_size=8, num_heads=2 => hidden_size=16 + // sequence_length=1, batch_size=1 + test.AddInput("query", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddInput("key", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddInput("value", {1, 1, 16}, std::vector(16, 0.0f)); + // past_key/value: (batch_size=1, kv_num_heads=2, max_cache_seq_len=32, head_size=8) + test.AddInput("past_key", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddInput("past_value", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddInput("block_row_indices", block_row_indices_dims, block_row_indices_data); + test.AddInput("block_col_indices", block_col_indices_dims, block_col_indices_data); + test.AddInput("total_sequence_length", {1}, {2}); + test.AddInput("key_total_sequence_lengths", {1}, {2}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddOutput("present_key", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddOutput("present_value", {1, 2, 32, 8}, std::vector(512, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error, {}, nullptr, &execution_providers); +} + +// block_row_indices[0][0] must be 0. +TEST(SparseAttentionTest, RejectsBlockRowIndicesFirstElementNonZero) { + // shape (1, 3) => max_blocks=2, max_sequence_length=32 + RunSparseAttentionCSRValidationTest( + {1, 1, 2}, {1, 3}, // row indices: first element is 1, not 0 + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices[0][0] must be 0"); +} + +// block_row_indices must be monotonically non-decreasing. +TEST(SparseAttentionTest, RejectsBlockRowIndicesNonMonotonic) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 2, 1}, {1, 3}, // row indices: 2 > 1 at row 1 (non-monotonic) + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_row_indices values must not exceed block_col_indices column count. +TEST(SparseAttentionTest, RejectsBlockRowIndicesExceedsColCount) { + // shape (1, 3) => max_blocks=2, col_count=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 3}, {1, 3}, // row indices: last element 3 > col_count=2 + {0, 1}, {1, 2}, // col indices shape (1, 2) + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_col_indices values must be in [0, max_blocks). +TEST(SparseAttentionTest, RejectsBlockColIndicesOutOfRange) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2}, {1, 3}, // row indices: valid + {0, 2}, {1, 2}, // col indices: value 2 >= max_blocks=2 + "block_col_indices[0][1]=2 is out of valid range [0, 2)"); +} + +// block_col_indices negative values must be rejected. +TEST(SparseAttentionTest, RejectsBlockColIndicesNegative) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2}, {1, 3}, // row indices: valid + {0, -1}, {1, 2}, // col indices: negative value + "block_col_indices[0][1]=-1 is out of valid range [0, 2)"); +} + +// block_row_indices with negative values. +TEST(SparseAttentionTest, RejectsBlockRowIndicesNegative) { + RunSparseAttentionCSRValidationTest( + {0, -1, 2}, {1, 3}, // row indices: negative value at index 1 + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_col_indices with large OOB value (the original vulnerability scenario). +TEST(SparseAttentionTest, RejectsBlockColIndicesLargeValue) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 2, 2}, {1, 3}, // row indices: valid CSR format + {0, 1048576}, {1, 2}, // col indices: 0x100000 far out of range + "block_col_indices[0][1]=1048576 is out of valid range [0, 2)"); +} + +// Multi-layout: invalid col index in second layout only. +TEST(SparseAttentionTest, RejectsBlockColIndicesInvalidInSecondLayout) { + // shape (2, 3) => num_layout=2, max_blocks=2 + // num_heads=2, so num_heads % num_layout == 0 + RunSparseAttentionCSRValidationTest( + {0, 1, 2, 0, 1, 2}, {2, 3}, // row indices: valid for both layouts + {0, 1, 0, 5}, {2, 2}, // col indices: layout 0 valid, layout 1 has 5 >= max_blocks=2 + "block_col_indices[1][1]=5 is out of valid range [0, 2)"); +} + +// Multi-layout: invalid row pointer in second layout only. +TEST(SparseAttentionTest, RejectsBlockRowIndicesInvalidInSecondLayout) { + // shape (2, 3) => num_layout=2, max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2, 1, 1, 2}, {2, 3}, // row indices: layout 0 valid, layout 1 starts with 1 != 0 + {0, 1, 0, 1}, {2, 2}, // col indices: valid + "block_row_indices[1][0] must be 0"); +} + +// Col index invalid within NNZ range but padding would be fine. +// row pointers say nnz=1, col[0] is invalid, col[1] is padding (not checked). +TEST(SparseAttentionTest, RejectsBlockColIndicesInvalidWithinNNZ) { + // shape (1, 3) => max_blocks=2, row indices: {0, 1, 1} means row 0 has 1 entry, row 1 has 0 + // nnz = r[max_blocks] = r[2] = 1, so only col[0] is validated + RunSparseAttentionCSRValidationTest( + {0, 1, 1}, {1, 3}, // row indices: valid, nnz=1 + {99, 0}, {1, 2}, // col[0]=99 is out of range, col[1]=0 is padding (not checked) + "block_col_indices[0][0]=99 is out of valid range [0, 2)"); +} + +#if defined(USE_CUDA) +// CUDA-specific CSR validation tests. +// CUDA SparseAttention requires head_size=128, sparse_block_size=64, and MLFloat16 inputs. +// These tests verify that the device-side ValidateCSRIndicesOnDevice kernel correctly +// rejects invalid CSR indices. Error messages are less detailed than CPU (no per-element info) +// because the CUDA kernel reports via a single error code. +// Note: OpTester does not share past/present buffers (no IOBinding), but that is fine here +// because the CSR validation runs before the shared-buffer check in ComputeInternal. +// These tests expect failure from validation, not from compute. +static void RunSparseAttentionCudaCSRValidationTest( + const std::vector& block_row_indices_data, + const std::vector& block_row_indices_dims, + const std::vector& block_col_indices_data, + const std::vector& block_col_indices_dims, + const std::string& expected_error) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 1); + test.AddAttribute("kv_num_heads", 1); + test.AddAttribute("sparse_block_size", 64); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + // head_size=128, num_heads=1 => hidden_size=128 + // sequence_length=1, batch_size=1 + const int64_t hidden_size = 128; + const int64_t max_cache_seq_len = 128; + test.AddInput("query", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("key", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("value", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("past_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("past_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("block_row_indices", block_row_indices_dims, block_row_indices_data); + test.AddInput("block_col_indices", block_col_indices_dims, block_col_indices_data); + test.AddInput("total_sequence_length", {1}, {2}); + test.AddInput("key_total_sequence_lengths", {1}, {2}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + + // Run only on CUDA EP. CPU EP does not register MLFloat16 for SparseAttention with these params. + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error, {}, nullptr, &execution_providers); +} + +// CUDA: block_row_indices first element must be 0. +TEST(SparseAttentionTest, CudaRejectsBlockRowIndicesFirstElementNonZero) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCudaCSRValidationTest( + {1, 1, 2}, {1, 3}, + {0, 1}, {1, 2}, + "block_row_indices first element must be 0 for all layouts"); +} + +// CUDA: block_row_indices must be monotonically non-decreasing. +TEST(SparseAttentionTest, CudaRejectsBlockRowIndicesNonMonotonic) { + RunSparseAttentionCudaCSRValidationTest( + {0, 2, 1}, {1, 3}, + {0, 1}, {1, 2}, + "block_row_indices values are not monotonically non-decreasing"); +} + +// CUDA: block_col_indices values must be in range. +TEST(SparseAttentionTest, CudaRejectsBlockColIndicesOutOfRange) { + RunSparseAttentionCudaCSRValidationTest( + {0, 1, 2}, {1, 3}, + {0, 99}, {1, 2}, + "block_col_indices value is out of valid range"); +} + +// CUDA: block_col_indices with large OOB value. +TEST(SparseAttentionTest, CudaRejectsBlockColIndicesLargeValue) { + RunSparseAttentionCudaCSRValidationTest( + {0, 2, 2}, {1, 3}, + {0, 1048576}, {1, 2}, + "block_col_indices value is out of valid range"); +} + +// CUDA: key_total_sequence_lengths out of range. +TEST(SparseAttentionTest, CudaRejectsKeyLengthOutOfRange) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 1); + test.AddAttribute("kv_num_heads", 1); + test.AddAttribute("sparse_block_size", 64); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + const int64_t hidden_size = 128; + const int64_t max_cache_seq_len = 128; + test.AddInput("query", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("key", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("value", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("past_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("past_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + // Valid CSR: shape (1, 3) => max_blocks=2 + test.AddInput("block_row_indices", {1, 3}, {0, 1, 2}); + test.AddInput("block_col_indices", {1, 2}, {0, 1}); + test.AddInput("total_sequence_length", {1}, {4}); + // Invalid key length: -5 is out of range [1, 4] + test.AddInput("key_total_sequence_lengths", {1}, {-5}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "key_total_sequence_lengths value is out of valid range", + {}, nullptr, &execution_providers); +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index fef185e20f341..4820f4e5c8898 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -7,12 +7,14 @@ #include #include #include +#include #include #include #include "gsl/gsl" #include "gtest/gtest.h" #include "core/common/logging/sinks/file_sink.h" +#include "core/common/path_string.h" #include "core/framework/config_options.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 38a0f15f2301d..e15098d9c8c3c 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -6,6 +6,7 @@ #include "core/framework/endian_utils.h" #include "core/framework/prepacked_weights.h" #include "core/framework/prepacked_weights_container.h" +#include "core/framework/tensor.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "test/util/include/asserts.h" @@ -225,6 +226,65 @@ TEST(TensorProtoUtilsTest, UnpackTensor) { EXPECT_FALSE(status.IsOK()); } +// A bool initializer supplied through raw_data is copied verbatim, so its bytes are not +// restricted to {0, 1}. UnpackTensor must normalize them so downstream consumers (which assume +// canonical bool values) all observe the same result regardless of how they read the byte. +TEST(TensorProtoUtilsTest, UnpackBoolTensorWithRawDataNormalizesToZeroOne) { + std::filesystem::path model_path; + TensorProto bool_tensor_proto; + bool_tensor_proto.set_data_type(TensorProto_DataType_BOOL); + bool_tensor_proto.add_dims(4); + + // Bytes outside {0, 1}: 0x00 -> 0, 0x01 -> 1, 0x02 -> 1, 0xFF -> 1. + const unsigned char raw_bytes[] = {0x00, 0x01, 0x02, 0xFF}; + bool_tensor_proto.set_raw_data(std::string(reinterpret_cast(raw_bytes), sizeof(raw_bytes))); + + bool bool_data[4]; + auto status = UnpackTensor(bool_tensor_proto, model_path, bool_data, 4); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + const auto* bytes = reinterpret_cast(bool_data); + EXPECT_EQ(bytes[0], 0); + EXPECT_EQ(bytes[1], 1); + EXPECT_EQ(bytes[2], 1); + EXPECT_EQ(bytes[3], 1); +} + +// NormalizeBoolTensorIfNeeded normalizes a CPU bool tensor's bytes to {0, 1} in place. It backs the +// external-initializer device-copy path (session_state_utils.cc), where bool bytes that may be +// non-canonical are normalized in a writable CPU staging copy before being copied to the device. +TEST(TensorProtoUtilsTest, NormalizeBoolTensorIfNeededNormalizesToZeroOne) { + auto cpu_allocator = std::make_shared(); + + // Bool tensor: write non-canonical bytes, then normalize. + Tensor bool_tensor(DataTypeImpl::GetType(), TensorShape({4}), cpu_allocator); + unsigned char* bool_bytes = reinterpret_cast(bool_tensor.MutableDataRaw()); + bool_bytes[0] = 0x00; + bool_bytes[1] = 0x01; + bool_bytes[2] = 0x02; + bool_bytes[3] = 0xFF; + + NormalizeBoolTensorIfNeeded(bool_tensor); + + EXPECT_EQ(bool_bytes[0], 0); + EXPECT_EQ(bool_bytes[1], 1); + EXPECT_EQ(bool_bytes[2], 1); + EXPECT_EQ(bool_bytes[3], 1); + + // Non-bool tensor: bytes must be left untouched. + Tensor int32_tensor(DataTypeImpl::GetType(), TensorShape({3}), cpu_allocator); + int32_t* int32_data = int32_tensor.MutableData(); + int32_data[0] = 0; + int32_data[1] = 2; + int32_data[2] = 255; + + NormalizeBoolTensorIfNeeded(int32_tensor); + + EXPECT_EQ(int32_data[0], 0); + EXPECT_EQ(int32_data[1], 2); + EXPECT_EQ(int32_data[2], 255); +} + namespace { template std::vector CreateValues() { @@ -348,6 +408,42 @@ TEST(TensorProtoUtilsTest, UnpackTensorWithExternalData) { TestUnpackExternalTensor(TensorProto_DataType_BOOL, model_path); } +// A bool initializer supplied through external data is copied verbatim, so its bytes are not +// restricted to {0, 1}. UnpackTensor must normalize them so downstream consumers (which assume +// canonical bool values) all observe the same result regardless of how they read the byte. +TEST(TensorProtoUtilsTest, UnpackBoolTensorWithExternalDataNormalizesToZeroOne) { + std::filesystem::path model_path; + + // Bytes outside {0, 1}: 0x00 -> 0, 0x01 -> 1, 0x02 -> 1, 0xFF -> 1. + const unsigned char raw_bytes[] = {0x00, 0x01, 0x02, 0xFF}; + + std::basic_string filename(ORT_TSTR("bool_tensor_XXXXXX")); + FILE* fp; + CreateTestFile(fp, filename); + ASSERT_EQ(sizeof(raw_bytes), fwrite(raw_bytes, 1, sizeof(raw_bytes), fp)); + ASSERT_EQ(0, fclose(fp)); + std::unique_ptr file_deleter(const_cast(filename.c_str()), + DeleteFileFromDisk); + + TensorProto bool_tensor_proto; + onnx::StringStringEntryProto* location = bool_tensor_proto.mutable_external_data()->Add(); + location->set_key("location"); + location->set_value(ToUTF8String(filename)); + bool_tensor_proto.add_dims(4); + bool_tensor_proto.set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + bool_tensor_proto.set_data_type(TensorProto_DataType_BOOL); + + auto arr = std::make_unique(4); + auto status = utils::UnpackTensor(bool_tensor_proto, model_path, arr.get(), 4); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + const auto* bytes = reinterpret_cast(arr.get()); + EXPECT_EQ(bytes[0], 0); + EXPECT_EQ(bytes[1], 1); + EXPECT_EQ(bytes[2], 1); + EXPECT_EQ(bytes[3], 1); +} + template static NodeProto CreateConstantNode(const std::string& attrib_name, AttributeProto_AttributeType type, std::function add_data) { diff --git a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc index 8aa4c88052742..47e08802c9e20 100644 --- a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc +++ b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc @@ -4,17 +4,24 @@ // Unit tests for the DQMatMulNBitsFusion graph transformer. // Tests Pattern 1: DQ(3D,axis=2)->Reshape->Transpose([1,0])->[Cast]->MatMul/Gemm -> MatMulNBits // Tests Pattern 2: DQ(2D,axis=0)->MatMul/Gemm -> MatMulNBits +#include #include "core/common/span_utils.h" #include "core/framework/int4.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/dq_matmulnbits_fusion.h" +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/test_environment.h" #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" #include "gtest/gtest.h" @@ -354,6 +361,166 @@ TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_NoZP) { TransformerLevel::Level1, 1, pre_check, post_check)); } +// Validates the cross-session-sharing tag the fusion attaches to the generated B weight. The tag is a +// stable, content-derived enrollment identity: identical source quantization groups yield the SAME +// identity, while a semantic difference -- here, different zero points -- yields a DIFFERENT identity. +// (The tag only enrolls B into the shared container; the actual sharing is keyed by the packed-bytes +// hash, so a stable, content-distinct tag just keeps enrollment deterministic across sessions.) +TEST_F(DQMatMulNBitsFusionTest, TagsGeneratedWeightWithStableContentIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(N * num_blocks)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + // Non-default (non-8) zero points so the fusion keeps them (it elides uniform-8 zero points). + std::vector zp_a(static_cast(N * num_blocks), 3); + std::vector zp_b(zp_a.size(), 5); + + // Runs the fusion on a Pattern-1 model built from the given zero points and returns the sharing + // identity tagged onto the generated MatMulNBits B weight. + auto tag_for = [&](const std::vector& zp) -> std::string { + std::string captured; + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, /*with_zp*/ true, /*with_cast*/ false, + /*use_gemm*/ false, &weight, &scale, &zp); + }; + auto pre_check = [](Graph&) -> Status { return Status::OK(); }; + auto post_check = [&](Graph& graph) -> Status { + int matmulnbits = 0; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + ++matmulnbits; + const std::string& b_name = node.InputDefs()[1]->Name(); // input 1 == quantized B + const std::string* id = graph.GetSharedPrepackInitializerId(b_name); + EXPECT_NE(id, nullptr) << "generated B weight was not tagged for cross-session sharing"; + if (id != nullptr) { + captured = *id; + } + } + } + EXPECT_EQ(matmulnbits, 1); + return Status::OK(); + }; + auto transformer = std::make_unique(4); + EXPECT_TRUE(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_check, post_check) + .IsOK()); + return captured; + }; + + const std::string id_a1 = tag_for(zp_a); + const std::string id_a2 = tag_for(zp_a); + const std::string id_b = tag_for(zp_b); + + ASSERT_FALSE(id_a1.empty()); + EXPECT_EQ(id_a1, id_a2); // stable: identical source quantization group -> identical identity + EXPECT_NE(id_a1, id_b); // collision-safe: different zero points -> different identity +} + +// Builds and serializes a Pattern-1 DQ->Reshape->Transpose->MatMul model (UINT4 constant weight). When +// loaded into a session with the DQ->MatMulNBits fusion enabled, it becomes a MatMulNBits whose B is +// tagged for cross-session sharing. +static void SerializeDQMatMulModel(int64_t M, int64_t N, int64_t K, int64_t block_size, + const std::vector& weight, const std::vector& scale, + const std::vector& zp, std::string& model_bytes) { + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("dq_matmulnbits_share", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + ModelTestBuilder builder(model.MainGraph()); + BuildPattern1Graph(builder, M, N, K, block_size, /*with_zp*/ true, /*with_cast*/ false, + /*use_gemm*/ false, &weight, &scale, &zp); + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); +} + +// Loads the serialized model on the CPU EP with the DQ->MatMulNBits fusion enabled and the supplied +// shared container. Reports whether the fusion produced a MatMulNBits and how many pre-packed weights +// this session served from the container. +static void RunSharedFusionSession(const std::string& model_bytes, PrepackedWeightsContainer& container, + bool& produced_matmulnbits, size_t& used_shared_count) { + SessionOptions so; + // This test exercises prepack-weight sharing, not parallel execution. Cap the intra-op thread pool + // to a single thread so we don't spin up one worker per core: under AddressSanitizer each thread adds + // fake-stack and thread-local allocator overhead, which on a high-core CI runner multiplies across the + // sessions every test creates (the sibling SessionStatePrepackingTest caps it for the same reason). + so.intra_op_param.thread_pool_size = 1; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.AddPrePackedWeightsContainer(&container)); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + ASSERT_STATUS_OK(session.Initialize()); + + produced_matmulnbits = false; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "MatMulNBits") { + produced_matmulnbits = true; + break; + } + } + used_shared_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +// End-to-end: two sessions optimizing the same DQ+MatMul model share the fused MatMulNBits B weight +// through a common container WITHOUT any session option -- the fusion tags it to enroll it, and +// SessionState keys the sharing by the packed-bytes hash. A model whose quantized weight differs packs +// to different bytes, so it gets a different key and must NOT share. (A zero-point-only difference is +// intentionally NOT used: on the CompFp32 path the zero points are not folded into the packed B, so two +// such models pack identically and would correctly share a byte-identical buffer.) +TEST_F(DQMatMulNBitsFusionTest, SharesFusedWeightAcrossSessionsViaTag) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + // A different quantized weight -> different packed B on every compute type (unlike a zp-only change). + std::vector weight_other(weight.size()); + for (size_t i = 0; i < weight_other.size(); ++i) { + weight_other[i] = static_cast((i + 7) % 16); + } + std::vector scale(static_cast(N * num_blocks)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(N * num_blocks), 3); + + std::string model_a, model_other; + SerializeDQMatMulModel(M, N, K, block_size, weight, scale, zp, model_a); + SerializeDQMatMulModel(M, N, K, block_size, weight_other, scale, zp, model_other); + + PrepackedWeightsContainer container; + bool fused1 = false, fused2 = false, fused_other = false; + size_t used1 = 0, used2 = 0, used_other = 0; + + RunSharedFusionSession(model_a, container, fused1, used1); + ASSERT_TRUE(fused1) << "DQ -> MatMulNBits fusion did not run"; + if (container.GetNumberOfElements() == 0) { + GTEST_SKIP() << "MatMulNBits B was not pre-packed on this platform"; + } + EXPECT_EQ(used1, static_cast(0)); // first session: nothing to share yet + + // Second session over the SAME model shares the tagged B from the container. + RunSharedFusionSession(model_a, container, fused2, used2); + ASSERT_TRUE(fused2); + EXPECT_GT(used2, static_cast(0)); + + // A model with a different quantized weight packs to different bytes -> different key, so it must NOT + // reuse the buffer (on any compute type). + RunSharedFusionSession(model_other, container, fused_other, used_other); + ASSERT_TRUE(fused_other); + EXPECT_EQ(used_other, static_cast(0)); +} + TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithDefaultZP8) { constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index a1c0f8adfffb7..b53577a81ff4a 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -2,10 +2,14 @@ // Licensed under the MIT License. #include +#include #include "core/common/span_utils.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" @@ -1462,6 +1466,237 @@ TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_Alpha) { 1e-5, 2e-5); } +// --------------------------------------------------------------------------- +// Cross-session pre-pack sharing for the DEFAULT DQ->MatMulNBits path +// --------------------------------------------------------------------------- +// DQMatMulToMatMulNBitsAction (in the QDQ selector/action transformer) runs without the +// session.enable_dq_matmulnbits_fusion flag and synthesizes the MatMulNBits B/scales/zp initializers +// with names that are NOT stable across sessions. It tags the generated B weight with a sharing +// identity that SessionState treats as the enrollment signal opting the buffer into the cross-session +// container; the actual sharing is keyed by the packed-bytes hash (only byte-identical packed buffers +// are reused, exactly like the AddInitializer path), so packings that differ by compute type/options +// are never falsely shared. + +// Packs uint4 nibble values (row-major, 2 per byte) into UInt4x2 storage. +static std::vector PackUint4Nibbles(const std::vector& values) { + const size_t num_pairs = UInt4x2::CalcNumInt4Pairs(values.size()); + std::vector packed(num_pairs); + for (size_t i = 0; i < values.size(); i += 2) { + const uint8_t lo = values[i] & 0x0F; + const uint8_t hi = (i + 1 < values.size()) ? (values[i + 1] & 0x0F) : 0; + packed[i / 2] = UInt4x2(lo, hi); + } + return packed; +} + +// Builds a default-path model: a constant UINT4 weight [K, N] block-quantized along axis 0 feeding a +// DequantizeLinear whose output is the second input to a single MatMul. The QDQ selector/action +// transformer converts this into a MatMulNBits. Explicit weight/scale/zp give a deterministic identity. +static void BuildDefaultPathDQMatMul(ModelTestBuilder& builder, int64_t M, int64_t N, int64_t K, + int64_t block_size, const std::vector& weight, + const std::vector& scale, const std::vector& zp) { + const int64_t num_blocks = (K + block_size - 1) / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer({K, N}, PackUint4Nibbles(weight)); + auto* scale_arg = builder.MakeInitializer({num_blocks, N}, scale); + auto* zp_arg = builder.MakeInitializer({num_blocks, N}, PackUint4Nibbles(zp)); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + + builder.AddNode("MatMul", {input_a, dq_output}, {output}); +} + +// Serializes a default-path DQ->MatMul model built from explicit quantization data. +static void SerializeDefaultPathModel(int64_t M, int64_t N, int64_t K, int64_t block_size, + const std::vector& weight, const std::vector& scale, + const std::vector& zp, std::string& model_bytes) { + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("dq_matmul_default_share", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + ModelTestBuilder builder(model.MainGraph()); + BuildDefaultPathDQMatMul(builder, M, N, K, block_size, weight, scale, zp); + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); +} + +// Loads the model on the CPU EP with the given shared container and DEFAULT options (no fusion flag). +// Reports whether a MatMulNBits was produced, the sharing identity tagged onto its B weight, and how +// many pre-packed weights this session served from the container. +static void RunDefaultPathSession(const std::string& model_bytes, PrepackedWeightsContainer& container, + bool& produced_matmulnbits, std::string& b_tag, size_t& used_shared_count, + int accuracy_level = -1) { + SessionOptions so; + // This test exercises prepack-weight sharing, not parallel execution. Cap the intra-op thread pool + // to a single thread so we don't spin up one worker per core: under AddressSanitizer each thread adds + // fake-stack and thread-local allocator overhead, which on a high-core CI runner multiplies across the + // sessions every test creates (the sibling SessionStatePrepackingTest caps it for the same reason). + so.intra_op_param.thread_pool_size = 1; + if (accuracy_level >= 0) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str())); + } + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.AddPrePackedWeightsContainer(&container)); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + ASSERT_STATUS_OK(session.Initialize()); + + produced_matmulnbits = false; + b_tag.clear(); + const Graph& graph = session.GetGraph(); + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + produced_matmulnbits = true; + const std::string& b_name = node.InputDefs()[1]->Name(); // input 1 == quantized B + if (const std::string* id = graph.GetSharedPrepackInitializerId(b_name); id != nullptr) { + b_tag = *id; + } + break; + } + } + used_shared_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +// Verifies the default DQ->MatMulNBits path tags its generated B weight with a stable, content-derived +// enrollment identity: identical quantization data yields the SAME identity, while different zero points +// yield a DIFFERENT identity. (The tag only enrolls the buffer for sharing; the container keys by the +// packed-bytes hash. A stable, content-distinct tag keeps enrollment deterministic across sessions.) +TEST(QDQTransformerTests, DefaultPath_TagsGeneratedWeightWithStableContentIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp_a(static_cast(num_blocks * N), 3); + std::vector zp_b(zp_a.size(), 5); + + auto tag_for = [&](const std::vector& zp) -> std::string { + std::string model_bytes; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_bytes); + PrepackedWeightsContainer container; + bool produced = false; + std::string tag; + size_t used = 0; + RunDefaultPathSession(model_bytes, container, produced, tag, used); + EXPECT_TRUE(produced) << "DQ -> MatMulNBits conversion did not run on the default path"; + return tag; + }; + + const std::string id_a1 = tag_for(zp_a); + const std::string id_a2 = tag_for(zp_a); + const std::string id_b = tag_for(zp_b); + + ASSERT_FALSE(id_a1.empty()) << "generated B weight was not tagged for cross-session sharing"; + EXPECT_EQ(id_a1, id_a2); // stable: identical quantization data -> identical identity + EXPECT_NE(id_a1, id_b); // collision-safe: different zero points -> different identity +} + +// End-to-end: two sessions converting the same model via the default path share the MatMulNBits B +// pre-packed buffer through a common container (no session option). A model whose quantized weight +// differs packs to different bytes -> different container key, so it must not reuse the buffer. (A +// zero-point-only difference is intentionally NOT used here: on the CompFp32 path the zero points are +// applied at compute time and left out of the packed B, so two such models pack identically and would +// correctly share -- packed-bytes keying only ever reuses byte-identical buffers.) +TEST(QDQTransformerTests, DefaultPath_SharesWeightAcrossSessionsViaTag) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + // A different quantized weight -> different packed B on every compute type (unlike a zp-only change). + std::vector weight_other(weight.size()); + for (size_t i = 0; i < weight_other.size(); ++i) { + weight_other[i] = static_cast((i + 7) % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(num_blocks * N), 3); + + std::string model_a, model_other; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_a); + SerializeDefaultPathModel(M, N, K, block_size, weight_other, scale, zp, model_other); + + PrepackedWeightsContainer container; + bool produced1 = false, produced2 = false, produced_other = false; + std::string tag1, tag2, tag_other; + size_t used1 = 0, used2 = 0, used_other = 0; + + RunDefaultPathSession(model_a, container, produced1, tag1, used1); + ASSERT_TRUE(produced1) << "DQ -> MatMulNBits conversion did not run on the default path"; + if (container.GetNumberOfElements() == 0) { + GTEST_SKIP() << "MatMulNBits B was not pre-packed on this platform"; + } + EXPECT_EQ(used1, static_cast(0)); // first session: nothing to share yet + + // Second session over the SAME model reuses the tagged B from the container. + RunDefaultPathSession(model_a, container, produced2, tag2, used2); + ASSERT_TRUE(produced2); + EXPECT_GT(used2, static_cast(0)); + + // A model with a different quantized weight packs to different bytes -> different key, so it must NOT + // reuse the buffer (on any compute type). + RunDefaultPathSession(model_other, container, produced_other, tag_other, used_other); + ASSERT_TRUE(produced_other); + EXPECT_EQ(used_other, static_cast(0)); +} + +// accuracy_level participates in the enrollment identity, so the same weights requested at different +// accuracy levels get distinct identities. Whether the two sessions then share the packed buffer is +// platform-dependent (level 4 may pack as CompInt8 -- different bytes, no share -- or fall back to the +// same CompFp32 packing as level 0 and benignly reuse the byte-identical buffer); packed-bytes keying +// makes either outcome safe, so this asserts the identity is distinct, not a fixed sharing count. +TEST(QDQTransformerTests, DefaultPath_DifferentAccuracyLevelGetsDistinctIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(num_blocks * N), 3); + + std::string model_bytes; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_bytes); + + PrepackedWeightsContainer container; + bool produced0 = false, produced4 = false; + std::string tag0, tag4; + size_t used0 = 0, used4 = 0; + + RunDefaultPathSession(model_bytes, container, produced0, tag0, used0, /*accuracy_level*/ 0); + ASSERT_TRUE(produced0) << "DQ -> MatMulNBits conversion did not run on the default path"; + + // Same model/weights, different accuracy level, sharing the same container. + RunDefaultPathSession(model_bytes, container, produced4, tag4, used4, /*accuracy_level*/ 4); + ASSERT_TRUE(produced4); + + ASSERT_FALSE(tag0.empty()); + ASSERT_FALSE(tag4.empty()); + EXPECT_NE(tag0, tag4); // accuracy_level participates in the enrollment identity +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index cf76ca0fa00f8..03a47664c632a 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -2332,6 +2332,45 @@ TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_Decode_BottomRight) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// q_len=1 with nonpad_kv_seqlen keeps bottom-right decode behavior: the single +// query attends every valid key. If the no-nonpad upper-left overlay is applied +// here, this returns 1.0 instead of the expected 1/6. +TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_SingleQueryKeepsBottomRight_CPU) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + test.AddAttribute("is_causal", static_cast(1)); + + constexpr int batch_size = 1; + constexpr int q_num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int q_sequence_length = 1; + constexpr int kv_sequence_length = 6; + constexpr int head_size = 8; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::fill_n(v.begin(), head_size, 1.0f); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, q); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, k); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {batch_size}, {kv_sequence_length}); + + std::vector expected_y(batch_size * q_num_heads * q_sequence_length * head_size, + 1.0f / static_cast(kv_sequence_length)); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, expected_y, false, 0, + 1e-4f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // Continued / chunked prefill (S_q=2) into a partially-filled static cache. // nonpad=[4], S_q=2 -> offset = 4 - 2 = 2: query 0 attends keys {0,1,2}, query 1 // attends {0,1,2,3}. The old top-left alignment would mask everything past the @@ -3087,9 +3126,41 @@ TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { // ============================================================================ // Causal alignment tests: verify upper-left (no past) vs lower-right (with past) -// These are CUDA-only tests that validate the causal masking fix. +// These tests validate causal mask alignment across CPU and CUDA. // ============================================================================ +// Test: Causal + cross-attention (S_q=1, S_kv=6, no past) +// ONNX spec mandates upper-left alignment: q0 attends only to kv[0]. +// This covers GitHub issue #29020, where CPU skipped causal masking for S_q=1. +TEST(AttentionTest, Attention4DCausalSingleQueryCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size, 0.0f); + for (int i = 0; i < v_head_size; ++i) { + v[i] = 1.0f; + } + std::vector y(batch_size * q_num_heads * q_sequence_length * v_head_size, 1.0f); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, + v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), + std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, + TensorType::kFloat, + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + // Test: Causal + cross-attention (S_q=3, S_kv=5, no past) // ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. // V is identity-like so output directly reveals which KV positions were attended. diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c3d91100605e9..495fea5735b32 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -4,6 +4,14 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#ifdef USE_CUDA +#include "core/graph/model.h" +#include "core/session/inference_session.h" +#include "test/test_environment.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/default_providers.h" +#endif + namespace onnxruntime { namespace test { @@ -148,5 +156,81 @@ TEST(CompressTest, Compress_3dims_neg_axis) { test.Run(); } +#ifdef USE_CUDA +// Regression test for the CUDA Compress prefix-sum sizing path. A bool condition byte may hold a +// non-canonical value (e.g. 0xFF) at runtime — initializers are normalized to {0, 1} on unpack, +// but runtime-produced bool tensors are not. Without normalizing the byte before the prefix sum, +// 0xFF would be summed as 255 (sizing the output for 255 selected elements) while _CompressKernel +// selects it as a single element via truthiness, so sizing and selection would disagree. This +// test feeds a raw 0xFF condition byte (which OpTester cannot produce, since it normalizes bool +// inputs to {0, 1}) and asserts the output is sized by truthiness. +TEST(CompressTest, Compress_cuda_non_canonical_bool_condition) { + // Build: output = Compress(input, condition, axis=0) + auto model = std::make_unique("compress_non_canonical_bool", false, + DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ONNX_NAMESPACE::TypeProto tensor_bool; + tensor_bool.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + + auto& input_arg = graph.GetOrCreateNodeArg("input", &tensor_float); + auto& condition_arg = graph.GetOrCreateNodeArg("condition", &tensor_bool); + auto& output_arg = graph.GetOrCreateNodeArg("output", &tensor_float); + + std::vector input_defs{&input_arg, &condition_arg}; + std::vector output_defs{&output_arg}; + auto& node = graph.AddNode("compress", "Compress", "Compress", input_defs, output_defs, nullptr, + onnxruntime::kOnnxDomain); + node.AddAttribute("axis", static_cast(0)); + ASSERT_STATUS_OK(graph.Resolve()); + + SessionOptions so; + so.session_logid = "CompressTest.Compress_cuda_non_canonical_bool_condition"; + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); + + std::string serialized_model; + ASSERT_TRUE(model->ToProto().SerializeToString(&serialized_model)); + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + AllocatorPtr cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + OrtValue input_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({3, 2}), cpu_allocator, input_value); + const float input_data[6] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + memcpy(input_value.GetMutable()->MutableData(), input_data, sizeof(input_data)); + + // Condition {false, non-canonical-true, true}: write a raw 0xFF byte for the middle element to + // emulate a runtime-produced bool tensor outside the canonical {0, 1} set. + OrtValue condition_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({3}), cpu_allocator, condition_value); + auto* condition_bytes = + reinterpret_cast(condition_value.GetMutable()->MutableDataRaw()); + condition_bytes[0] = 0x00; + condition_bytes[1] = 0xFF; + condition_bytes[2] = 0x01; + + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run( + std::unordered_map{{"input", input_value}, {"condition", condition_value}}, + std::vector{"output"}, &fetches)); + + ASSERT_EQ(fetches.size(), 1u); + const Tensor& output = fetches[0].Get(); + // Two non-zero condition bytes select two rows along axis 0 (not 256). + EXPECT_EQ(output.Shape(), TensorShape({2, 2})); + const auto output_span = output.DataAsSpan(); + const std::vector expected{3.0f, 4.0f, 5.0f, 6.0f}; + ASSERT_EQ(output_span.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_EQ(output_span[i], expected[i]); + } +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc new file mode 100644 index 0000000000000..b75de767bb7f6 --- /dev/null +++ b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc @@ -0,0 +1,373 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests that the CUDA plugin EP supports combining a user-provided compute stream +// (user_compute_stream) with CUDA graph capture/replay (enable_cuda_graph). +// +// Historically the plugin EP rejected this combination with ORT_INVALID_ARGUMENT. +// It now captures and replays the CUDA graph on the user-provided stream (the same +// stream the kernels are issued to), matching the bundled CUDA EP behavior. These +// tests verify: +// 1. Session creation succeeds with both options set (regression for the removed +// validation). +// 2. Capture + replay on the user stream produce correct results. +// 3. Replay after an in-place input update (on the user stream) is correct. + +#if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/file_util.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { +namespace { + +constexpr const char* kCudaPluginEpRegistrationName = "CudaPluginUserStreamGraphTest"; +constexpr const char* kCudaPluginEpName = "CudaPluginExecutionProvider"; + +// Resolve the CUDA plugin EP shared library path. +std::filesystem::path GetCudaPluginLibraryPath() { + return GetSharedLibraryFileName(ORT_TSTR("onnxruntime_providers_cuda_plugin")); +} + +// RAII handle that registers/unregisters the CUDA plugin EP library. +class ScopedCudaPluginRegistration { + public: + ScopedCudaPluginRegistration(Ort::Env& env, const char* registration_name) + : env_(env), name_(registration_name) { + auto lib_path = GetCudaPluginLibraryPath(); + if (!std::filesystem::exists(lib_path)) { + available_ = false; + return; + } + env_.RegisterExecutionProviderLibrary(name_.c_str(), lib_path.c_str()); + available_ = true; + } + + ~ScopedCudaPluginRegistration() { + if (available_) { + try { + env_.UnregisterExecutionProviderLibrary(name_.c_str()); + } catch (...) { + } + } + } + + bool IsAvailable() const { return available_; } + + ScopedCudaPluginRegistration(const ScopedCudaPluginRegistration&) = delete; + ScopedCudaPluginRegistration& operator=(const ScopedCudaPluginRegistration&) = delete; + + private: + Ort::Env& env_; + std::string name_; + bool available_ = false; +}; + +// Find the CUDA plugin EP device after registration. +Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { + auto ep_devices = env.GetEpDevices(); + for (const auto& device : ep_devices) { + if (strcmp(device.EpName(), kCudaPluginEpName) == 0) { + return device; + } + } + return Ort::ConstEpDevice{nullptr}; +} + +// Dummy external allocator callbacks. They are only used to make the external-allocator +// configuration non-null; the plugin EP rejects the combination with user_compute_stream +// before either is ever invoked. +void* DummyExternalAlloc(size_t /*size*/) { return nullptr; } +void DummyExternalFree(void* /*ptr*/) {} + +} // namespace + +class CudaPluginUserStreamGraphTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kCudaPluginEpRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + cudaDeviceSynchronize(); + } + + // Build session options that select the plugin EP with CUDA graph capture enabled + // and the user-provided stream supplied as a pointer-sized address string. + Ort::SessionOptions CreateUserStreamGraphSessionOptions(cudaStream_t user_stream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"enable_cuda_graph", "1"}, + {"user_compute_stream", + std::to_string(reinterpret_cast(user_stream))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + return so; + } + + // Allocate device input/output, bind them, and run `iterations` times on `stream`, verifying + // Y = X * W each run. The input is uploaded once up front and then left constant: when CUDA graph + // capture is enabled, issuing host->device work on the stream immediately before the capture run + // would interfere with cudaStreamBeginCapture, so the buffers are populated and synchronized + // before any capture happens. When `graph_ids` is non-empty, run i sets gpu_graph_id to + // graph_ids[i % size] to exercise CUDA graph annotation-id switching. mul_1.onnx computes + // Y = X * W with W = [1..6] (shape 3x2). + void RunAndVerifyOnStream(Ort::Session& session, cudaStream_t stream, int iterations, + const std::vector& graph_ids = {}) { + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::array x_values = {2.0f, 3.0f, 5.0f, 7.0f, 11.0f, 13.0f}; + + // Fixed device buffers so captured CUDA graphs keep valid IO addresses across replays. + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + // Populate the input once and synchronize, so no host-issued work is pending on `stream` + // when graph capture begins on a later run. + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, x_values.data(), kBytes, cudaMemcpyHostToDevice, stream)); + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + for (int i = 0; i < iterations; ++i) { + Ort::RunOptions run_options; + if (!graph_ids.empty()) { + run_options.AddConfigEntry("gpu_graph_id", graph_ids[i % graph_ids.size()].c_str()); + } + session.Run(run_options, binding); + + // Kernels run on `stream`; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + std::array y{}; + ASSERT_EQ(cudaSuccess, cudaMemcpy(y.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + for (size_t j = 0; j < kNumElements; ++j) { + EXPECT_FLOAT_EQ(y[j], x_values[j] * w_values[j]) << "mismatch at iteration " << i << " index " << j; + } + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +// Regression: creating a session with both user_compute_stream and enable_cuda_graph +// used to fail with ORT_INVALID_ARGUMENT. It must now succeed. +TEST_F(CudaPluginUserStreamGraphTest, SessionCreatesWithUserStreamAndCudaGraph) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + { + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + ASSERT_NO_THROW({ + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }); + } + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Full capture + replay on the user stream, including replay after an in-place input +// update. mul_1.onnx computes Y = X * W with W = [1, 2, 3, 4, 5, 6] (shape 3x2). +TEST_F(CudaPluginUserStreamGraphTest, CaptureAndReplayOnUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Device allocator backing the plugin EP's default memory. + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Pre-allocate device input/output buffers (required for CUDA graph IO binding). + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + auto upload_input = [&](const std::array& host_values) { + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, host_values.data(), kBytes, + cudaMemcpyHostToDevice, user_stream)); + }; + + auto read_output = [&](std::array& host_values) { + // Kernels run on the user stream; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(user_stream)); + ASSERT_EQ(cudaSuccess, + cudaMemcpy(host_values.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + }; + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + // First run: warmup + capture + first replay on the user stream. + const std::array x0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + upload_input(x0); + session.Run(Ort::RunOptions{}, binding); + + std::array y{}; + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "capture mismatch at " << i; + } + + // Second run: pure graph replay (same inputs) on the user stream. + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "replay mismatch at " << i; + } + + // Update the input in place on the user stream and replay again. + const std::array x1 = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; + upload_input(x1); + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x1[i] * w_values[i]) << "updated-input replay mismatch at " << i; + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Negative: a user_compute_stream combined with an external GPU allocator +// (gpu_external_alloc/gpu_external_free) is not supported and must be rejected at session +// creation with an error rather than silently ignored. +TEST_F(CudaPluginUserStreamGraphTest, RejectsUserStreamWithExternalAllocator) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"user_compute_stream", std::to_string(reinterpret_cast(user_stream))}, + {"gpu_external_alloc", std::to_string(reinterpret_cast(&DummyExternalAlloc))}, + {"gpu_external_free", std::to_string(reinterpret_cast(&DummyExternalFree))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + EXPECT_THROW( + { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }, + Ort::Exception); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Edge case: cudaStream_t(0) (the CUDA default stream) is a valid user-provided stream. Because +// user_compute_stream parses to nullptr, the caller must set has_user_compute_stream explicitly, +// otherwise the stream would be treated as "not provided". Session creation must succeed and +// inference must run correctly on the default stream. +// +// Note: CUDA graph capture is intentionally NOT enabled here. The legacy default stream (stream 0) +// cannot be captured (cudaStreamBeginCapture returns cudaErrorStreamCaptureUnsupported), so this +// test exercises only that stream 0 is honored as the compute stream for non-graph execution. +TEST_F(CudaPluginUserStreamGraphTest, DefaultStreamAsUserStream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"has_user_compute_stream", "1"}, + {"user_compute_stream", "0"}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Run several iterations on the default stream (stream 0) and verify correctness. + RunAndVerifyOnStream(session, /*stream=*/nullptr, /*iterations=*/4); +} + +// Switching the CUDA graph annotation id (gpu_graph_id) between runs while using a user stream +// must capture/replay a distinct graph per id without crashing and keep producing correct results. +TEST_F(CudaPluginUserStreamGraphTest, GraphAnnotationIdSwitchingWithUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Alternate between annotation ids "1" and "2". With min_num_runs_before_cuda_graph_capture == 2, + // 8 iterations let each id accumulate warmup runs, capture, and then replay on the user stream. + RunAndVerifyOnStream(session, user_stream, /*iterations=*/8, /*graph_ids=*/{"1", "2"}); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP diff --git a/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py b/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py new file mode 100644 index 0000000000000..a8ce794f5fdd3 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# pylint: disable=C0114,C0115,C0116,W0212 +import unittest + +import onnxruntime + + +class TestGetNvidiaDllPaths(unittest.TestCase): + """Unit tests for the private _get_nvidia_dll_paths helper that locates CUDA/cuDNN + libraries inside the NVIDIA site-packages folders. + + NVIDIA restructured the CUDA Python wheels starting with CUDA 13: the per-component + packages (cublas, cufft, cuda_runtime, ...) were consolidated into a single + "nvidia/cu{major}" tree. These tests pin down the expected relative paths for the + old (CUDA 12) and new (CUDA 13) layouts on both Windows and Linux. + """ + + def _paths(self, **kwargs): + return onnxruntime._get_nvidia_dll_paths(**kwargs) + + # ---- CUDA 12 (legacy per-component layout) -------------------------------------- + def test_cuda12_windows(self): + paths = self._paths(is_windows=True, build_cuda_version="12.4", cudnn=False) + self.assertIn(("nvidia", "cublas", "bin", "cublasLt64_12.dll"), paths) + self.assertIn(("nvidia", "cublas", "bin", "cublas64_12.dll"), paths) + self.assertIn(("nvidia", "cufft", "bin", "cufft64_11.dll"), paths) + self.assertIn(("nvidia", "cuda_runtime", "bin", "cudart64_12.dll"), paths) + + def test_cuda12_linux(self): + paths = self._paths(is_windows=False, build_cuda_version="12.4", cudnn=False) + self.assertIn(("nvidia", "cublas", "lib", "libcublasLt.so.12"), paths) + self.assertIn(("nvidia", "cublas", "lib", "libcublas.so.12"), paths) + self.assertIn(("nvidia", "cuda_nvrtc", "lib", "libnvrtc.so.12"), paths) + self.assertIn(("nvidia", "curand", "lib", "libcurand.so.10"), paths) + self.assertIn(("nvidia", "cufft", "lib", "libcufft.so.11"), paths) + self.assertIn(("nvidia", "cuda_runtime", "lib", "libcudart.so.12"), paths) + + # ---- CUDA 13 (consolidated "cu13" layout) --------------------------------------- + def test_cuda13_windows_x86_64(self): + paths = self._paths(is_windows=True, build_cuda_version="13.2", cudnn=False, arch="x86_64") + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cublasLt64_13.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cublas64_13.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cufft64_12.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cudart64_13.dll"), paths) + + def test_cuda13_windows_arch_override(self): + paths = self._paths(is_windows=True, build_cuda_version="13.2", cudnn=False, arch="arm64") + self.assertIn(("nvidia", "cu13", "bin", "arm64", "cudart64_13.dll"), paths) + + def test_cuda13_linux_is_flat(self): + paths = self._paths(is_windows=False, build_cuda_version="13.2", cudnn=False) + # Linux consolidated layout has no architecture sub-folder (flat "lib"). + self.assertIn(("nvidia", "cu13", "lib", "libcublasLt.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcublas.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libnvrtc.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcurand.so.10"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcufft.so.12"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcudart.so.13"), paths) + + # ---- cuDNN keeps its own package/layout in both schemes ------------------------- + def test_cudnn_layout_unchanged(self): + for build_cuda_version in ("12.4", "13.2"): + win = self._paths(is_windows=True, build_cuda_version=build_cuda_version, cuda=False) + self.assertIn(("nvidia", "cudnn", "bin", "cudnn64_9.dll"), win) + + linux = self._paths(is_windows=False, build_cuda_version=build_cuda_version, cuda=False) + self.assertEqual(linux, [("nvidia", "cudnn", "lib", "libcudnn.so.9")]) + + # ---- toggles -------------------------------------------------------------------- + def test_cuda_and_cudnn_toggles(self): + self.assertEqual(self._paths(is_windows=False, build_cuda_version="13.2", cuda=False, cudnn=False), []) + + cuda_only = self._paths(is_windows=False, build_cuda_version="13.2", cuda=True, cudnn=False) + self.assertTrue(all(p[1] == "cu13" for p in cuda_only)) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py index 77ac08cf50d6c..7dbcb16a75973 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -106,6 +106,70 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with a non-quantized FP32 KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ] + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + def benchmark_gqa( batch_size, seq_len, @@ -117,6 +181,7 @@ def benchmark_gqa( past_seq_len=0, warmup=5, repeats=20, + non_quantized=False, ): """Benchmark a single GQA configuration. Returns elapsed time in ms.""" hidden_size = num_heads * head_size @@ -126,54 +191,76 @@ def benchmark_gqa( total_seqlen = past_seq_len + seq_len buffer_seq_len = total_seqlen - onnx_model_str = create_quantized_gqa_graph( - batch_size, - seq_len, - num_heads, - kv_num_heads, - head_size, - quant_type, - bit_width, - buffer_seq_len=buffer_seq_len, - ) - sess_options = SessionOptions() sess_options.intra_op_num_threads = 8 - sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - # Generate inputs np.random.seed(42) query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) - - cache_dtype = np.uint8 if bit_width == 4 else np.int8 - past_k = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - past_v = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) total_seq = np.array([total_seqlen], dtype=np.int32) - per_channel = quant_type == "PER_CHANNEL" - scale_size = kv_num_heads * head_size if per_channel else 1 - k_scale = np.full(scale_size, 0.01, dtype=np.float32) - v_scale = np.full(scale_size, 0.01, dtype=np.float32) - - feeds = { - "query": query, - "key": key, - "value": value, - "past_key": past_k, - "past_value": past_v, - "seqlens_k": seqlens_k, - "total_sequence_length": total_seq, - "k_scale": k_scale, - "v_scale": v_scale, - } + if non_quantized: + onnx_model_str = create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + past_k = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + past_v = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + } + else: + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } # Warmup for _ in range(warmup): @@ -242,20 +329,21 @@ def run_benchmarks(args): "past_seq_len": 2048, } ) - # INT4 prefill - configs.append( - { - "label": "Prefill S=2048 INT4", - "batch_size": 1, - "seq_len": 2048, - "num_heads": 16, - "kv_num_heads": 8, - "head_size": 128, - "quant_type": "PER_TENSOR", - "bit_width": 4, - "past_seq_len": 0, - } - ) + # INT4 prefill (quantized mode only) + if not args.fp32: + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) warmup = args.warmup repeats = args.repeats @@ -263,13 +351,15 @@ def run_benchmarks(args): # Save and restore env var to avoid side effects on callers saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + kv_mode = "FP32 (non-quantized)" if args.fp32 else "INT8/INT4 quantized" print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") - print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"KV cache: {kv_mode}, Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") print("-" * 62) for cfg in configs: label = cfg.pop("label") + cfg["non_quantized"] = args.fp32 # Flash path (default) os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) @@ -296,5 +386,6 @@ def run_benchmarks(args): parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + parser.add_argument("--fp32", action="store_true", help="Use non-quantized FP32 KV cache instead of quantized") args = parser.parse_args() run_benchmarks(args) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 19968db98edd7..c438790ab5950 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -19,6 +19,7 @@ import torch from bert_padding import pad_input, unpad_input from einops import rearrange, repeat +from env_var_helper import scoped_env_var from onnx import TensorProto, helper from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -2645,6 +2646,69 @@ def test_gqa_past(self): qk_output, ) + def test_gqa_decode_flash_vs_naive_parity(self): + # The FP32 flash gate enables the dedicated GEMV decode kernel (and the + # flash-decoding KV-split reduction) for sequence_length == 1 with + # total_sequence_length > 1. Run the same decode configs against the + # reference twice: once with the flash path enabled (default) and once + # with it disabled via ORT_GQA_DISABLE_FLASH_ATTENTION=1 (naive path). + # If both paths match the reference, the decode kernel and KV-split + # reduction are correct -- including the bias and local-window cases. + print("-------- TEST GQA DECODE FLASH VS NAIVE PARITY ---------") + + # FP32 only: the GEMV decode kernel and flash gate are float-only. + torch_type = torch.float32 + numpy_type = numpy.float32 + ort_type = TensorProto.FLOAT + rtol = 1e-3 + atol = 1e-3 + + batches = [1, 3] + # (sequence_length == 1) decode. Include a long KV length so that the + # flash-decoding KV-split path (kv_chunk_count > 1) is exercised. + seqs = [(1, 128), (1, 2048)] + num_h = [(9, 3)] + h_sizes = [64, 128] + + # "0" keeps the flash path enabled; "1" forces the naive path. Reseed per + # phase so both paths are validated against the reference on identical + # inputs, independent of test execution order. + for env_value in ["0", "1"]: + with scoped_env_var("ORT_GQA_DISABLE_FLASH_ATTENTION", env_value): + print(f" flash {'disabled (naive path)' if env_value == '1' else 'enabled'}") + random.seed(69) + torch.manual_seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for has_attn in [False, True]: + config = Config( + b, + s, + s2, + 0, + n, + n2, + h, + False, + has_attn, + False, + QKOutputType.NO_OUTPUT, + ) + all_close = parity_check_gqa_past( + config, + torch_type=torch_type, + numpy_type=numpy_type, + ort_type=ort_type, + local=local, + past_format=Formats.BNSH, + rtol=rtol, + atol=atol, + ) + self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 349f68f0ac5ce..d1a5d833d12a4 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -666,12 +666,11 @@ def attention_past_func( # ################################################################################################# -def construct_causal_mask(seqlen_q, seqlen_k, device): - """Construct a causal mask for attention.""" +def construct_causal_mask(seqlen_q, seqlen_k, device, past_seqlen=0): + """Construct a causal mask for ONNX Attention upper-left alignment.""" row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - # Causal: positions can only attend to earlier positions - return col_idx > row_idx + seqlen_k - seqlen_q + return col_idx > row_idx + past_seqlen def attention_ref( @@ -682,6 +681,7 @@ def attention_ref( attn_bias=None, causal=False, softcap=0.0, + past_seqlen=0, ): """ Reference implementation of scaled dot-product attention with GQA support. @@ -694,6 +694,7 @@ def attention_ref( attn_bias: Additive attention bias [broadcastable to batch, num_heads, seq_q, seq_k] causal: Whether to apply causal masking softcap: Softcap value for attention scores (0.0 = disabled) + past_seqlen: Number of past K/V tokens before q[0] for causal masking. Returns: output: Attention output [batch, seq_q, num_heads, head_size] @@ -724,7 +725,7 @@ def attention_ref( scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device) + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device, past_seqlen=past_seqlen) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index 542094a9ac4ee..eca86e429b597 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -289,6 +289,7 @@ def parity_check_gqa_past( key_padding_mask=key_padding_mask, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -548,6 +549,7 @@ def parity_check_gqa_past_with_padding( key_padding_mask=key_padding_mask, causal=config.is_causal == 1, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) # --- ONNX Runtime Path --- diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index e2d9acbd0c500..16a00085f224f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -298,6 +298,7 @@ def parity_check_mha_past( attn_bias=attn_bias_ref, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -2393,7 +2394,13 @@ def test_mha_past_asymmetric_v_head_size(self): full_k_bsnh = full_k_bnsh.transpose(1, 2) full_v_bsnh = full_v_bnsh.transpose(1, 2) - out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + out_ref, _ = attention_ref( + q=q, + k=full_k_bsnh, + v=full_v_bsnh, + causal=True, + past_seqlen=config.past_kv_sequence_length, + ) # ORT path — should fall back to unfused (not crash in MEA) out_ort, present_k, present_v = attention_past_func( diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 53745cae9d803..a58f82deab4ff 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -54,8 +54,11 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { EXPECT_NE(allocated, nullptr) << "KernelContext_GetAllocator() can successfully allocate some memory"; allocator.Free(allocated); + OrtSyncStream* sync_stream = ctx.GetSyncStream(); + // Do computation #ifdef USE_CUDA + EXPECT_NE(sync_stream, nullptr) << "KernelContext_GetSyncStream() returns the kernel compute stream"; // Launch on stream 0 or user provided stream void* stream; Ort::ThrowOnError(ort_.KernelContext_GetGPUComputeStream(context, &stream)); @@ -70,6 +73,7 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { // and use the same compute stream to launch the custom op. // Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.) #else + EXPECT_EQ(sync_stream, nullptr) << "CPU custom ops do not have a compute stream"; ORT_UNUSED_PARAMETER(ort_); for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; diff --git a/onnxruntime/test/shared_lib/test_ep_context_data_api.cc b/onnxruntime/test/shared_lib/test_ep_context_data_api.cc new file mode 100644 index 0000000000000..ec8107f92aa7a --- /dev/null +++ b/onnxruntime/test/shared_lib/test_ep_context_data_api.cc @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" + +#include "gmock/gmock.h" +#include "gsl/gsl" +#include "gtest/gtest.h" +#include "test/util/include/api_asserts.h" + +namespace { + +void ExpectFailureOrtStatus(OrtStatus* status_ptr, OrtErrorCode expected_code, const char* expected_message) { + Ort::Status status{status_ptr}; + ASSERT_NE(status_ptr, nullptr) << "Expected a failure status, but the API returned nullptr (OK)."; + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), expected_code); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr(expected_message)); +} + +struct EpContextReadCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +OrtStatus* ORT_API_CALL EpContextReadCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* read_state = static_cast(state); + read_state->called = true; + read_state->file_name = file_name; + + *buffer = nullptr; + *data_size = read_state->payload.size(); + + if (read_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, read_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::memcpy(*buffer, read_state->payload.data(), read_state->payload.size()); + return nullptr; +} + +struct EpContextWriteCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +OrtStatus* ORT_API_CALL EpContextWriteCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* write_state = static_cast(state); + write_state->called = true; + write_state->file_name = file_name; + write_state->payload.clear(); + if (buffer_size != 0) { + if (buffer == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, + "EpContextWriteCallback received a null buffer for non-empty data"); + } + + const char* buffer_bytes = static_cast(buffer); + write_state->payload.assign(buffer_bytes, buffer_bytes + buffer_size); + } + + return nullptr; +} + +} // namespace + +TEST(EpContextDataApiTest, ReadFuncIsReturnedByEpApi) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{ + false, + {}, + {'e', 'p', 'c', 't', 'x'}, + }; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = nullptr; + void* callback_state_out = nullptr; + ep_context_config.GetReadFunc(read_func, callback_state_out); + ASSERT_EQ(read_func, EpContextReadCallback); + ASSERT_EQ(callback_state_out, &callback_state); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ASSERT_ORTSTATUS_OK(read_func(callback_state_out, "context.bin", allocator, &buffer, &buffer_size)); + auto release_buffer = gsl::finally([&]() { + if (buffer != nullptr) { + allocator.Free(buffer); + } + }); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "context.bin"); + ASSERT_EQ(buffer_size, callback_state.payload.size()); + EXPECT_TRUE(std::equal(callback_state.payload.begin(), callback_state.payload.end(), + static_cast(buffer))); +} + +TEST(EpContextDataApiTest, ApiRejectsInvalidArguments) { + const auto& ort_api = Ort::GetApi(); + + auto* get_config = Ort::Experimental::Get_OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28_FnOrThrow(&ort_api); + auto* release_config_func = + Ort::Experimental::Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_FnOrThrow(&ort_api); + auto* get_read_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + auto* get_write_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_FnOrThrow(&ort_api); + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + Ort::SessionOptions session_options; + OrtEpContextConfig* ep_context_config = nullptr; + ExpectFailureOrtStatus(get_config(nullptr, &ep_context_config), ORT_INVALID_ARGUMENT, "OrtSessionOptions is NULL"); + ExpectFailureOrtStatus(get_config(session_options, nullptr), ORT_INVALID_ARGUMENT, + "Output OrtEpContextConfig is NULL"); + + ExpectFailureOrtStatus(set_read_func(nullptr, EpContextReadCallback, nullptr), ORT_INVALID_ARGUMENT, + "'options' parameter must not be NULL"); + + ASSERT_ORTSTATUS_OK(get_config(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { release_config_func(ep_context_config); }); + + OrtReadNamedBufferFunc read_func = nullptr; + OrtWriteNamedBufferFunc write_func = nullptr; + void* state = nullptr; + ExpectFailureOrtStatus(get_read_func(nullptr, &read_func, &state), ORT_INVALID_ARGUMENT, + "OrtEpContextConfig is NULL"); + ExpectFailureOrtStatus(get_read_func(ep_context_config, nullptr, &state), ORT_INVALID_ARGUMENT, + "Output read_func is NULL"); + ExpectFailureOrtStatus(get_read_func(ep_context_config, &read_func, nullptr), ORT_INVALID_ARGUMENT, + "Output state is NULL"); + ExpectFailureOrtStatus(get_write_func(nullptr, &write_func, &state), ORT_INVALID_ARGUMENT, + "OrtEpContextConfig is NULL"); + ExpectFailureOrtStatus(get_write_func(ep_context_config, nullptr, &state), ORT_INVALID_ARGUMENT, + "Output write_func is NULL"); + ExpectFailureOrtStatus(get_write_func(ep_context_config, &write_func, nullptr), ORT_INVALID_ARGUMENT, + "Output state is NULL"); + +#if !defined(ORT_MINIMAL_BUILD) + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataApiRejectsInvalidArguments"}; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + ExpectFailureOrtStatus(set_write_func(nullptr, EpContextWriteCallback, nullptr), ORT_INVALID_ARGUMENT, + "OrtModelCompilationOptions is NULL"); + // A null write_func is allowed: it clears any previously set callback (covered by WriteFuncCanBeCleared), so it is + // not rejected here. +#endif // !defined(ORT_MINIMAL_BUILD) +} + +TEST(EpContextDataApiTest, AccessorsReturnNullWhenCallbacksUnset) { + Ort::SessionOptions session_options; + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + + OrtReadNamedBufferFunc read_func = EpContextReadCallback; + OrtWriteNamedBufferFunc write_func = EpContextWriteCallback; + void* state = reinterpret_cast(0x1); + + ep_context_config.GetReadFunc(read_func, state); + EXPECT_EQ(read_func, nullptr); + EXPECT_EQ(state, nullptr); + + state = reinterpret_cast(0x1); + ep_context_config.GetWriteFunc(write_func, state); + EXPECT_EQ(write_func, nullptr); + EXPECT_EQ(state, nullptr); +} + +TEST(EpContextDataApiTest, ConfigReturnsConfiguredCallbacks) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + ep_context_config.GetReadFunc(read_func, read_state); + EXPECT_EQ(read_func, EpContextReadCallback); + EXPECT_EQ(read_state, &callback_state); + + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + ep_context_config.GetWriteFunc(write_func, write_state); + EXPECT_EQ(write_func, nullptr); + EXPECT_EQ(write_state, nullptr); +} + +TEST(EpContextDataApiTest, ReadFuncCanBeCleared) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + ASSERT_ORTSTATUS_OK(set_read_func(session_options, nullptr, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = EpContextReadCallback; + void* read_state = reinterpret_cast(0x1); + ep_context_config.GetReadFunc(read_func, read_state); + EXPECT_EQ(read_func, nullptr); + EXPECT_EQ(read_state, nullptr); +} + +#if !defined(ORT_MINIMAL_BUILD) +TEST(EpContextDataApiTest, WriteFuncCanBeSetOnModelCompilationOptions) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeSetOnModelCompilationOptions"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + const std::vector payload{'b', 'i', 'n', 'a', 'r', 'y'}; + ASSERT_ORTSTATUS_OK(EpContextWriteCallback(&callback_state, "engine.bin", payload.data(), payload.size())); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "engine.bin"); + EXPECT_EQ(callback_state.payload, payload); +} + +TEST(EpContextDataApiTest, WriteFuncCanBeCleared) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeCleared"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + // A null write_func clears the previously set callback (symmetric with the read setter) and must be accepted + // rather than rejected with ORT_INVALID_ARGUMENT. + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, nullptr, &callback_state)); +} + +TEST(EpContextDataApiTest, WriteFuncCanBeUsedWithEpContextBinaryInformation) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeUsedWithEpContextBinaryInformation"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + // The EPContext write callback and the EPContext binary information may be configured together; neither call + // rejects the other. + ASSERT_NO_THROW(compilation_options.SetEpContextBinaryInformation(ORT_TSTR("ep_context_dir/"), + ORT_TSTR("compiled_model.onnx"))); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + const std::vector payload{'c', 't', 'x'}; + ASSERT_ORTSTATUS_OK(EpContextWriteCallback(&callback_state, "logical_context.bin", payload.data(), payload.size())); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "logical_context.bin"); + EXPECT_EQ(callback_state.payload, payload); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +TEST(EpContextDataApiTest, ReturnedReadFuncAllowsEmptyPayloads) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + ep_context_config.GetReadFunc(read_func, read_state); + ASSERT_EQ(read_func, EpContextReadCallback); + ASSERT_EQ(read_state, &callback_state); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = reinterpret_cast(0x1); + size_t buffer_size = 1; + ASSERT_ORTSTATUS_OK(read_func(read_state, "empty.bin", allocator, &buffer, &buffer_size)); + + EXPECT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "empty.bin"); + EXPECT_EQ(buffer, nullptr); + EXPECT_EQ(buffer_size, 0U); +} diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index a1f2f9102f027..91fc61f19e0df 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -887,11 +887,20 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithNoGraph_Fails) { EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("graph")); } -// Test validation: model with empty inputs/outputs -TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { - // Create a model with a graph that has no inputs or outputs +// Test validation: model with no outputs (one input but zero outputs). +// 0 outputs is still rejected because compilation produces an output model that +// would have no consumers for any computed values. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyOutputs_Fails) { Ort::Graph graph; - // Don't set inputs or outputs + + // Provide a single input but no outputs, to isolate the output-count check. + std::vector graph_inputs; + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph.SetInputs(graph_inputs); + // Intentionally do not call SetOutputs. std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; Model model(opsets); @@ -909,8 +918,70 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with empty inputs/outputs"; - EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("input")); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with no outputs"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("at least one output")); +} + +// Test: model with zero graph inputs is now accepted by CompileModel. +// Mirrors what CreateSessionFromModel already accepts (e.g., a graph composed of +// zero-input generator ops like RandomNormal that produces output without external input). +// Regression test for https://github.com/microsoft/onnxruntime/issues/28135. +// +// Scope: this test exercises only the ORT-side validation in ModelCompilationOptions::Check(). +// EP-specific validation (e.g., whether the WebNN EP's partitioner accepts a 0-input subgraph) +// is owned by the respective EP and is not covered here. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputs_Succeeds) { + Ort::Graph graph; + + // Zero graph inputs; one graph output produced by a RandomNormal node. + // RandomNormal takes 0 inputs and produces a tensor with shape specified via attribute. + // Use RandomNormal rather than Constant because Constant nodes are folded into initializers + // at load time (see graph.cc Graph::LoadFromModelEditorApiModel) and would not exercise + // the true 0-input producer path. + std::vector output_dims = {2, 3}; + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + + std::vector graph_outputs; + graph_outputs.emplace_back("Y", output_type_info.GetConst()); + graph.SetOutputs(graph_outputs); + // Intentionally do not call SetInputs (zero graph inputs). + + std::vector attributes; + std::vector shape_attr_value = {2, 3}; + attributes.push_back(OpAttr("shape", shape_attr_value.data(), + static_cast(shape_attr_value.size()), + OrtOpAttrType::ORT_OP_ATTR_INTS)); + + Node node("RandomNormal", onnxruntime::kOnnxDomain, "RandomNormal1", + /*input_names*/ {}, /*output_names*/ {"Y"}, attributes); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Compile should succeed. No specific compiling EP is required; the default + // kGenerateModel action emits an output model even when no EPContext nodes are produced. + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + // Fatal: a compile that returns OK but produces no artifact bytes is a regression. + ASSERT_NE(output_buffer, nullptr); + ASSERT_GT(output_size, 0u); + + allocator->Free(output_buffer); } // Test: model can be reused after compilation. diff --git a/onnxruntime/test/testdata/dummy_whisper_model_generator.py b/onnxruntime/test/testdata/dummy_whisper_model_generator.py new file mode 100644 index 0000000000000..a90518255193b --- /dev/null +++ b/onnxruntime/test/testdata/dummy_whisper_model_generator.py @@ -0,0 +1,328 @@ +"""Script to generate a dummy ONNX model emulating a Whisper model with BeamSearch op. + +The model is intentionally tiny and produces deterministic (but meaningless) outputs. +Its only purpose is to exercise the WhisperBeamSearch encoder/decoder subgraph plumbing, +in particular the decoder "use sequence as input ids" path that builds the initial decoder +feeds from the full running sequences. +""" + +import argparse + +import numpy as np +import onnx + + +def create_model( + vocab_size: int, + embed_dim: int, + num_heads: int, + head_size: int, + feature_size: int, + beam_size: int, + min_length: int, + max_length: int, + length_penalty: float, + sequence_as_input: bool, +) -> onnx.ModelProto: + encoder_graph = create_encoder(vocab_size, embed_dim, num_heads, head_size, feature_size) + decoder_graph = create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input) + + # Top-level inputs: input_features (audio) and decoder_input_ids (initial transcript tokens). + input_features = onnx.helper.make_tensor_value_info( + "input_features", onnx.TensorProto.FLOAT, ["batch_size", feature_size, "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "initial_decode_sequence_length"] + ) + + # Outputs: sequences, scores + sequences = onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ["batch_size", "num_return_sequences", "decode_sequence_length"] + ) + scores = onnx.helper.make_tensor_value_info( + "scores", onnx.TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) + + # Initializers for the BeamSearch parameters. + max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name="max_length") + min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name="min_length") + num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name="num_beams") + num_return_sequences_t = onnx.numpy_helper.from_array(np.array(1, dtype=np.int32), name="num_return_sequences") + length_penalty_t = onnx.numpy_helper.from_array( + np.array(length_penalty, dtype=np.float32), name="length_penalty_as_tensor" + ) + + # The Whisper BeamSearch op expects decoder_input_ids at input index 10. The intervening + # optional inputs (repetition_penalty, vocab_mask, prefix_vocab_mask, attention_mask) are + # left empty. + beam_search = onnx.helper.make_node( + "BeamSearch", + [ + "input_features", + "max_length", + "min_length", + "num_beams", + "num_return_sequences", + "length_penalty_as_tensor", + "", + "", + "", + "", + "decoder_input_ids", + ], + ["sequences", "scores"], + decoder_start_token_id=2, + eos_token_id=2, + early_stopping=0, + model_type=2, + pad_token_id=1, + decoder=decoder_graph, + encoder=encoder_graph, + domain="com.microsoft", + ) + + graph = onnx.helper.make_graph( + [beam_search], + "model", + [input_features, decoder_input_ids], + [sequences, scores], + [max_length_t, min_length_t, num_beams_t, num_return_sequences_t, length_penalty_t], + ) + + model = onnx.helper.make_model( + graph, opset_imports=[onnx.helper.make_opsetid("", 17), onnx.helper.make_opsetid("com.microsoft", 1)] + ) + + return model + + +def create_encoder(vocab_size, embed_dim, num_heads, head_size, feature_size) -> onnx.GraphProto: + # Inputs: encoder_input_ids (audio features, float), decoder_input_ids (int32) + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.FLOAT, ["batch_size", feature_size, "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "initial_decode_sequence_length"] + ) + + # Outputs: logits, encoder_hidden_states, present_key_self_0, present_value_self_0, + # present_key_cross_0, present_value_cross_0 + logits = onnx.helper.make_tensor_value_info( + "logits", onnx.TensorProto.FLOAT, ["batch_size", "initial_decode_sequence_length", vocab_size] + ) + encoder_hidden_states = onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ) + present_key_self_0 = onnx.helper.make_tensor_value_info( + "present_key_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_value_self_0 = onnx.helper.make_tensor_value_info( + "present_value_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_key_cross_0 = onnx.helper.make_tensor_value_info( + "present_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + present_value_cross_0 = onnx.helper.make_tensor_value_info( + "present_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + + # Initializers + feature_proj = onnx.numpy_helper.from_array( + np.random.randn(feature_size, embed_dim).astype(np.float32), name="feature_proj" + ) + decoder_embeddings = onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="encoder_decoder_embeddings" + ) + final_proj = onnx.numpy_helper.from_array( + np.random.randn(embed_dim, vocab_size).astype(np.float32), name="encoder_final_proj" + ) + num_heads_and_size = onnx.numpy_helper.from_array( + np.array([num_heads, head_size], dtype=np.int64), name="num_heads_and_size" + ) + self_state_shape = onnx.numpy_helper.from_array( + np.array([-1, 1, num_heads, head_size], dtype=np.int64), name="self_state_shape" + ) + + nodes = [ + # encoder_hidden_states = transpose(features)[B, Es, Fs] @ feature_proj[Fs, E] -> [B, Es, E] + onnx.helper.make_node("Transpose", ["encoder_input_ids"], ["features_t"], perm=[0, 2, 1]), + onnx.helper.make_node("MatMul", ["features_t", "feature_proj"], ["encoder_hidden_states"]), + # cross KV: reshape [B, Es, E] -> [B, Es, num_heads, head_size] -> transpose [B, num_heads, Es, head_size] + onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["enc_batch_seq"], end=2), + onnx.helper.make_node("Concat", ["enc_batch_seq", "num_heads_and_size"], ["enc_cross_shape"], axis=0), + onnx.helper.make_node("Reshape", ["encoder_hidden_states", "enc_cross_shape"], ["enc_cross_reshaped"]), + onnx.helper.make_node("Transpose", ["enc_cross_reshaped"], ["present_key_cross_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["enc_cross_reshaped"], ["present_value_cross_0"], perm=[0, 2, 1, 3]), + # decoder hidden states from decoder_input_ids + onnx.helper.make_node("Gather", ["encoder_decoder_embeddings", "decoder_input_ids"], ["decoder_hidden_states"]), + # logits = decoder_hidden_states[B, Ds, E] @ final_proj[E, V] -> [B, Ds, V] + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["enc_hidden_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_states", "enc_hidden_mean"], ["decoder_sum"]), + onnx.helper.make_node("MatMul", ["decoder_sum", "encoder_final_proj"], ["logits"]), + # self KV (length 1): reduce decoder hidden over Ds -> [B, 1, E] -> [B, 1, Hn, Hs] -> [B, Hn, 1, Hs] + onnx.helper.make_node("ReduceMean", ["decoder_sum"], ["self_hidden_mean"], axes=[1]), + onnx.helper.make_node("Reshape", ["self_hidden_mean", "self_state_shape"], ["self_state"]), + onnx.helper.make_node("Transpose", ["self_state"], ["present_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state"], ["present_value_self_0"], perm=[0, 2, 1, 3]), + ] + + graph = onnx.helper.make_graph( + nodes, + "encoder", + [encoder_input_ids, decoder_input_ids], + [ + logits, + encoder_hidden_states, + present_key_self_0, + present_value_self_0, + present_key_cross_0, + present_value_cross_0, + ], + [feature_proj, decoder_embeddings, final_proj, num_heads_and_size, self_state_shape], + ) + return graph + + +def create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input) -> onnx.GraphProto: + # Inputs: input_ids, encoder_hidden_states, past_key_self_0, past_value_self_0, + # past_key_cross_0, past_value_cross_0 + inputs = [ + onnx.helper.make_tensor_value_info( + "input_ids", onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length" if sequence_as_input else 1] + ), + onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ), + onnx.helper.make_tensor_value_info( + "past_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "past_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "past_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "past_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "past_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ), + onnx.helper.make_tensor_value_info( + "past_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ), + ] + + outputs = [ + onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", 1, vocab_size]), + onnx.helper.make_tensor_value_info( + "present_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "present_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + ] + + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="decoder_embeddings" + ), + onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name="final_proj"), + onnx.numpy_helper.from_array( + np.array([-1, num_heads, head_size], dtype=np.int64), name="self_state_shape_no_batch" + ), + onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name="hidden_mean_shape"), + ] + + nodes = [ + onnx.helper.make_node("Gather", ["decoder_embeddings", "input_ids"], ["decoder_hidden_states"]), + # encoder signal from encoder_hidden_states mean -> [B, 1, E] + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["enc_hidden_mean"], axes=[1]), + onnx.helper.make_node("Reshape", ["enc_hidden_mean", "hidden_mean_shape"], ["enc_hidden_mean_reshaped"]), + # reduce decoder hidden over the sequence dim -> [B, 1, E] + onnx.helper.make_node("ReduceMean", ["decoder_hidden_states"], ["decoder_hidden_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_mean", "enc_hidden_mean_reshaped"], ["decoder_sum"]), + onnx.helper.make_node("MatMul", ["decoder_sum", "final_proj"], ["logits"]), + # self KV for this step (length 1) concatenated with the running past + onnx.helper.make_node("Shape", ["decoder_sum"], ["decoder_batch"], end=1), + onnx.helper.make_node( + "Concat", ["decoder_batch", "self_state_shape_no_batch"], ["self_state_shape_dec"], axis=0 + ), + onnx.helper.make_node("Reshape", ["decoder_sum", "self_state_shape_dec"], ["self_state"]), + onnx.helper.make_node("Transpose", ["self_state"], ["single_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state"], ["single_value_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Concat", ["past_key_self_0", "single_key_self_0"], ["present_key_self_0"], axis=2), + onnx.helper.make_node("Concat", ["past_value_self_0", "single_value_self_0"], ["present_value_self_0"], axis=2), + ] + + graph = onnx.helper.make_graph(nodes, "decoder", inputs, outputs, initializers) + return graph + + +def run_model(model_path, feature_size): + # Imported lazily so model *generation* only depends on `onnx`; running needs `onnxruntime`. + import onnxruntime as ort # noqa: PLC0415 + + ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + encode_length = 5 + # Fixed, deterministic inputs so a C++ regression test can reproduce the exact golden outputs. + input_features = (((np.arange(feature_size * encode_length, dtype=np.float32) % 7) - 3.0) * 0.1).reshape( + 1, feature_size, encode_length + ) + decoder_input_ids = np.array([[2, 5]], dtype=np.int32) + sequences, scores = ort_session.run( + None, {"input_features": input_features, "decoder_input_ids": decoder_input_ids} + ) + print("input_features (flat):", input_features.flatten().tolist()) + print("decoder_input_ids:", decoder_input_ids.tolist()) + print("sequences shape:", sequences.shape) + print("sequences:", sequences.tolist()) + print("scores:", scores.tolist()) + return sequences, scores + + +def arg_parser(): + parser = argparse.ArgumentParser(description="Generate a dummy ONNX model emulating Whisper with BeamSearch op.") + parser.add_argument("--output-path", type=str, default="dummy_whisper.onnx", help="Model output path") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--vocab-size", type=int, default=20, help="Vocab size") + parser.add_argument("--embed-dim", type=int, default=8, help="Embedding dimension") + parser.add_argument("--num-heads", type=int, default=2, help="Number of heads") + parser.add_argument("--head-size", type=int, default=4, help="Head size") + parser.add_argument("--feature-size", type=int, default=8, help="Encoder input feature size") + parser.add_argument("--beam-size", type=int, default=3, help="Beam size") + parser.add_argument("--min-length", type=int, default=1, help="Min length") + parser.add_argument("--max-length", type=int, default=10, help="Max length") + parser.add_argument("--length-penalty", type=float, default=1.1, help="Length penalty") + parser.add_argument("--sequence-as-input", action="store_true", help="Use sequence as input ids") + parser.add_argument( + "--no-run", + action="store_true", + help="Only generate and save the model; skip running it (avoids needing an onnxruntime install)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = arg_parser() + np.random.seed(args.seed) + + model = create_model( + args.vocab_size, + args.embed_dim, + args.num_heads, + args.head_size, + args.feature_size, + args.beam_size, + args.min_length, + args.max_length, + args.length_penalty, + args.sequence_as_input, + ) + onnx.save(model, args.output_path) + + if not args.no_run: + run_model(args.output_path, args.feature_size) diff --git a/onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000..ad37e9d265bd8 Binary files /dev/null and b/onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx differ diff --git a/rust/onnxruntime/examples/issue22.rs b/rust/onnxruntime/examples/issue22.rs index 6c96e899fa774..1fb7fe28ff123 100644 --- a/rust/onnxruntime/examples/issue22.rs +++ b/rust/onnxruntime/examples/issue22.rs @@ -51,5 +51,5 @@ fn main() { let outputs = session.run(inputs).unwrap(); - print!("outputs: {:#?}", outputs[0].float_array().unwrap()); + print!("outputs: {:#?}", outputs[0].float_array().unwrap().view()); } diff --git a/rust/onnxruntime/examples/sample.rs b/rust/onnxruntime/examples/sample.rs index 9af5cf733ccae..b6f351b2082ed 100644 --- a/rust/onnxruntime/examples/sample.rs +++ b/rust/onnxruntime/examples/sample.rs @@ -73,10 +73,11 @@ fn run() -> Result<(), Error> { let outputs = session.run(input_tensor_values)?; let output = outputs[0].float_array().unwrap(); + let view = output.view(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(view.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, view[[0, i, 0, 0]]); } Ok(()) diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs index 326426e35982c..d475d1b724111 100644 --- a/rust/onnxruntime/src/session.rs +++ b/rust/onnxruntime/src/session.rs @@ -410,10 +410,10 @@ impl Session { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'input, 'output>( - &'output self, + pub fn run<'input>( + &self, mut input_arrays: impl AsMut<[Box]> + 'input, - ) -> Result>> { + ) -> Result> { let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; diff --git a/rust/onnxruntime/src/tensor/ort_output_tensor.rs b/rust/onnxruntime/src/tensor/ort_output_tensor.rs index 83663c0d303f8..727cae1db0ef4 100644 --- a/rust/onnxruntime/src/tensor/ort_output_tensor.rs +++ b/rust/onnxruntime/src/tensor/ort_output_tensor.rs @@ -71,22 +71,41 @@ impl Drop for OrtOutputTensor { } /// An Output tensor with the ptr and the item that will copy from the ptr. -#[derive(Debug)] -pub struct WithOutputTensor<'a, T> { - #[allow(dead_code)] +/// +/// The view is materialized on each access via [`view()`](Self::view) to ensure the +/// borrowed lifetime is tied to `&self`, preventing the view from outliving the +/// underlying buffer owned by the `OrtOutputTensor`. +pub struct WithOutputTensor { pub(crate) tensor: OrtOutputTensor, - item: ArrayView<'a, T, ndarray::IxDyn>, + data_ptr: *const T, + shape: Vec, } -impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> { - type Target = ArrayView<'a, T, ndarray::IxDyn>; +impl Debug for WithOutputTensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WithOutputTensor") + .field("tensor", &self.tensor) + .field("data_ptr", &self.data_ptr) + .field("shape", &self.shape) + .finish() + } +} - fn deref(&self) -> &Self::Target { - &self.item +// SAFETY: The data pointer is derived from OrtOutputTensor which owns the allocation. +// Access is only possible through &self (via view()), so Send/Sync follow from T: Send/Sync. +unsafe impl Send for WithOutputTensor {} +unsafe impl Sync for WithOutputTensor {} + +impl WithOutputTensor { + /// Returns an [`ArrayView`] over the output tensor data. + /// + /// The returned view borrows `self`, so it cannot outlive the tensor owner. + pub fn view(&self) -> ArrayView<'_, T, ndarray::IxDyn> { + unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&self.shape), self.data_ptr) } } } -impl<'a, T> TryFrom for WithOutputTensor<'a, T> +impl TryFrom for WithOutputTensor where T: TypeToTensorElementDataType, { @@ -110,45 +129,45 @@ where status_to_result(status).map_err(OrtError::IsTensor)?; assert_ne!(output_array_ptr, std::ptr::null_mut()); - let array_view = - unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&value.shape), output_array_ptr) }; + let shape = value.shape.clone(); Ok(WithOutputTensor { tensor: value, - item: array_view, + data_ptr: output_array_ptr, + shape, }) } } /// The onnxruntime Run output type. -pub enum OrtOutput<'a> { +pub enum OrtOutput { /// Tensor of f32s - Float(WithOutputTensor<'a, f32>), + Float(WithOutputTensor), /// Tensor of f64s - Double(WithOutputTensor<'a, f64>), + Double(WithOutputTensor), /// Tensor of u8s - UInt8(WithOutputTensor<'a, u8>), + UInt8(WithOutputTensor), /// Tensor of u16s - UInt16(WithOutputTensor<'a, u16>), + UInt16(WithOutputTensor), /// Tensor of u32s - UInt32(WithOutputTensor<'a, u32>), + UInt32(WithOutputTensor), /// Tensor of u64s - UInt64(WithOutputTensor<'a, u64>), + UInt64(WithOutputTensor), /// Tensor of i8s - Int8(WithOutputTensor<'a, i8>), + Int8(WithOutputTensor), /// Tensor of i16s - Int16(WithOutputTensor<'a, i16>), + Int16(WithOutputTensor), /// Tensor of i32s - Int32(WithOutputTensor<'a, i32>), + Int32(WithOutputTensor), /// Tensor of i64s - Int64(WithOutputTensor<'a, i64>), + Int64(WithOutputTensor), /// Tensor of Strings - String(WithOutputTensor<'a, String>), + String(WithOutputTensor), } -impl<'a> OrtOutput<'a> { - /// Return `WithOutputTensor<'a, f32>` which derefs into an `ArrayView`. - pub fn float_array(&self) -> Option<&WithOutputTensor<'a, f32>> { +impl OrtOutput { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn float_array(&self) -> Option<&WithOutputTensor> { if let Self::Float(item) = self { Some(item) } else { @@ -156,8 +175,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, f64>` which derefs into an `ArrayView`. - pub fn double_array(&self) -> Option<&WithOutputTensor<'a, f64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn double_array(&self) -> Option<&WithOutputTensor> { if let Self::Double(item) = self { Some(item) } else { @@ -165,8 +184,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u8>` which derefs into an `ArrayView`. - pub fn uint8_array(&self) -> Option<&WithOutputTensor<'a, u8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint8_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt8(item) = self { Some(item) } else { @@ -174,8 +193,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u16>` which derefs into an `ArrayView`. - pub fn uint16_array(&self) -> Option<&WithOutputTensor<'a, u16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint16_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt16(item) = self { Some(item) } else { @@ -183,8 +202,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u32>` which derefs into an `ArrayView`. - pub fn uint32_array(&self) -> Option<&WithOutputTensor<'a, u32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint32_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt32(item) = self { Some(item) } else { @@ -192,8 +211,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u64>` which derefs into an `ArrayView`. - pub fn uint64_array(&self) -> Option<&WithOutputTensor<'a, u64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint64_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt64(item) = self { Some(item) } else { @@ -201,8 +220,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i8>` which derefs into an `ArrayView`. - pub fn int8_array(&self) -> Option<&WithOutputTensor<'a, i8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int8_array(&self) -> Option<&WithOutputTensor> { if let Self::Int8(item) = self { Some(item) } else { @@ -210,8 +229,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i16>` which derefs into an `ArrayView`. - pub fn int16_array(&self) -> Option<&WithOutputTensor<'a, i16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int16_array(&self) -> Option<&WithOutputTensor> { if let Self::Int16(item) = self { Some(item) } else { @@ -219,8 +238,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i32>` which derefs into an `ArrayView`. - pub fn int32_array(&self) -> Option<&WithOutputTensor<'a, i32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int32_array(&self) -> Option<&WithOutputTensor> { if let Self::Int32(item) = self { Some(item) } else { @@ -228,8 +247,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i64>` which derefs into an `ArrayView`. - pub fn int64_array(&self) -> Option<&WithOutputTensor<'a, i64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int64_array(&self) -> Option<&WithOutputTensor> { if let Self::Int64(item) = self { Some(item) } else { @@ -237,8 +256,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, String>` which derefs into an `ArrayView`. - pub fn string_array(&self) -> Option<&WithOutputTensor<'a, String>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn string_array(&self) -> Option<&WithOutputTensor> { if let Self::String(item) = self { Some(item) } else { @@ -247,10 +266,10 @@ impl<'a> OrtOutput<'a> { } } -impl<'a> TryFrom for OrtOutput<'a> { +impl TryFrom for OrtOutput { type Error = OrtError; - fn try_from(value: OrtOutputTensor) -> Result> { + fn try_from(value: OrtOutputTensor) -> Result { unsafe { let mut shape_info = std::ptr::null_mut(); diff --git a/rust/onnxruntime/tests/integration_tests.rs b/rust/onnxruntime/tests/integration_tests.rs index 7843fe269e5e4..1c096400eccf7 100644 --- a/rust/onnxruntime/tests/integration_tests.rs +++ b/rust/onnxruntime/tests/integration_tests.rs @@ -112,6 +112,7 @@ mod download { // and iterate on resulting probabilities, creating an index to later access labels. let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -209,6 +210,7 @@ mod download { let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -301,6 +303,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -398,6 +401,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -515,7 +519,7 @@ mod download { let output = outputs[0].float_array().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } diff --git a/setup.py b/setup.py index 3b8bb9b81d20a..62ced38819f2c 100644 --- a/setup.py +++ b/setup.py @@ -817,7 +817,9 @@ def reformat_run_count(count_str): # Adding CUDA Runtime as dependency for NV TensorRT RTX python wheel if package_name == "onnxruntime-trt-rtx": major = cuda_major_version or "12" # Default to CUDA 12 - install_requires.append(f"nvidia-cuda-runtime-cu{major}~={major}.0") + # CUDA 13 dropped the "-cuNN" suffix from the CUDA Runtime package name. + runtime_pkg = "nvidia-cuda-runtime" if int(major) >= 13 else f"nvidia-cuda-runtime-cu{major}" + install_requires.append(f"{runtime_pkg}~={major}.0") def save_build_and_package_info(package_name, version_number, cuda_version, qnn_version): @@ -862,13 +864,18 @@ def save_build_and_package_info(package_name, version_number, cuda_version, qnn_ if package_name == "onnxruntime-gpu" and cuda_major_version: # Determine cufft version: CUDA 13 uses cufft 12, CUDA 12 uses cufft 11 cufft_version = "12.0" if cuda_major_version == "13" else "11.0" + + # Starting with CUDA 13, NVIDIA renamed the per-component CUDA Toolkit packages by + # dropping the "-cuNN" suffix (e.g. "nvidia-cuda-runtime-cu12" -> "nvidia-cuda-runtime"). + # cuDNN keeps the suffixed package name ("nvidia-cudnn-cu13"). + cuda_pkg_suffix = "" if int(cuda_major_version) >= 13 else f"-cu{cuda_major_version}" extras_require.update( { "cuda": [ - f"nvidia-cuda-nvrtc-cu{cuda_major_version}~={cuda_major_version}.0", - f"nvidia-cuda-runtime-cu{cuda_major_version}~={cuda_major_version}.0", - f"nvidia-cufft-cu{cuda_major_version}~={cufft_version}", - f"nvidia-curand-cu{cuda_major_version}~=10.0", + f"nvidia-cuda-nvrtc{cuda_pkg_suffix}~={cuda_major_version}.0", + f"nvidia-cuda-runtime{cuda_pkg_suffix}~={cuda_major_version}.0", + f"nvidia-cufft{cuda_pkg_suffix}~={cufft_version}", + f"nvidia-curand{cuda_pkg_suffix}~=10.0", ], "cudnn": [ f"nvidia-cudnn-cu{cuda_major_version}~=9.0",