From b77b31239621c45fe04856e8fcfad396bf84be85 Mon Sep 17 00:00:00 2001 From: adrastogi Date: Wed, 24 Jun 2026 14:47:00 -0700 Subject: [PATCH 01/19] Relax CompileModel validation to accept zero-input OrtModel graphs (#28771) ### Description Relax the input-validation in OrtApi::CompileModel to accept OrtModel instances with zero graph inputs. Previously, ModelCompilationOptions::Check() rejected such models with "OrtModel graph must have at least one input and one output defined." The check now requires only at least one graph output; the zero-input case is legal. Tests in test_model_builder_api.cc are restructured: - The old CompileFromModelWithEmptyInputsOutputs_Fails is renamed to CompileFromModelWithEmptyOutputs_Fails and reshaped to provide 1 input + 0 outputs, isolating the output-only check. - A new regression test CompileFromModelWithEmptyInputs_Succeeds builds a 0-input model with a RandomNormal node and verifies compilation succeeds. ### Motivation and Context Fixes #28135 The original check was too restrictive and impacts callers (e.g., WebNN/Chromium needs to call CompileModel on such models in a separate compiler process (and then load the compiled artifact via CreateSessionFromArray in the GPU process)). --- .../core/session/model_compilation_options.cc | 7 +- .../test/shared_lib/test_model_builder_api.cc | 83 +++++++++++++++++-- 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index a393bb42fe2cb..a17802bdd7573 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -283,9 +283,12 @@ Status ModelCompilationOptions::Check() const { "OrtModel has no graph. Call AddGraphToModel before compilation."); } - if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { + // A model with zero graph inputs is legal (e.g., a graph composed of zero-input + // generator ops like RandomNormal that produces output without external input). + // We still require at least one graph output for the compiled model to be meaningful. + if (input_model_->graph->GetNumOutputs() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "OrtModel graph must have at least one input and one output defined."); + "OrtModel graph must have at least one output defined."); } if (input_model_->domain_to_version.empty()) { diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index a1f2f9102f027..91fc61f19e0df 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -887,11 +887,20 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithNoGraph_Fails) { EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("graph")); } -// Test validation: model with empty inputs/outputs -TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { - // Create a model with a graph that has no inputs or outputs +// Test validation: model with no outputs (one input but zero outputs). +// 0 outputs is still rejected because compilation produces an output model that +// would have no consumers for any computed values. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyOutputs_Fails) { Ort::Graph graph; - // Don't set inputs or outputs + + // Provide a single input but no outputs, to isolate the output-count check. + std::vector graph_inputs; + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph.SetInputs(graph_inputs); + // Intentionally do not call SetOutputs. std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; Model model(opsets); @@ -909,8 +918,70 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with empty inputs/outputs"; - EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("input")); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with no outputs"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("at least one output")); +} + +// Test: model with zero graph inputs is now accepted by CompileModel. +// Mirrors what CreateSessionFromModel already accepts (e.g., a graph composed of +// zero-input generator ops like RandomNormal that produces output without external input). +// Regression test for https://github.com/microsoft/onnxruntime/issues/28135. +// +// Scope: this test exercises only the ORT-side validation in ModelCompilationOptions::Check(). +// EP-specific validation (e.g., whether the WebNN EP's partitioner accepts a 0-input subgraph) +// is owned by the respective EP and is not covered here. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputs_Succeeds) { + Ort::Graph graph; + + // Zero graph inputs; one graph output produced by a RandomNormal node. + // RandomNormal takes 0 inputs and produces a tensor with shape specified via attribute. + // Use RandomNormal rather than Constant because Constant nodes are folded into initializers + // at load time (see graph.cc Graph::LoadFromModelEditorApiModel) and would not exercise + // the true 0-input producer path. + std::vector output_dims = {2, 3}; + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + + std::vector graph_outputs; + graph_outputs.emplace_back("Y", output_type_info.GetConst()); + graph.SetOutputs(graph_outputs); + // Intentionally do not call SetInputs (zero graph inputs). + + std::vector attributes; + std::vector shape_attr_value = {2, 3}; + attributes.push_back(OpAttr("shape", shape_attr_value.data(), + static_cast(shape_attr_value.size()), + OrtOpAttrType::ORT_OP_ATTR_INTS)); + + Node node("RandomNormal", onnxruntime::kOnnxDomain, "RandomNormal1", + /*input_names*/ {}, /*output_names*/ {"Y"}, attributes); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Compile should succeed. No specific compiling EP is required; the default + // kGenerateModel action emits an output model even when no EPContext nodes are produced. + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + // Fatal: a compile that returns OK but produces no artifact bytes is a regression. + ASSERT_NE(output_buffer, nullptr); + ASSERT_GT(output_size, 0u); + + allocator->Free(output_buffer); } // Test: model can be reused after compilation. From eade9ea991bb1cc52864bd27548d93fc2b1460be Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:08:38 -0700 Subject: [PATCH 02/19] Clamp derived sequence lengths and KV-cache index in CUDA GroupQueryAttention (#29240) ### Description The CUDA `GroupQueryAttention` kernel derives a KV-cache append offset from the `seqlens_k` input (`past_seq_lens = (seqlens_k + 1) - sequence_length`). On the CUDA EP `seqlens_k` is device-resident (only `total_sequence_length` is a CPU input), so the host-side range validation in the operator/helper is skipped. The device kernel `UnpackRoPEAppend` then guarded the cache store with only a one-sided upper bound (`cache_s < max_seqlen`), so an out-of-range `seqlens_k` could produce a negative offset that is sign-extended into the cache-index arithmetic. The CPU operator already validates `seqlens_k` host-side; this change brings the CUDA path to parity by guarding on the device. ### Changes - `group_query_attention_impl.cu` (`GetSequenceLengths`): clamp the negative case at the source so both `total_seq_lens` and the append offset `past_seq_lens` stay non-negative for all downstream consumers. - `group_query_attention_qkv.cuh` (`UnpackRoPEAppend`): make the KV-cache store bound two-sided (`cache_s >= 0 && cache_s < max_seqlen`), mirroring the existing position-index guard a few lines above. This also covers the fast-decode path, where `past_seq_lens` points directly at the raw input and bypasses `GetSequenceLengths`. - Added `NegativeSeqlensK_CacheAppend_NoOOB_CUDA` regression test exercising the KV-cache append path with an out-of-range `seqlens_k` (CUDA-guarded; skips when CUDA EP is unavailable). ### Notes - The two-sided guard matches the pattern introduced for the rotary position index in #27597. - CPU is unaffected (already validated host-side); WebGPU relies on the CPU-validated `total_sequence_length`. The CUDA implementation is shared with ROCm via hipify. - The regression is a device-memory write best observed under `compute-sanitizer`; the test asserts the run completes with finite outputs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cuda/bert/group_query_attention_impl.cu | 9 +- .../cuda/bert/group_query_attention_qkv.cuh | 4 +- .../group_query_attention_op_test.cc | 103 ++++++++++++++++++ 3 files changed, 113 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 6a2d95f089f2b..aa84ee05c2bd1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -641,13 +641,18 @@ __global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, const bool is_first_prompt) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < batch_size) { - const int total_len = total_seq_lens_minus_one[i] + 1; + // total_seq_lens_minus_one is the seqlens_k input and is not range-checked on the device. + // Clamp the negative case at the source so the derived lengths below stay non-negative and + // cannot flow as negative offsets into KV-cache or attention index computations. + const int seqlens_k = total_seq_lens_minus_one[i]; + const int total_len = (seqlens_k > 0 ? seqlens_k : 0) + 1; total_seq_lens[i] = total_len; if (is_first_prompt) { past_seq_lens[i] = 0; padded_seq_lens[i] = sequence_length; } else { - past_seq_lens[i] = total_len - sequence_length; + const int past_len = total_len - sequence_length; + past_seq_lens[i] = past_len > 0 ? past_len : 0; padded_seq_lens[i] = 0; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 0c62aef11d53a..1a5dc6e0fca6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -231,7 +231,9 @@ __global__ void UnpackRoPEAppend( } else { // Store K or V into the KV cache at index (past_seqlen + s) const int cache_s = past_seq_lens[b] + s; - if (cache_s < max_seqlen) { + // Two-sided bound: the lower check mirrors the position guard above and prevents a + // negative offset from being sign-extended into the cache index arithmetic below. + if (cache_s >= 0 && cache_s < max_seqlen) { void* cache_ptr = (head_type == KEY) ? k_cache : v_cache; if (cache_ptr != nullptr) { int64_t cache_idx; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 645564d01abc0..c6f6e0affc163 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -1569,6 +1569,109 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_Prompt_CUDA) { ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_Prompt_CUDA_vs_CPU"); } +// CUDA: out-of-range (negative) seqlens_k must not drive an out-of-bounds KV-cache write. +// On the CUDA EP seqlens_k is device-resident, so the host-side range check in the operator is +// skipped and the derived append offset is clamped on the device instead. With sequence_length > 1 +// the non-fast-decode path is taken, exercising both the derived-length clamp and the cache-store +// bound. The run must complete and yield finite outputs. This is a memory-safety regression that is +// most precisely observed under compute-sanitizer, where the pre-clamp code reported an invalid +// device write at this site. +TEST(GroupQueryAttentionTest, NegativeSeqlensK_CacheAppend_NoOOB_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 2; // > 1 forces the non-fast-decode path + constexpr int past_seq_len = 4; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be a multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int total_sequence_length = past_seq_len + sequence_length; + constexpr int present_seq_len = total_sequence_length; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + std::vector past_key_data(batch_size * kv_num_heads * past_seq_len * head_size); + std::vector past_value_data(batch_size * kv_num_heads * past_seq_len * head_size); + for (size_t i = 0; i < query_data.size(); ++i) query_data[i] = 0.05f * static_cast(i % 7 + 1); + for (size_t i = 0; i < key_data.size(); ++i) key_data[i] = 0.04f * static_cast(i % 5 + 1); + for (size_t i = 0; i < value_data.size(); ++i) value_data[i] = 0.03f * static_cast(i % 3 + 1); + for (size_t i = 0; i < past_key_data.size(); ++i) past_key_data[i] = 0.02f * static_cast(i % 11 + 1); + for (size_t i = 0; i < past_value_data.size(); ++i) past_value_data[i] = 0.01f * static_cast(i % 13 + 1); + + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, ToFloat16(query_data)); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(key_data)); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(value_data)); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, ToFloat16(past_key_data)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + ToFloat16(past_value_data)); + + // seqlens_k is negative, so the derived past length, (max(seqlens_k, 0) + 1) - sequence_length, is + // negative (here 0 + 1 - 2 = -1). The device-side derivation must neutralize this so the cache append + // for the new tokens stays within the present buffer instead of indexing before its start. + tester.AddInput("seqlens_k", {batch_size}, {-1}); + // Marked as an initializer so shape inference can read the value at graph-build time and size + // present_kv to max(past_seq_len, total_sequence_length), matching the declared present outputs below. + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, ToFloat16(cos_cache)); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, ToFloat16(sin_cache)); + + // Valid position_ids so the rotary index path is well-formed and only the cache-store bound is stressed. + std::vector position_ids(batch_size * sequence_length); + for (int s = 0; s < sequence_length; ++s) { + position_ids[s] = static_cast(past_seq_len + s); + } + tester.AddInput("position_ids", {batch_size, sequence_length}, position_ids); + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, MLFloat16(0.0f))); + const int present_size = batch_size * kv_num_heads * present_seq_len * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + + // The malformed seqlens_k drives the derived past length negative, which is the condition under test. + // That leaves the KV length under-specified for the query, so the attention is degenerate and its + // outputs may be non-finite; this is expected and intentionally not asserted. The regression point is + // that the cache append and attention complete without indexing outside their buffers (which a + // sanitizer build would otherwise flag), so only the output shape is verified. + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([](const std::vector& fetches, + const std::string& /*provider*/) { + ASSERT_FALSE(fetches.empty()); + ASSERT_TRUE(fetches[0].IsTensor()); + EXPECT_EQ(fetches[0].Get().Shape().Size(), static_cast(output_size)); + }); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // --------------------------------------------------------------------------- // Quantized KV cache tests for CPU GroupQueryAttention // --------------------------------------------------------------------------- From 996cea1e829b9f577fa8ffa37045d26f053b0960 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 24 Jun 2026 15:55:42 -0700 Subject: [PATCH 03/19] Add flash attention for non-quantized CPU GroupQueryAttention (#28962) ## Summary Adds an FP32 flash attention path for the CPU `com.microsoft.GroupQueryAttention` (GQA) contrib op, mirroring the existing quantized-KV flash attention path. The new tiled, online-softmax kernel avoids materializing the full `[S, T]` attention score matrix. It is restricted to prefill / chunked-prefill (`sequence_length > 1`); single-token decode falls back to the naive path. With causal early-termination it is faster than the naive path across all measured prefill lengths while using a fraction of the memory. ## Key changes - **New MLAS kernel** `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (`MlasFlashAttentionGQA`): - Tiled QK / softmax / SV with online-softmax (running max/sum rescaling). - GQA head grouping (`num_heads % kv_num_heads == 0`), causal masking, local window, additive attention bias, and packed-QKV input. - **Causal early-termination**: during prefill, KV blocks that fall entirely in the causally masked upper triangle are skipped (`break` once `ir >= past_seqlen + q_idx + row_size_q`), avoiding the wasted QK/SV GEMMs over roughly half of the square prefill attention matrix. - Per-batch invocation for ragged / shared-buffer `seqlens_k`. - **MLAS API** `onnxruntime/core/mlas/inc/mlas.h`: new `MlasFlashAttentionGQAArgs` struct and `MlasFlashAttentionGQA` declaration. - **Dispatch** `onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h`: new `ApplyAttentionFlash` that concatenates new K/V into the FP32 present cache and invokes the kernel. The per-thread scratch buffer size is computed with `SafeInt` to guard against `size_t` overflow on large/malformed shapes before allocation. - **Wiring** `onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc`: float-only flash dispatch, active only for prefill (`sequence_length > 1`) and when `softcap == 0`, no smooth softmax, no head sink, no QK output; falls back to the naive path otherwise. The existing `ORT_GQA_DISABLE_FLASH_ATTENTION` env var disables it. - **CMake** `cmake/onnxruntime_mlas.cmake`: register the new source file. - **Docs** `docs/contrib_ops/cpu/gqa.md`: document the non-quantized flash attention path, activation conditions, causal early-termination, file list, and FP32 flash-vs-naive benchmark results. - **Benchmark** `onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py`: add an FP32 (non-quantized) mode (`--fp32`) for operator-level flash-vs-naive comparison. ### Why prefill-only (`sequence_length > 1`) Single-token decode (`sequence_length == 1`) produces only a `[1, total_sequence_length]` score row per head, so there is nothing to tile away and the extra online-softmax bookkeeping makes the flash kernel slower and noisier than naive in practice. Restricting the flash path to prefill keeps the consistent prefill win without regressing decode. Because decode is excluded, the two-phase flash-decoding kernels are unreachable and have been removed for a smaller, simpler implementation. `float16` continues to use the naive path (the kernel is float-only, matching the quantized flash constraint). ## Performance Operator-level, AMD EPYC 7763 (16 physical cores), threads=8, FP32 KV cache, `B=1, num_heads=16, kv_num_heads=8, head_size=128`. Flash is faster than naive across all measured prefill lengths (and single-threaded as well, 1.4-1.8x), confirming the gain is algorithmic - the causal early-termination removes the wasted upper-triangle work that previously made flash slower than naive at short sequences. | Prefill Seq Length | Naive (ms) | Flash (ms) | Speedup | |---:|---:|---:|---:| | 512 | 5.8-8.4 | 4.2-5.3 | 1.4-1.6x | | 1024 | 25-29 | 13-18 | 1.6-2.0x | | 2048 | 87-118 | 52-65 | 1.5-2.0x | | 4096 | 365-380 | 213-234 | 1.6-1.7x | The flash path's primary structural benefit is memory: it never allocates the full O(N x S x T) attention matrix (~1 GB at S=4096, N=16) and instead uses an O(S x Bc) per-thread tile. ## Testing - **C++ op tests**: `onnxruntime_provider_test --gtest_filter="GroupQueryAttentionTest.*"` - 38 passed (12 GPU/WebGPU skipped) with flash on (default) and with `ORT_GQA_DISABLE_FLASH_ATTENTION=1`. - **Flash vs. naive parity** (FP32): output of the flash path matches the naive path (max abs diff ~1e-7) across prefill (block-aligned and non-aligned `S`), MHA and GQA head ratios, and local window. Decode now uses the naive path on both sides (diff 0). - **Python parity** (`test_gqa_cpu.py`, flash vs. naive reference): focused FP32 sweep of 600 prompt configurations covering all head sizes (32-256), GQA ratios `(6,6)/(6,3)/(9,9)/(9,3)`, batches `1/3/5`, causal/local window, attention bias, position ids, packed QKV, and with/without KV buffer - all passed. The official `test_gqa_cpu.py` suite passes. Two correctness bugs were found and fixed via the parity sweep while developing this path: 1. Attention-bias batch stride ignored head broadcasting for `[batch, 1, S, T]` bias. 2. Query batch stride was hardcoded to `num_heads * S * H`, which is incorrect for packed-QKV input (correct stride is `(num_heads + 2 * kv_num_heads) * S * H`). --- cmake/onnxruntime_mlas.cmake | 1 + docs/contrib_ops/cpu/gqa.md | 101 +++++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 284 ++++++++++++++++ .../cpu/bert/group_query_attention.cc | 25 ++ onnxruntime/core/mlas/inc/mlas.h | 51 +++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 310 ++++++++++++++++++ .../transformers/benchmark_gqa_cpu_flash.py | 197 ++++++++--- 7 files changed, 911 insertions(+), 58 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/flashattn_gqa.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b8ab7142b6b35..d55a4d49fb455 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/flashattn_qkv.cpp + ${MLAS_SRC_DIR}/flashattn_gqa.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index e5a211c9fd11a..8b81fdba8f1a6 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` -- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel) + +The non-quantized flash attention tiled kernel is implemented in MLAS: + +- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel) +- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`) The operator schema itself is defined in: @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. -The quantized path has two execution strategies: +Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). ## Supported Cache Modes @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. -## Flash Attention Path +## Quantized Flash Attention Path -The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. +The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. ### Algorithm @@ -204,6 +211,58 @@ The partials buffer is allocated alongside the per-thread scratch in a single al - Per-thread scratch: `scores[Bc]` (one float per KV block element) - Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) +## Non-Quantized Flash Attention Path + +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here. + +### Differences from the Quantized Path + +- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step. +- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`. +- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`. +- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA` before the tiled loop runs. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +#### Causal early-termination + +During prefill, every KV block whose start index is at or beyond the largest global query +position in the current q_block is fully causally masked and contributes nothing. The kernel +computes a per-q_block bound +`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once +`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle +QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the +main reason the FP32 flash path is faster than naive even at short sequence lengths +(see the benchmark results below). + +### Activation Conditions + +The non-quantized flash path is selected when ALL of the following hold: + +- The kernel specialization is `float` (FP16 uses the naive path) +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path) +- No softcap +- No smooth softmax +- No head sink +- No output QK capture +- `present_key` and `present_value` are provided + +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes and Threading + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode. + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. @@ -428,7 +487,39 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle | 4096 (N=32) | +2131 | +87 | 24.5x | **Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. +### Non-Quantized (FP32) Flash Attention vs Naive benchmark results + +Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache, +`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with: +```bash +python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \ + --fp32 --prompt_only --warmup 10 --repeats 30 +``` + +#### Latency — Prefill (S = T, prompt phase) + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x | +| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x | +| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x | +| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x | + +The FP32 flash path is faster than naive across all measured prefill lengths. With the causal +early-termination described above, roughly half of the QK/SV work (the causally masked +upper triangle of the square prefill attention matrix) is skipped entirely, which more than +offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale). +The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is +algorithmic rather than purely from threading. + +#### Decode (S = 1, token generation) + +Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls +back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head, +so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel +slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps +the consistent prefill win without regressing decode. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 59313cf527c91..413483756ed5c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -909,6 +909,290 @@ class GQAAttentionBase { return Status::OK(); } + // Non-quantized flash attention path. Only supports T = float. + // Concatenates new K/V into the FP32 present cache, then runs the tiled + // online-softmax kernel MlasFlashAttentionGQA (QK^T + softmax + S*V fused). + Status ApplyAttentionFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (float) + const Tensor* past_value, // past V (float) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (float) + Tensor* present_value, // present V (float) + const Tensor* seqlens_k, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for flash attention"); + + const float* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + float* present_key_data = present_key->MutableData(); + const float* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + float* present_value_data = present_value->MutableData(); + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = SafeInt(kv_sequence_length) * head_size; + const size_t past_buff_chunk_length = SafeInt(seqlen_past_kv_cache) * head_size; + const size_t present_buff_chunk_length = SafeInt(seqlen_present_kv_cache) * head_size; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast((past_buff_chunk_length + kv_input_chunk_length) * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_length * sizeof(float)); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_length = past_seqlen * head_size; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_key_data, k_new, present_key_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_value_data, v_new, present_value_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + } + }); + } + + // ---- Phase 2: Flash Attention with FP32 KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention). + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), fall back to per-batch invocation + // so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Per-thread scratch: l + m + scores[q_block_size * kv_block_size] + temp_output[q_block_size * head_size] + const size_t buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionGQAArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.q_batch_stride = packed_qkv + ? static_cast(packed_batch_stride) + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + + MlasFlashAttentionGQA(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionGQAArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0 + ? packed_batch_stride + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.query = Q + static_cast(b) * static_cast(q_batch_stride_elems); + args.q_batch_stride = SafeInt(num_heads_) * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside). + // Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head + // extent (1 when the head dim is broadcast). + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast(num_heads_); + batch_bias += static_cast(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + + MlasFlashAttentionGQA(&args, tp); + } + } + + return Status::OK(); + } + private: // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index e36bdb2de263a..692a5759d1adc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -343,6 +343,31 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V const T* k_data = packed_qkv ? nullptr : k_rotary; const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + + // Non-quantized flash attention path (float only). Uses the tiled online-softmax + // kernel to avoid materializing the full attention score matrix. Falls back to the + // naive path when an unsupported feature is requested (softcap, smooth softmax, + // head sink, or QK output). + if constexpr (std::is_same_v) { + // Restrict the flash path to prefill / chunked-prefill (query length > 1). Single-token + // decode (sequence_length == 1) has no flash benefit: the naive score matrix is only + // [1, total_sequence_length] per head, so there is nothing to tile away, and the extra + // online-softmax bookkeeping makes it slower in practice. + const bool use_flash = !disable_gqa_flash_ && + parameters.sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr && + present_k != nullptr && present_v != nullptr; + if (use_flash) { + return ApplyAttentionFlash(q_rotary, k_data, v_data, + attention_bias, past_key, past_value, + output, present_k, present_v, seqlens_k, + parameters, allocator, context); + } + } + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 811bad15ebbab..2410fcc83e7cd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2297,6 +2297,57 @@ MlasFlashAttention( MLAS_THREADPOOL* ThreadPool ); +// +// Flash Attention for non-quantized (FP32) GroupQueryAttention KV cache. +// +// Adapts the online-softmax tiled algorithm to operate on an FP32 present +// K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). +// Supports GQA head grouping, causal masking, local window attention, and +// additive attention bias. Intended for prefill / chunked-prefill +// (sequence_length > 1). +// +struct MlasFlashAttentionGQAArgs { + int batch_size; + int num_heads; // number of query heads + int kv_num_heads; // number of key/value heads (num_heads % kv_num_heads == 0) + int sequence_length; // number of new query tokens (S) + int total_seqlen; // total tokens (past + new) for this invocation (T) + int head_size; // per-head size (H) + int past_seqlen; // causal offset (number of cached tokens before the new ones) + int local_window_size; // -1 disables local window masking + int seqlen_present_kv; // sequence dimension of the present K/V buffer + int q_block_size; // query tile size (Br) + int kv_block_size; // key/value tile size (Bc) + float scale; // QK scaling factor + int thread_count; // number of partitions / threads + float* buffer; // per-thread scratch + size_t buffer_size_per_thread; + + const float* query; // [batch, num_heads, sequence_length, head_size] BNSH + size_t q_batch_stride; // element stride between consecutive batches in `query` + // (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV) + const float* k_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + const float* v_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + float* output; // [batch, sequence_length, num_heads, head_size] BSNH + + const float* attention_bias; // [batch|1, num_heads|1, S, T] additive bias, or nullptr + int attention_bias_seqlen_stride; + bool attention_bias_broadcast_batch; + bool attention_bias_broadcast_head; +}; + +/** + * @brief FP32 Flash Attention for GroupQueryAttention with an FP32 KV cache. + * @param args Arguments + * @param ThreadPool Thread pool + */ +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +); + /** * @brief Enumeration of supported GELU algorithm variants. * diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp new file mode 100644 index 0000000000000..4d0ff65733a44 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -0,0 +1,310 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_gqa.cpp + +Abstract: + + Flash Attention kernel for the non-quantized (FP32) GroupQueryAttention + KV cache. + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate on + an FP32 present K/V cache laid out as BNSH + ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head + grouping (num_heads % kv_num_heads == 0), causal masking, local window + attention, and additive attention bias. Intended for prefill / + chunked-prefill (sequence_length > 1). + + QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; + the outer parallelism is provided by MlasExecuteThreaded. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" + +void +MlasFlashAttentionGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers. Layout: [batch, kv_num_heads, seqlen_present, head_size] + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is + // supplied separately (args->q_batch_stride) so the kernel works with both the + // standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Causal early-termination bound: the largest global query position in this + // q_block is (past_seqlen + q_idx + row_size_q - 1), so it can attend to KV + // positions up to that index inclusive. Any KV block starting at or beyond + // (past_seqlen + q_idx + row_size_q) is fully causally masked for every row in + // the block, so it contributes nothing and can be skipped. This avoids the + // wasted QK/SV GEMMs over the causal upper triangle during prefill. + const ptrdiff_t kv_causal_limit = + past_seqlen + q_idx + static_cast(row_size_q); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + if (ir >= kv_causal_limit) { + break; + } + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with FP32 K block + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch + // stride uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so the S*V GEMM contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + ir == 0 ? 0.0f : 1.0f, // beta (accumulate after first block) + temp_output, // C (accumulated output) + static_cast(head_size) // ldc + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); +} diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py index 77ac08cf50d6c..7dbcb16a75973 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -106,6 +106,70 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with a non-quantized FP32 KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ] + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + def benchmark_gqa( batch_size, seq_len, @@ -117,6 +181,7 @@ def benchmark_gqa( past_seq_len=0, warmup=5, repeats=20, + non_quantized=False, ): """Benchmark a single GQA configuration. Returns elapsed time in ms.""" hidden_size = num_heads * head_size @@ -126,54 +191,76 @@ def benchmark_gqa( total_seqlen = past_seq_len + seq_len buffer_seq_len = total_seqlen - onnx_model_str = create_quantized_gqa_graph( - batch_size, - seq_len, - num_heads, - kv_num_heads, - head_size, - quant_type, - bit_width, - buffer_seq_len=buffer_seq_len, - ) - sess_options = SessionOptions() sess_options.intra_op_num_threads = 8 - sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - # Generate inputs np.random.seed(42) query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) - - cache_dtype = np.uint8 if bit_width == 4 else np.int8 - past_k = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - past_v = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) total_seq = np.array([total_seqlen], dtype=np.int32) - per_channel = quant_type == "PER_CHANNEL" - scale_size = kv_num_heads * head_size if per_channel else 1 - k_scale = np.full(scale_size, 0.01, dtype=np.float32) - v_scale = np.full(scale_size, 0.01, dtype=np.float32) - - feeds = { - "query": query, - "key": key, - "value": value, - "past_key": past_k, - "past_value": past_v, - "seqlens_k": seqlens_k, - "total_sequence_length": total_seq, - "k_scale": k_scale, - "v_scale": v_scale, - } + if non_quantized: + onnx_model_str = create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + past_k = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + past_v = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + } + else: + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } # Warmup for _ in range(warmup): @@ -242,20 +329,21 @@ def run_benchmarks(args): "past_seq_len": 2048, } ) - # INT4 prefill - configs.append( - { - "label": "Prefill S=2048 INT4", - "batch_size": 1, - "seq_len": 2048, - "num_heads": 16, - "kv_num_heads": 8, - "head_size": 128, - "quant_type": "PER_TENSOR", - "bit_width": 4, - "past_seq_len": 0, - } - ) + # INT4 prefill (quantized mode only) + if not args.fp32: + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) warmup = args.warmup repeats = args.repeats @@ -263,13 +351,15 @@ def run_benchmarks(args): # Save and restore env var to avoid side effects on callers saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + kv_mode = "FP32 (non-quantized)" if args.fp32 else "INT8/INT4 quantized" print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") - print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"KV cache: {kv_mode}, Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") print("-" * 62) for cfg in configs: label = cfg.pop("label") + cfg["non_quantized"] = args.fp32 # Flash path (default) os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) @@ -296,5 +386,6 @@ def run_benchmarks(args): parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + parser.add_argument("--fp32", action="store_true", help="Use non-quantized FP32 KV cache instead of quantized") args = parser.parse_args() run_benchmarks(args) From 4ed5a4addef306170c9b39c4f094aff0cb16ba7c Mon Sep 17 00:00:00 2001 From: Vineeth Chelur Date: Thu, 25 Jun 2026 05:34:30 +0530 Subject: [PATCH 04/19] Update cpuinfo to include cpuinfo_deinitialize(), fix QNN ETW logging, GQA underflow, and ep_weight_sharing_ctx_gen build (#28245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR contains three commits: **Commit 1: Miscellaneous fixes** - Downgrade QNN ETW profiling mismatch logs from ERROR to VERBOSE to reduce excessive telemetry noise (~1 billion events/week across Windows devices) - Add bounds checking in GQA attention to prevent `size_t` underflow when `seqlens_k` contains invalid data (fixes #27170) - Build `ep_weight_sharing_ctx_gen` for TensorRT, OpenVINO, and VitisAI in addition to QNN **Commit 2: Bump cpuinfo and add `cpuinfo_deinitialize()` integration** Applications that dynamically load and unload the onnxruntime DLL leave orphaned heap allocations from cpuinfo when the library is unloaded mid-process. These are flagged as memory leaks by App Verifier, Valgrind, AddressSanitizer, and LeakSanitizer. This commit bumps `pytorch/cpuinfo` to a version that implements `cpuinfo_deinitialize()` ([pytorch/cpuinfo#387](https://github.com/pytorch/cpuinfo/pull/387)) and adds ORT integration: - `CPUIDInfo::ShutDown()` calls `cpuinfo_deinitialize()` to free heap-allocated globals - `DllMain` calls `ShutdownCpuInfo()` on `DLL_PROCESS_DETACH` - In memleak-check builds, shutdown also runs during process termination - `InstanceCreated` atomic guard prevents singleton creation during DLL unload **Commit 3: Update to official cpuinfo merged fix** After [pytorch/cpuinfo#387](https://github.com/pytorch/cpuinfo/pull/387) merged upstream, updated the dependency to point to `pytorch/cpuinfo` main (`4628dc06`). Patch changes: - **Removed** `win_arm_fp16_detection_fallback.patch` — upstreamed via [pytorch/cpuinfo#348](https://github.com/pytorch/cpuinfo/pull/348) - **Updated** `patch_vcpkg_arm64ec_support.patch` — regenerated for new cpuinfo; still needed ([pytorch/cpuinfo#324](https://github.com/pytorch/cpuinfo/pull/324) not yet merged) - **Updated** `patch_cpuinfo_h_for_arm64ec.patch` — retained, not yet upstream - **Regenerated** `fix_missing_sysfs_fallback.patch` — updated context lines for new cpuinfo code ### Motivation and Context - https://github.com/pytorch/cpuinfo/issues/150 - https://github.com/microsoft/onnxruntime/issues/16117 - https://github.com/microsoft/onnxruntime/issues/23762 --- cmake/deps.txt | 2 +- .../external/onnxruntime_external_deps.cmake | 4 +- cmake/onnxruntime_unittests.cmake | 4 +- .../cpuinfo/fix_missing_sysfs_fallback.patch | 58 ++++++++++++++++--- .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 4 +- .../win_arm_fp16_detection_fallback.patch | 19 ------ .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 4 +- cmake/vcpkg-ports/cpuinfo/portfile.cmake | 7 +-- .../win_arm_fp16_detection_fallback.patch | 19 ------ onnxruntime/core/common/cpuid_info.cc | 9 +++ onnxruntime/core/common/cpuid_info.h | 1 + onnxruntime/core/platform/posix/env.cc | 5 ++ .../qnn/builder/qnn_backend_manager.cc | 4 +- 13 files changed, 78 insertions(+), 62 deletions(-) delete mode 100644 cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch delete mode 100644 cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch diff --git a/cmake/deps.txt b/cmake/deps.txt index e303ccd9f8a98..d6a5b71221dc8 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -50,7 +50,7 @@ protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/downlo psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v3.0.2.zip;a064e663b4d7a337ac291d1bef7337ef4e60a1ae -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/403d652dca4c1046e8145950b1c0997a9f748b57.zip;30b2a07fe4bae8574f89176e56274cacdd6d135b +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/4628dc060ce4e82345dc166bbac875609db4ff69.zip;e58d4b47c16a982111c897e669ae4f1821a393d7 re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 1a1e4921a41e6..ed3b0aa8192a7 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -371,9 +371,7 @@ if (CPUINFO_SUPPORTED) PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && # https://github.com/pytorch/cpuinfo/pull/324 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch && - # https://github.com/pytorch/cpuinfo/pull/348 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/win_arm_fp16_detection_fallback.patch + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 23eccb22476df..c32aa7f4ae75a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1609,8 +1609,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() - - if(onnxruntime_USE_QNN) + # Build ep_weight_sharing_ctx_gen for all supported EPs (QNN, TensorRT, OpenVINO, VitisAI) + if(onnxruntime_USE_QNN OR onnxruntime_USE_TENSORRT OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_VITISAI) #qnn ctx generator set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) set(ep_weight_sharing_ctx_gen_src_patterns diff --git a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch index 005cd458fdd2b..47a1054e25107 100644 --- a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch +++ b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch @@ -1,10 +1,19 @@ diff --git a/src/linux/processors.c b/src/linux/processors.c -index 47bee76..d0c5569 100644 +index fd040a3..2ca8ec4 100644 --- a/src/linux/processors.c +++ b/src/linux/processors.c -@@ -2,0 +3 @@ +@@ -3,6 +3,7 @@ + #include + #include + #include +#include -@@ -291,0 +293,22 @@ + + #if !defined(__ANDROID__) + /* +@@ -289,6 +290,28 @@ static bool max_processor_number_parser(uint32_t processor_list_start, uint32_t + return true; + } + +static uint32_t cpuinfo_linux_get_max_processor_from_sysconf( + uint32_t max_processors_count, + const char* processor_list_name) { @@ -27,13 +36,31 @@ index 47bee76..d0c5569 100644 + return max_processor; +} + -@@ -301 +324 @@ + uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) { + uint32_t max_possible_processor = 0; + if (!cpuinfo_linux_parse_cpulist( +@@ -298,7 +321,7 @@ uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, POSSIBLE_CPULIST_FILENAME); -@@ -323 +346 @@ + } + if (max_possible_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -320,7 +343,7 @@ uint32_t cpuinfo_linux_get_max_present_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, PRESENT_CPULIST_FILENAME); -@@ -357,0 +381,31 @@ + } + if (max_present_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -355,6 +378,37 @@ static bool detect_processor_parser(uint32_t processor_list_start, uint32_t proc + return true; + } + +static bool cpuinfo_linux_detect_processors_from_sysconf( + uint32_t max_processors_count, + uint32_t* processor0_flags, @@ -65,7 +92,13 @@ index 47bee76..d0c5569 100644 + return true; +} + -@@ -373 +427,6 @@ + bool cpuinfo_linux_detect_possible_processors( + uint32_t max_processors_count, + uint32_t* processor0_flags, +@@ -370,7 +424,12 @@ bool cpuinfo_linux_detect_possible_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -73,7 +106,13 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + possible_flag, + POSSIBLE_CPULIST_FILENAME); -@@ -392 +451,6 @@ + } + } + +@@ -389,7 +448,12 @@ bool cpuinfo_linux_detect_present_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -81,3 +120,6 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + present_flag, + PRESENT_CPULIST_FILENAME); + } + } + diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index 67bd18e61cc28..9140a233e2ccd 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,13 +6,12 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 403d652dca4c1046e8145950b1c0997a9f748b57 - SHA512 f7cd6dc44bd1120af610cae1337ed4c0f557ba78d2de9c73fed350fa3dfe9512643a1619ae55f5a540c6316a87d641856cca27297bb8766e48f39b7b7a59da1f - HEAD_REF master + REF 4628dc060ce4e82345dc166bbac875609db4ff69 + SHA512 db7a93279f2f6daaf825fbd8552935d8ed671d276b65ad614e11f722b6a6848e663850d65180d33b554d67ef1a36aae842feb368699f90be8f21172a1af1924e + HEAD_REF main PATCHES patch_cpuinfo_h_for_arm64ec.patch patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 - win_arm_fp16_detection_fallback.patch # https://github.com/pytorch/cpuinfo/pull/348 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index ebf3cc9f50be6..ec5c1386e8336 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -405,4 +405,13 @@ CPUIDInfo::CPUIDInfo() { #endif #endif // defined(CPUIDINFO_ARCH_RISCV64) } + +CPUIDInfo::~CPUIDInfo() { +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + cpuinfo_deinitialize(); + pytorch_cpuinfo_init_ = false; + } +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index bf502c645c9eb..6eed234332f46 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -110,6 +110,7 @@ class CPUIDInfo { static void LogEarlyWarning(std::string_view message); CPUIDInfo(); + ~CPUIDInfo(); void VendorInfoInit(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 0270bf9d4d79c..c34d8b3dbf696 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -657,6 +657,11 @@ class PosixEnv : public Env { } } } + ~PosixEnv() { + if (cpuinfo_available_) { + cpuinfo_deinitialize(); + } + } bool cpuinfo_available_{false}; #endif // ORT_USE_CPUINFO }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5758ff3ad2847..f586fc8e117a6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1871,13 +1871,13 @@ Status QnnBackendManager::ExtractBackendProfilingInfo(qnn::profile::ProfilingInf // ETW disabled previously, but enabled now if (ProfilingLevel::INVALID == profiling_level_etw_ && tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } // ETW enabled previously, but disabled now if (ProfilingLevel::INVALID != profiling_level_etw_ && !tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } From 3b022ecb71f4756d899e06380afe1c2d2850edc1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 24 Jun 2026 17:20:56 -0700 Subject: [PATCH 05/19] [CUDA] Support user compute stream with CUDA graph in CUDA plugin EP (#29221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description The CUDA plugin EP previously rejected combining a user-provided compute stream (`user_compute_stream`) with CUDA graph capture (`enable_cuda_graph`), returning `ORT_INVALID_ARGUMENT`. This PR removes that restriction so the two options can be used together: when both are set, graph capture and replay run on the user-owned stream (the same stream the kernels are issued to), matching the bundled (non-plugin) CUDA EP behavior. Several supporting fixes make capture on a shared stream stable and Memcpy-free. ## Summary of Changes ### Allow user stream + CUDA graph | File | Change | |------|--------| | [onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc](onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc) | Remove the validation that rejected `user_compute_stream` + `enable_cuda_graph` together. | | [onnxruntime/core/providers/cuda/plugin/cuda_ep.cc](onnxruntime/core/providers/cuda/plugin/cuda_ep.cc) | `PerThreadContext` accepts an optional external graph stream. When both options are set it captures/replays on the user stream and does **not** create or destroy it (the user owns its lifetime); otherwise it owns a dedicated graph stream as before. | ### Stable, Memcpy-free CUDA graph capture | File | Change | |------|--------| | [onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h](onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h) | Route kernel scratch/workspace allocations through the EP allocator (BFC arena) instead of raw `cudaMallocAsync`/`cudaMalloc`. After warmup the arena reaches steady state, so the capture run serves scratch from already-reserved chunks and the device free-memory footprint stays stable — required for correct capture. Matches the built-in CUDA EP. | | [onnxruntime/core/providers/cuda/tensor/shape_op.cc](onnxruntime/core/providers/cuda/tensor/shape_op.cc) | Add an adapter-based `Shape` kernel under `#ifdef BUILD_CUDA_EP_AS_PLUGIN` with identical semantics to the CPU `Shape`. Registering `Shape` on the EP keeps it off the CPU EP and avoids the Memcpy nodes that would otherwise break CUDA graph capture. | | [cmake/onnxruntime_providers_cuda_plugin.cmake](cmake/onnxruntime_providers_cuda_plugin.cmake) | Stop excluding `shape_op.cc` from the plugin build so the adapter-based `Shape` kernel is compiled in. | ### Null-allocator fallback in PrePack (plugin boundary) In the plugin build the `AllocatorPtr` passed to `PrePack` can arrive null across the library boundary. Each kernel now falls back to its own default-memory allocator (`Info().GetAllocator(OrtMemTypeDefault)`), which is always valid. - [onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) - [onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc](onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc) - [onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc](onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc) ### Misc - [onnxruntime/core/framework/session_state.cc](onnxruntime/core/framework/session_state.cc) — wrap a long line (no behavior change). ## Testing - New test: [onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc](onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc) covering: 1. Session creation succeeds with both `user_compute_stream` and `enable_cuda_graph` set (regression for the removed validation). 2. Capture + replay on the user stream produce correct results. 3. Replay after an in-place input update on the user stream is correct. - Tests are gated on `ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP` and skip gracefully when no CUDA device or plugin library is available. ## Motivation and Context Users that drive ORT from their own CUDA stream (e.g. to interleave ORT inference with their own kernels) previously could not also benefit from CUDA graph capture on the plugin EP. This change brings the plugin EP to parity with the bundled CUDA EP for that workflow. ## Checklist - [x] Tests added/updated - [x] No breaking changes (relaxes a previously rejected option combination) - [ ] Documentation updated (if applicable) --- .github/workflows/linux_cuda_plugin_ci.yml | 24 ++ cmake/onnxruntime_providers_cuda_plugin.cmake | 9 +- .../arena_allocator_migration_design.md | 13 +- .../cuda_graph_for_cuda_plugin.md | 50 +++- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 15 +- include/onnxruntime/ep/adapter/allocator.h | 24 +- onnxruntime/core/framework/session_state.cc | 11 +- .../core/providers/cuda/llm/attention.cc | 7 +- .../core/providers/cuda/plugin/cuda_ep.cc | 64 ++++- .../providers/cuda/plugin/cuda_ep_factory.cc | 9 +- .../cuda/plugin/cuda_kernel_adapter.h | 100 ++++---- .../core/providers/cuda/tensor/shape_op.cc | 75 ++++++ .../cuda_plugin_user_stream_graph_test.cc | 240 ++++++++++++++++++ 13 files changed, 542 insertions(+), 99 deletions(-) create mode 100644 onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index a9197b3732dd8..e88c6beff5280 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -141,3 +141,27 @@ jobs: cd /onnxruntime_src/onnxruntime/test/python/transformers python test_cuda_plugin_ep.py " + + # --- Run the CUDA plugin EP C++ GoogleTest binary --- + # onnxruntime_provider_test is built into the artifact and links the plugin tests + # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). The user-stream + CUDA graph test + # registers the plugin .so via GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), + # which returns the platform-specific filename without a directory component. Run from + # /build/Release/Release so that filename resolves to the plugin .so built there. + - name: Run CUDA Plugin EP C++ Tests + run: | + docker run --rm --gpus all \ + -v ${{ github.workspace }}:/onnxruntime_src \ + -v ${{ runner.temp }}/Release:/build/Release \ + -e NVIDIA_VISIBLE_DEVICES=all \ + ${{ steps.build_docker_image_step.outputs.full-image-name }} \ + bash -c " + set -ex + export PATH=/opt/python/cp312-cp312/bin:\$PATH + # Make libcudart.so.13 (and the plugin's CUDA deps) findable; see note above. + export LD_LIBRARY_PATH=/build/Release/Release:/usr/local/cuda-13.0/lib64:\${LD_LIBRARY_PATH:-} + + cd /build/Release/Release + ls -la onnxruntime_provider_test libonnxruntime_providers_cuda_plugin.so + ./onnxruntime_provider_test --gtest_filter='CudaPluginUserStreamGraphTest.*' + " diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 551e877d5f6d8..86e5579eb6761 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -88,10 +88,11 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. -# shape_op.cc inherits from onnxruntime::OpKernel (framework) -# which cannot convert to ep::adapter::OpKernel in the plugin build. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") +# shape_op.cc is INCLUDED in the plugin build. It provides an adapter-based +# Shape kernel under #ifdef BUILD_CUDA_EP_AS_PLUGIN (the CPU onnxruntime::Shape +# class, which derives from the framework OpKernel, is only used in the +# non-plugin build). Registering Shape on the EP keeps it off the CPU EP and +# avoids Memcpy nodes that would otherwise break CUDA Graph capture. # Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") diff --git a/docs/cuda_plugin_ep/arena_allocator_migration_design.md b/docs/cuda_plugin_ep/arena_allocator_migration_design.md index 285aa3e60ed5c..f082b444e10b0 100644 --- a/docs/cuda_plugin_ep/arena_allocator_migration_design.md +++ b/docs/cuda_plugin_ep/arena_allocator_migration_design.md @@ -62,7 +62,18 @@ if (!factory.arena_allocator_) { **Stream-aware allocation.** `ArenaImpl::AllocOnStream(size, stream)` tracks which chunks are assigned to which stream. `ResetChunksUsingStream(stream_impl)` is called from `OrtSyncStreamImpl::OnSessionRunEnd` to release chunk-to-stream assignments when a session run completes. -**Read-only allocator bypasses arena.** The factory creates a plain `CustomAllocator` (no arena) for `OrtReadOnlyAllocator` (initializers), since initializer memory doesn't benefit from arena allocation. +**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT ≥ 1.23), falling back to plain `Alloc` otherwise. The plugin `GetScratchBuffer` deliberately passes a **null stream** to the arena rather than forwarding the kernel's compute stream. A plugin kernel only has the raw `cudaStream_t` (via `KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena persists in each chunk (`chunk->stream`) and later dereferences through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). Synthesizing a temporary framework `Stream` wrapper over the raw handle would be unsafe: it would dangle once `GetScratchBuffer` returns while the arena still holds the pointer, and it would be type-confused (a framework `Stream*` reinterpreted as an `OrtSyncStream*` that ORT never created for this stream). With a null stream the arena tracks scratch chunks as freely reusable (the same semantics as a plain non-stream-aware BFC arena). This is still what keeps scratch allocations served from already-reserved chunks during CUDA graph capture — capture stability comes from chunk reuse, not from stream tagging — and it is safe for the CUDA graph path, which runs on a single unified stream. + +#### Scratch buffer stream tagging — limitation and future work + +A common review question is: *"Passing a null stream to the scratch allocator looks wrong — won't it cause a synchronization issue? Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is that, for the path this code targets, it is correct and safe. The longer answer clarifies what the `stream` argument actually does and why forwarding the real stream is not currently possible. + +- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is always consumed by the kernel on its real compute stream. So a null tag does not move work onto the default stream or skip any required sync — it only relaxes cross-stream chunk reuse. +- **Why null is safe here.** The scratch routing targets serialized runs and the CUDA graph path, which runs on a single **unified stream** when graph capture and a user compute stream are combined. On a single stream, alloc -> use -> free -> reuse are implicitly ordered by the stream itself, so there is never a second stream that could reuse a chunk while the first is still using it. A null-tagged ("freely reusable") chunk behaves exactly like a plain non-stream-aware BFC arena chunk, which is the correct behavior for one stream. Because null-tagged chunks are not safe for overlapping runs on different CUDA streams, the CUDA plugin EP does not advertise concurrent `Session::Run()` support until scratch chunks can be properly stream-tagged. +- **Why we cannot forward the real stream today (C-API limitation).** The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) — the ORT-core wrapper it stored in `chunk->stream`. A plugin kernel only has the raw `cudaStream_t`. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (an `OrtSyncStreamImpl`), but that is a *different* object from the ORT-core `OrtSyncStream*` the arena expects; passing it (or a stack-allocated shim over the raw handle) would be both dangling and type-confused. +- **Future work.** To properly stream-tag scratch chunks — which only becomes necessary if this path is extended to support concurrent multi-stream runs sharing one arena — ORT needs new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels at dispatch time (e.g. via `KernelContext`). Until then, the null-stream tag is the correct and intentional choice. The matching code comment lives in `CudaKernel::GetScratchBuffer` (`onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h`). + + ### 2.2 How ORT Core Calls the Factory diff --git a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md index 8092a15e26973..1cac9464430dc 100644 --- a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md +++ b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md @@ -34,6 +34,7 @@ Session::Run() **Key design choices:** - Each thread gets its own dedicated graph `cudaStream_t`, `CudaGraphManager`, and capture bookkeeping for the EP instance. `CudaSyncStream::InitHandlesWithExternalStream()` wraps the thread's graph stream so graph capture sees the same stream as kernels. The manager stores captured `cudaGraphExec_t` executables keyed by annotation ID, allowing multiple graphs (e.g., different input shapes) for that thread. +- When a `user_compute_stream` is supplied together with graph capture, the per-thread context adopts that user-owned stream as its graph stream instead of creating one, so capture/replay run on the same stream the caller drives. The context records that it does not own the stream and never destroys it. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). - Warm-up runs (default: 2) allow memory allocations to stabilize before capture begins. - Graph annotation IDs are parsed from `OrtRunOptions` key `"gpu_graph_id"`. ID `-1` skips capture; `0` is the default. @@ -53,25 +54,50 @@ Session::Run() Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supported. For the warm-up count, `ep.cuda.min_num_runs_before_cuda_graph_capture` is also accepted. +The provider option `user_compute_stream` (a `cudaStream_t` passed as a pointer) may be combined with `enable_cuda_graph`. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). + --- +## User Compute Stream + CUDA Graph + +A caller can supply its own CUDA stream through the `user_compute_stream` provider option and enable CUDA graph capture at the same time. This combination was previously rejected with `ORT_INVALID_ARGUMENT`; it is now supported and matches the bundled (non-plugin) CUDA EP. + +When both options are set: + +- `CudaEpFactory::CreateEpImpl` no longer rejects the pair. Setting `user_compute_stream` still forces unified-stream mode (matching the bundled EP). +- `CudaEp::CreateSyncStreamForDeviceImpl` wraps the user stream via `InitHandlesWithUserStream()`, attaching full cuBLAS/cuDNN/cuBLASLt handles to it. +- `CudaEp::GetPerThreadContext()` builds the thread's `PerThreadContext` around the user stream (`external_graph_stream`) instead of creating an EP-owned graph stream. Capture and replay therefore run on the same stream the kernels are issued to. +- The context records `owns_graph_stream = false`, so it tears down captured graph execs on destruction but never calls `cudaStreamDestroy` on the user-owned stream. Stream lifetime stays with the caller. + +Because the user supplies one stream, this mode is inherently single-stream; the per-thread graph isolation still applies if the same session is driven from multiple threads, but each thread must drive its own captures on the stream it provides. + +### `user_compute_stream` is not limited to the CUDA graph case + +A natural question when reading `GetPerThreadContext()` is why `use_external_stream` is gated on `has_user_compute_stream && enable_cuda_graph` — does that restrict a user compute stream to graph-enabled runs? It does not. + +- A user compute stream is honored for kernels in **both** graph and non-graph runs. That happens in `CudaEp::CreateSyncStreamForDeviceImpl`, whose first branch wraps `config_.user_compute_stream` via `InitHandlesWithUserStream()` **independently of `enable_cuda_graph`**. +- The `enable_cuda_graph` term in `use_external_stream` only governs the `PerThreadContext`'s *graph stream*. `PerThreadContext` is a graph-capture-only object: `GetPerThreadContext()` is reached exclusively from the graph path (the `enable_cuda_graph` branch of `CreateSyncStreamForDeviceImpl`, `OnRunStart`/`OnRunEnd`, `IsGraphCaptured`, `ReplayGraph`). With graph disabled, no `PerThreadContext` is ever constructed, so its stream-ownership flag is irrelevant. +- The flag therefore answers a narrower question — *"should the per-thread capture/replay graph stream adopt (and not destroy) the user's stream?"* — which is only meaningful when a graph is actually being captured. + ## Implementation Summary ### Files Changed | File | Change | |------|--------| -| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to use the current thread's graph stream when graph capture is enabled, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to wrap a `user_compute_stream` or otherwise use the current thread's graph stream when graph capture is enabled, made `PerThreadContext` adopt the user stream as its (non-owned) graph stream when `user_compute_stream` + `enable_cuda_graph` are combined, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | | `onnxruntime/core/providers/cuda/plugin/cuda_ep.h` | Added `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` config fields, graph callback declarations, and a per-thread graph context cache | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc` | **NEW** — Complete `CudaGraphSet` and `CudaGraphManager` implementation | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h` | **NEW** — Header for graph manager types and constants | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member | -| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture`; removed the validation that rejected `user_compute_stream` + `enable_cuda_graph` (the combination is now supported) | +| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) with a null stream, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call, so scratch allocations are served from already-reserved arena chunks during capture | +| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available; `GetScratchBuffer` still passes a null stream until plugin kernels can receive a stable framework `OrtSyncStream*` | | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` | | `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` | | `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion | -| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread | +| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread; added a fallback in the PrePack loop to resolve the kernel's default-memory allocator (`Info().GetAllocator()`) when the device-keyed initializer-allocator lookup returns null for a separately-registered plugin EP | | `onnxruntime/core/framework/session_state.h` | Added thread-affine stream pool bucket state for `DeviceStreamCollection` reuse | | `onnxruntime/core/session/inference_session.h` | Added `RunImpl()` private method and `kMaxGraphCaptureWarmupRuns` constant | | `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | Added version-gated `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` bridge implementations | @@ -83,7 +109,8 @@ Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supp - **Thread safety**: Mutable graph state and graph streams are stored per thread. ORT core's `DeviceStreamCollection` cache is also thread-affine, so graph-enabled runs can recycle stream wrappers without exposing them to a different thread. - **Scope**: Capture/replay pipeline plus allocator compatibility. Arena integration is complete — see the [Arena Allocator Integration](#arena-allocator-integration) section. - **Callback assignment**: `IsGraphCaptureEnabled` and `GetGraphCaptureNodeAssignmentPolicy` are always set. `OnRunStart`, `OnRunEnd` are conditional on `enable_cuda_graph`. `IsGraphCaptured` and `ReplayGraph` are always set (return false/error when disabled). -- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally to use the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when disabled. +- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally: it wraps a user-provided `user_compute_stream` (via `InitHandlesWithUserStream`) when one is set, otherwise uses the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when both are disabled. +- **User compute stream + CUDA graph**: These options can now be combined (previously rejected at factory creation). When both are set, `CudaEp::GetPerThreadContext()` builds the `PerThreadContext` around the user's stream (`external_graph_stream`) so capture and replay run on the same stream the kernels use, and the context never destroys the user-owned stream (`owns_graph_stream = false`). - **Run-end synchronization**: `OnRunEndImpl` honors the `sync_stream` flag without double-synchronizing replayed graphs, preserving the normal EP completion contract. - **Stream collection reuse**: ORT core now recycles `DeviceStreamCollection` objects into a thread-affine session pool keyed by a per-thread lifetime token. Warm-up, capture, replay, and later user-visible `Run()` calls on the same thread can reuse the same stream wrappers, while dead-thread buckets are pruned before they can be reused by another thread. - **Per-thread context lifecycle**: Thread-local caches hold the strong `PerThreadContext` references, so CUDA streams and captured graph executables are released when the owning thread exits. The EP tracks weak references to those cache maps to remove stale entries during EP destruction without keeping the contexts alive. @@ -101,6 +128,7 @@ CUDA graph capture requires that all memory allocations happen during warmup, no **Arena integration details (now implemented):** - Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls. +- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` passes a **null stream** to the arena. This is *not* a synchronization bug: the `stream` argument is only bookkeeping metadata the stream-aware arena uses to decide when a freed chunk may be reused on a *different* stream without a sync - it does not change where the kernel runs (the buffer is still consumed on the real compute stream). In a serialized run (and within one graph-capture run), alloc/free/reuse are implicitly ordered on that stream, so a null-tagged ("freely reusable") chunk is correct and safe. It is also currently the only safe option, because a plugin kernel only has the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` the stream-aware arena persists per chunk and later dereferences through the EP stream API; note that the ORT-core `OrtSyncStream` (`struct OrtSyncStream : public onnxruntime::Stream`) is a different object from the plugin's `CudaSyncStream` (an `OrtSyncStreamImpl`). Synthesizing a temporary `Stream*` over the raw handle would dangle after `GetScratchBuffer` returns and be type-confused, so scratch chunks are tracked with a null stream (freely reusable, like a plain BFC arena). Capture stability comes from chunk reuse, not stream tagging. Properly stream-tagging scratch chunks (required before this path can support concurrent multi-stream runs) is **future work** that requires new C-API surface to expose the framework `OrtSyncStream*` to plugin kernels — see [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). - When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology. - Pinned allocations are also arena-backed, but remain non-stream-aware. - The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay. @@ -109,14 +137,12 @@ CUDA graph capture requires that all memory allocations happen during warmup, no ### Concurrent Run Support -Concurrent `Session::Run()` is supported with CUDA graph enabled: +Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plugin EP while migrated kernels route scratch/workspace allocations through the EP arena with a null stream tag. -- `CudaEp::PerThreadContext` owns the graph stream, graph manager, warm-up run counts, and memory watermark for the current thread. -- The current thread's cache owns the `PerThreadContext`; new threads get independent contexts, and exited threads release their contexts automatically. -- `CreateSyncStreamForDeviceImpl()` wraps the current thread's graph stream, so warm-up, capture, and replay all use the same stream for that thread. -- `CudaGraphManager::CaptureBegin()` uses `cudaStreamCaptureModeThreadLocal`, allowing overlapping capture scopes on different threads. -- ORT core recycles graph-enabled `DeviceStreamCollection` objects into a thread-affine session pool, so internal warm-up/capture recursion and later top-level `Run()` calls on the same thread reuse the same stream wrappers without cross-thread leakage. -- `IsGraphCaptured()` and `ReplayGraph()` resolve the current thread's graph context. If a new thread runs a graph-enabled session for the first time, that thread performs its own warm-up and capture before replaying. +- `CudaEp::PerThreadContext` still owns graph stream, graph manager, warm-up run counts, and memory watermark state per thread. This keeps graph bookkeeping thread-local and avoids sharing captured graph executables across threads. +- However, plugin kernels currently receive only the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena stores in each chunk and later uses for safe cross-stream reuse checks. +- Because `CudaKernel::GetScratchBuffer` cannot safely provide that framework stream, it passes a null stream tag. Null-tagged scratch chunks are freely reusable, which is safe for serialized runs and single-unified-stream graph capture but unsafe for overlapping runs on different CUDA streams. +- Therefore `CudaEp::IsConcurrentRunSupportedImpl()` returns false. Re-enabling concurrent multi-stream runs is future work and requires new C-API surface to expose a stable framework stream (or equivalent sync id) to plugin kernels so scratch chunks can be properly stream-tagged. ## Verification @@ -128,7 +154,9 @@ Concurrent `Session::Run()` is supported with CUDA graph enabled: - `test_cuda_graph_with_mempool` — graph capture with `arena.use_cuda_mempool=1` - `test_cuda_graph_annotation_id` — multiple graphs via `gpu_graph_id` run config - `test_cuda_graph_add_model` — graph capture with Add op (arena-backed) +4. `onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc` is a C++ test (gated by `ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP`) covering `user_compute_stream` combined with `enable_cuda_graph`: it verifies session creation succeeds with both options set (regression for the removed validation), capture + replay on the user stream produce correct results, and replay after an in-place input update on the user stream is correct. ## Future Work 1. **Profiling integration**: CUDA graph replay currently bypasses the CUDA plugin EP profiler path because the CUDA plugin EP does not yet implement `OrtEp::CreateProfiler`. Wiring graph replay into that path is future work. +2. **Stream-tagged scratch allocations**: `CudaKernel::GetScratchBuffer` passes a null stream to the EP arena because plugin kernels cannot currently obtain the framework `OrtSyncStream*` the stream-aware arena needs (they only have the raw `cudaStream_t`). This is correct and safe for serialized runs and within one graph-capture run, but it is why the EP does not advertise concurrent `Session::Run()` support. Supporting concurrent multi-stream runs that share one arena would require new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels so scratch chunks can be properly stream-tagged. See [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 2f61da90233b4..438fb8606fc09 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -459,9 +459,18 @@ The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: | 2 | Cache the shim provider pointer in the adapter `OpKernelInfo` | Implemented; fixes the observed NHWC runtime crash | | 3 | Consolidate allowlists, improve internal-domain diagnostics, and strengthen structural NHWC assertions | Recommended follow-up work | +#### 5.3.2 Allocator Resolution for Kernels (Scratch and PrePack) + +Migrated kernels need a valid device allocator in two places: scratch/workspace buffers during `Compute()`, and one-time weight conversion or packing during `PrePack()`. Both now resolve the allocator the same way the bundled CUDA EP does, through the kernel's own `OpKernelInfo`. + +- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) with a null stream tag, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). The null stream tag is intentional: plugin kernels only have the raw `cudaStream_t`, not the framework `OrtSyncStream*` that the stream-aware arena persists in chunks for safe cross-stream reuse. +- **PrePack.** The framework prepack loop (`SessionState::PrepackConstantInitializedTensors`) resolves the allocator with `GetInitializerAllocator(kernel->Info().GetDevice(OrtMemTypeDefault))`, a session map keyed by device. For a plugin EP registered as a separate library, that device-keyed lookup can miss and return null. The loop now falls back to `kernel->Info().GetAllocator(OrtMemTypeDefault)` when the lookup is null, so every `PrePack` implementation receives a valid allocator at the single framework call site. This replaces the earlier approach of adding a per-kernel `if (!alloc) alloc = Info().GetAllocator(...)` guard to each prepacking op (which only covered the few ops that were touched and risked missing future ones). The fallback is behavior-neutral for in-tree EPs, whose device-keyed lookup already succeeds, and it does **not** force `is_packed`/`prepacked_weights` handling \u2014 ops such as `QMoE` and `MatMulNBits` still set `is_packed = true` and populate prepacked weights normally. + +The enabling adapter change is in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h): `IAllocatorWrappingOrtAllocator` now implements `IsStreamAware()`/`AllocOnStream()` (previously `ORT_NOT_IMPLEMENTED`) by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` does not use that stream-aware path yet because the plugin kernel layer cannot safely provide the framework `OrtSyncStream*`; stream-tagged scratch allocation is future work and is documented in [arena_allocator_migration_design.md](arena_allocator_migration_design.md#scratch-buffer-stream-tagging--limitation-and-future-work). + ### 5.4 CUDA Graph Support -CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and concurrent `Session::Run()` support. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and concurrent run details — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. +CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is intentionally not advertised while scratch allocations are null-stream-tagged; supporting concurrent multi-stream runs requires future C-API work to expose a stable framework stream or sync id to plugin kernels. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. #### 5.4.1 OrtEp C API Extensions (v1.26) @@ -488,6 +497,10 @@ Session-level changes in `inference_session.cc`: - **Bounded recursion**: After each normal run when graph capture is enabled, the session recursively calls `RunImpl()` (bounded by `kMaxGraphCaptureWarmupRuns = 8`) until the graph is captured. From the user's perspective, a single `Run()` call handles the entire warm-up + capture sequence. - **Stream collection lifetime**: ORT core now caches `DeviceStreamCollection` objects in thread-affine session buckets keyed by a per-thread lifetime token. Graph-enabled runs recycle and reacquire stream wrappers only on the creating thread, which preserves warm-up/capture reuse without cross-thread leakage. +#### 5.4.3 User Compute Stream with CUDA Graph + +A caller-provided `user_compute_stream` may be combined with `enable_cuda_graph` (the factory previously rejected this pair). When both are set, `CudaEp::GetPerThreadContext()` builds the per-thread graph context around the user-owned stream rather than an EP-owned one, so capture and replay run on the same stream the kernels are issued to (matching the bundled CUDA EP). The context marks the stream as not owned and never destroys it. Details are in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#user-compute-stream--cuda-graph). + --- ## 6. EP Adapter Layer (`include/onnxruntime/ep/adapter/`) diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 0db30f39b3f57..1fb78f81fce19 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -41,21 +41,19 @@ class IAllocatorWrappingOrtAllocator final : public IAllocator { } bool IsStreamAware() const override { - return false; - - // TODO: Enable once AllocOnStream() is implemented. - // static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; - // const OrtAllocator* raw = ort_allocator_; - // return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + const OrtAllocator* raw = ort_allocator_; + return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; } - void* AllocOnStream(size_t /*size*/, Stream* /*stream*/) override { - // TODO: Implement AllocOnStream(). - // The internal `onnxruntime::IAllocator::AllocOnStream` signature takes an internal `onnxruntime::Stream*` - // argument, while the public `::OrtAllocator::AllocOnStream` signature takes an `::OrtSyncStream*` argument. - // We need to properly map from one to the other. - // `::OrtSyncStream*` should be treated as an opaque type from the plugin EP's perspective. - ORT_NOT_IMPLEMENTED("IAllocatorWrappingOrtAllocator::AllocOnStream is not implemented yet."); + void* AllocOnStream(size_t size, Stream* stream) override { + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + OrtAllocator* raw = ort_allocator_; + if (raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr) { + return raw->AllocOnStream(raw, size, reinterpret_cast(stream)); + } + + return raw->Alloc(raw, size); } void GetStats(AllocatorStats* stats) override { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ef2319c1d3f4..ad92ddd797d3a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -603,7 +603,16 @@ Status SessionState::PrepackConstantInitializedTensors( // within this session. Or if the weight is not present on disk, // we store the newly minted pre-packed data. - AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + AllocatorPtr session_initializer_alloc = GetInitializerAllocator( + kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + // A plugin EP registered as a separate library may not have an initializer + // allocator registered under the kernel's device key, so the lookup above can + // return null. Fall back to the kernel's own default-memory allocator (resolved + // through the EP), which is always valid. This keeps PrePack implementations from + // each having to special-case a null allocator at the library boundary. + if (!session_initializer_alloc) { + session_initializer_alloc = kernel->Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + } PrePackedWeights weights_to_be_filled_in; // The reason we invoke PrePack() before looking into the container for any pre-packed weight // cached by another instance of the same op_type (for the same constant initializer) is because diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index edafbfb3ede65..dc53b02141207 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1361,9 +1361,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // cross-attention case; MEA handles it via the causal_from_top_left flag and Unified // Unfused uses past_kv_length=0. (When an external cache is present — nonpad_kv_seqlen — // the required frontier IS bottom-right, so Flash is eligible; see below.) - const bool causal_cross_no_past = parameters.is_causal && - parameters.q_sequence_length != parameters.total_sequence_length && - parameters.past_sequence_length == 0; + [[maybe_unused]] const bool causal_cross_no_past = + parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; // is_causal=1 + nonpad_kv_seqlen (external KV cache) without past_key defines a // bottom-right causal frontier per onnx/onnx#8068: query in-block index i attends key j diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index d611366c8ad7e..13f7bfd7a40cf 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -83,20 +83,33 @@ void DestroyCudaStreamForDevice(cudaStream_t stream, int device_id) { } // namespace struct CudaEp::PerThreadContext { - explicit PerThreadContext(int device_id) + // When use_external_stream is true (user_compute_stream combined with CUDA graph), capture and + // replay happen on that user-owned stream so they see the same stream as the kernels; the + // context neither creates nor destroys it. Ownership is derived from the caller's intent rather + // than from external_stream being non-null, because a user may legitimately select the CUDA + // default stream (cudaStream_t(0), i.e. nullptr) as the compute stream — that is still an + // external, user-owned stream and must not be destroyed by the context. When use_external_stream + // is false the context creates and owns a dedicated graph stream. + explicit PerThreadContext(int device_id, bool use_external_stream = false, + cudaStream_t external_stream = nullptr) : device_id(device_id), - graph_stream(CreateCudaStreamForDevice(device_id)), + owns_graph_stream(!use_external_stream), + graph_stream(use_external_stream ? external_stream + : CreateCudaStreamForDevice(device_id)), cuda_graph(graph_stream) { } ~PerThreadContext() { // Destroy captured graph execs before destroying the stream they replay on. cuda_graph.Reset(); - DestroyCudaStreamForDevice(graph_stream, device_id); + if (owns_graph_stream) { + DestroyCudaStreamForDevice(graph_stream, device_id); + } graph_stream = nullptr; } int device_id; + bool owns_graph_stream; cudaStream_t graph_stream = nullptr; CudaGraphManager cuda_graph; size_t pre_capture_free_mem = 0; @@ -391,8 +404,14 @@ OrtStatus* ORT_API_CALL CudaEp::CreateSyncStreamForDeviceImpl( auto cuda_stream = std::make_unique(ep->factory_, device_id, this_ptr); - if (ep->config_.has_user_compute_stream && ep->config_.user_compute_stream != nullptr) { - // Wrap the user-provided external CUDA stream with full cuBLAS/cuDNN handles. + if (ep->config_.has_user_compute_stream) { + // A user-provided compute stream is honored for kernels regardless of whether CUDA graph + // capture is enabled - this branch is taken in both graph and non-graph runs. Use the caller's + // intent flag rather than checking the handle for non-null: cudaStream_t(0) / nullptr is the + // valid CUDA default stream and can be selected explicitly by the user. Wrap the external CUDA + // stream with full cuBLAS/cuDNN handles. When CUDA graph capture is also enabled, + // capture/replay run on this same user stream (see GetPerThreadContext), so kernels and graph + // capture share one stream. RETURN_IF_ERROR(cuda_stream->InitHandlesWithUserStream( static_cast(ep->config_.user_compute_stream))); } else if (ep->config_.enable_cuda_graph) { @@ -439,15 +458,20 @@ OrtStatus* ORT_API_CALL CudaEp::SyncImpl(OrtEp* this_ptr) noexcept { /*static*/ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl( OrtEp* this_ptr, bool* is_supported) noexcept { + ORT_UNUSED_PARAMETER(this_ptr); + if (is_supported == nullptr) { return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null."); } - auto* ep = static_cast(this_ptr); - // When a unified stream is in use (either from user_compute_stream, external - // allocator, or explicit use_ep_level_unified_stream), all operations share a - // single stream so concurrent runs are not safe. - *is_supported = !ep->config_.use_ep_level_unified_stream; + // Plugin kernels currently expose only the raw cudaStream_t to GetScratchBuffer(), not the + // framework OrtSyncStream* that the stream-aware arena needs to tag scratch chunks by stream. + // Scratch chunks are therefore allocated with a null stream tag and can be reused freely. That is + // safe when runs are serialized, but it is not safe to advertise concurrent Session::Run(): two + // runs on different CUDA streams could reuse the same scratch chunk while earlier work is still + // in flight. Re-enable concurrent runs only after the plugin kernel layer can pass a stable + // framework stream (or equivalent sync id) to the arena. + *is_supported = false; return nullptr; } @@ -467,7 +491,25 @@ CudaEp::PerThreadContext& CudaEp::GetPerThreadContext() const { return *cached_context_it->second; } - auto context = std::make_shared(config_.device_id); + // NOTE: `enable_cuda_graph` in this condition does NOT restrict using a user compute stream to + // the graph case. A user compute stream is honored for kernels in BOTH graph and non-graph runs + // — that happens in CreateSyncStreamForDeviceImpl(), which wraps config_.user_compute_stream + // independently of enable_cuda_graph. This flag only governs the PerThreadContext's *graph + // stream*, and PerThreadContext is a graph-capture-only object: GetPerThreadContext() is reached + // exclusively from the graph path (CreateSyncStreamForDeviceImpl's enable_cuda_graph branch, + // OnRunStart/OnRunEnd, IsGraphCaptured, ReplayGraph). With graph disabled, no PerThreadContext is + // ever constructed, so its stream ownership is irrelevant. + // + // When a user compute stream IS combined with CUDA graph capture, capture/replay must run on the + // user's stream (the same stream the kernels are issued to) rather than a separate EP-owned + // stream. The user owns the stream's lifetime, so the context must not destroy it. Derive this + // from the caller's intent (has_user_compute_stream && enable_cuda_graph), not from whether the + // handle is null: a user may explicitly choose the CUDA default stream (nullptr), which is still + // an external stream that the context must not own/destroy. + const bool use_external_stream = config_.has_user_compute_stream && config_.enable_cuda_graph; + cudaStream_t external_stream = + use_external_stream ? static_cast(config_.user_compute_stream) : nullptr; + auto context = std::make_shared(config_.device_id, use_external_stream, external_stream); PerThreadContext& context_ref = *context; { std::lock_guard lock(per_thread_contexts_mutex_); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index cb021034662e8..d445d8bab033c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -609,12 +609,9 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( "CUDA plugin EP does not support using both user_compute_stream and external allocator simultaneously."); } - // Validate: user_compute_stream and cuda graph cannot both be active. - if (config.has_user_compute_stream && config.enable_cuda_graph) { - return factory->ort_api_.CreateStatus( - ORT_INVALID_ARGUMENT, - "CUDA plugin EP does not support using both user_compute_stream and enable_cuda_graph simultaneously."); - } + // user_compute_stream and enable_cuda_graph CAN be combined: when both are set, CUDA graph + // capture/replay runs on the user-provided stream (the same stream kernels are issued to), + // matching the bundled CUDA EP behavior. See CudaEp::GetPerThreadContext. // When user_compute_stream is set, force unified stream mode (matches bundled EP behavior). if (config.has_user_compute_stream) { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 021acaf142435..f134c599d5b46 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -1023,56 +1023,60 @@ class CudaKernel : public OpKernel { template using IAllocatorUniquePtr = std::unique_ptr>; template - inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* /*stream*/) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); - size_t sz = 0; - if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { - ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); - } - void* p = nullptr; - cudaError_t alloc_result = cudaSuccess; - bool used_async_alloc = false; - if (s) { - // Note: stream-ordered allocations (cudaMallocAsync/cudaFreeAsync) rely on CUDA Memory Pools, - // which are not supported on NVIDIA GPUs with Multi-Instance GPU (MIG) enabled. - // On such instances, this will return cudaErrorNotSupported. - alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); - used_async_alloc = (alloc_result == cudaSuccess); - if (!used_async_alloc && (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue)) { - cudaGetLastError(); // Clear the thread-local error state - alloc_result = cudaMalloc(&p, sz); - } - } else { - alloc_result = cudaMalloc(&p, sz); - } - - if (alloc_result != cudaSuccess) { - ORT_THROW("CUDA scratch buffer allocation failed for ", sz, " bytes: ", cudaGetErrorString(alloc_result)); - } - - return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { - if (ptr) { - // Guard: only attempt async free if the stream is still registered. - // CudaSyncStream::~CudaSyncStream guarantees UnregisterStream() is - // called before cudaStreamDestroy(), so a non-null lookup here means - // the raw cudaStream_t handle is still valid. - if (used_async_alloc && s && - cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)) != nullptr) { - // As noted above, cudaFreeAsync may also return cudaErrorNotSupported on MIG-enabled instances. - cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); - if (free_result == cudaSuccess) { - return; - } - cudaGetLastError(); // Clear any error set by cudaFreeAsync - } - - // Fall back to synchronous free if async free is unsupported or if the - // stream is no longer registered. cudaFree is valid for allocations - // returned by cudaMallocAsync and avoids using a stale stream handle. - cudaFree(ptr); - } - }); + // Route kernel scratch/workspace allocations through the EP allocator + // (a BFC arena by default) instead of raw cudaMallocAsync/cudaMalloc. + // + // The arena pre-reserves device memory and reuses freed chunks across runs. + // Once the model has executed `min_num_runs_before_cuda_graph_capture` + // warmup runs, the arena has grown to its steady-state working set, so the + // capture run serves every scratch allocation from an already-reserved chunk + // without issuing a fresh cudaMalloc. This keeps the device free-memory + // footprint stable across the capture window, which is required for correct + // CUDA graph capture/replay. + // + // The previous behavior (cudaMallocAsync/cudaMalloc allocated-and-freed per + // call) allocated new device memory on every run, including the capture run, + // so no amount of warmup could stabilize it and the + // "GPU memory was allocated during CUDA graph capture" detector would trip. + // This now matches the built-in (non-plugin) CUDA EP, which also obtains + // scratch from Info().GetAllocator() (see core/providers/cuda/cuda_kernel.h). + // The overflow check that the previous hand-rolled path performed is still + // enforced inside MakeUniquePtr via ValidatedCalcMemSizeForArray (it throws + // on cnt * sizeof(T) overflow). + // + // The compute stream is intentionally NOT forwarded to the allocator here. This is a + // bookkeeping decision, NOT a synchronization bug: the `stream` argument to a stream-aware + // arena is only metadata used to decide when a freed chunk may be reused on a *different* + // stream without an intervening sync. It does not change where the kernel runs - the returned + // buffer is still consumed by the kernel on the real compute stream. In a serialized run (and + // within one graph-capture run), alloc/free/reuse ordering is implicit on that stream, so there + // is no cross-stream chunk to race on. Tagging chunks with a null stream (freely reusable, the + // same semantics as a plain non-stream-aware BFC arena) is therefore correct and safe as long + // as the EP does not advertise concurrent Session::Run() support. + // + // It is also currently the only safe option, because of a C-API type constraint: a plugin + // kernel only has the raw cudaStream_t (KernelContext::GetGPUComputeStream), not the framework + // OrtSyncStream* that the stream-aware arena persists in each chunk (CudaArena stores + // `chunk->stream` and later dereferences it through the EP stream API, e.g. + // SyncStream_GetImpl/SyncStream_GetSyncId). Note that OrtSyncStream (the ORT-core wrapper, + // `struct OrtSyncStream : public onnxruntime::Stream`) is a DIFFERENT object from the plugin's + // CudaSyncStream (an OrtSyncStreamImpl); CudaSyncStream::FromCudaStream() recovers the latter, + // not the former. Wrapping the raw handle in a temporary framework Stream shim and passing it + // down would be unsafe on two counts: (1) the shim is stack-allocated and would dangle after + // this function returns while the arena still holds the pointer, and (2) it is type-confused — + // the arena would reinterpret a framework Stream* as an OrtSyncStream* that was never created + // by ORT for this stream. + // + // Properly stream-tagging scratch chunks (needed before this path can support concurrent + // multi-stream runs) requires new C-API surface to expose the framework OrtSyncStream* to + // plugin kernels. See docs/cuda_plugin_ep/arena_allocator_migration_design.md ("Scratch buffer + // stream tagging") for the limitation and future work. + return ::onnxruntime::IAllocator::MakeUniquePtr( + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), cnt, /*use_reserve*/ false, + /*stream*/ nullptr); } template inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index 230b0b495bfbf..0202789a8777d 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -2,12 +2,87 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/shape_op.h" +#endif #include "core/providers/cuda/cuda_fwd.h" +#ifdef BUILD_CUDA_EP_AS_PLUGIN +#include +#endif + namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// The bundled CUDA EP registers the CPU `onnxruntime::Shape` kernel (which +// derives from the framework `OpKernel`) and only marks its output as CPU +// memory. That class cannot be registered through the plugin EP's adapter +// kernel machinery, so the plugin build provides an adapter-based Shape kernel +// with identical semantics. Shape only reads the input's shape metadata (never +// its data) and writes the dims to a CPU output, so registering it on the CUDA +// EP keeps the node inside the device partition and avoids the device->host +// Memcpy node that the framework would otherwise insert to feed an isolated CPU +// Shape node -- a Memcpy that would prevent CUDA Graph capture. The output is +// still CPU memory, so a downstream device consumer may still need a copy; this +// removes the graph-breaking input-side Memcpy, it does not eliminate all copies. +class Shape final : public CudaKernel { + public: + explicit Shape(const OpKernelInfo& info) : CudaKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs a 1D int64 tensor (on CPU memory) + // containing the shape of the input tensor. + Status ComputeInternal(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = static_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), + static_cast(true_start), + static_cast(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; +#endif // BUILD_CUDA_EP_AS_PLUGIN + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc new file mode 100644 index 0000000000000..d49faf3c90ea8 --- /dev/null +++ b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests that the CUDA plugin EP supports combining a user-provided compute stream +// (user_compute_stream) with CUDA graph capture/replay (enable_cuda_graph). +// +// Historically the plugin EP rejected this combination with ORT_INVALID_ARGUMENT. +// It now captures and replays the CUDA graph on the user-provided stream (the same +// stream the kernels are issued to), matching the bundled CUDA EP behavior. These +// tests verify: +// 1. Session creation succeeds with both options set (regression for the removed +// validation). +// 2. Capture + replay on the user stream produce correct results. +// 3. Replay after an in-place input update (on the user stream) is correct. + +#if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/file_util.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { +namespace { + +constexpr const char* kCudaPluginEpRegistrationName = "CudaPluginUserStreamGraphTest"; +constexpr const char* kCudaPluginEpName = "CudaPluginExecutionProvider"; + +// Resolve the CUDA plugin EP shared library path. +std::filesystem::path GetCudaPluginLibraryPath() { + return GetSharedLibraryFileName(ORT_TSTR("onnxruntime_providers_cuda_plugin")); +} + +// RAII handle that registers/unregisters the CUDA plugin EP library. +class ScopedCudaPluginRegistration { + public: + ScopedCudaPluginRegistration(Ort::Env& env, const char* registration_name) + : env_(env), name_(registration_name) { + auto lib_path = GetCudaPluginLibraryPath(); + if (!std::filesystem::exists(lib_path)) { + available_ = false; + return; + } + env_.RegisterExecutionProviderLibrary(name_.c_str(), lib_path.c_str()); + available_ = true; + } + + ~ScopedCudaPluginRegistration() { + if (available_) { + try { + env_.UnregisterExecutionProviderLibrary(name_.c_str()); + } catch (...) { + } + } + } + + bool IsAvailable() const { return available_; } + + ScopedCudaPluginRegistration(const ScopedCudaPluginRegistration&) = delete; + ScopedCudaPluginRegistration& operator=(const ScopedCudaPluginRegistration&) = delete; + + private: + Ort::Env& env_; + std::string name_; + bool available_ = false; +}; + +// Find the CUDA plugin EP device after registration. +Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { + auto ep_devices = env.GetEpDevices(); + for (const auto& device : ep_devices) { + if (strcmp(device.EpName(), kCudaPluginEpName) == 0) { + return device; + } + } + return Ort::ConstEpDevice{nullptr}; +} + +} // namespace + +class CudaPluginUserStreamGraphTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kCudaPluginEpRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + cudaDeviceSynchronize(); + } + + // Build session options that select the plugin EP with CUDA graph capture enabled + // and the user-provided stream supplied as a pointer-sized address string. + Ort::SessionOptions CreateUserStreamGraphSessionOptions(cudaStream_t user_stream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"enable_cuda_graph", "1"}, + {"user_compute_stream", + std::to_string(reinterpret_cast(user_stream))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + return so; + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +// Regression: creating a session with both user_compute_stream and enable_cuda_graph +// used to fail with ORT_INVALID_ARGUMENT. It must now succeed. +TEST_F(CudaPluginUserStreamGraphTest, SessionCreatesWithUserStreamAndCudaGraph) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + { + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + ASSERT_NO_THROW({ + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }); + } + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Full capture + replay on the user stream, including replay after an in-place input +// update. mul_1.onnx computes Y = X * W with W = [1, 2, 3, 4, 5, 6] (shape 3x2). +TEST_F(CudaPluginUserStreamGraphTest, CaptureAndReplayOnUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Device allocator backing the plugin EP's default memory. + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Pre-allocate device input/output buffers (required for CUDA graph IO binding). + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + auto upload_input = [&](const std::array& host_values) { + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, host_values.data(), kBytes, + cudaMemcpyHostToDevice, user_stream)); + }; + + auto read_output = [&](std::array& host_values) { + // Kernels run on the user stream; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(user_stream)); + ASSERT_EQ(cudaSuccess, + cudaMemcpy(host_values.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + }; + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + // First run: warmup + capture + first replay on the user stream. + const std::array x0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + upload_input(x0); + session.Run(Ort::RunOptions{}, binding); + + std::array y{}; + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "capture mismatch at " << i; + } + + // Second run: pure graph replay (same inputs) on the user stream. + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "replay mismatch at " << i; + } + + // Update the input in place on the user stream and replay again. + const std::array x1 = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; + upload_input(x1); + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x1[i] * w_values[i]) << "updated-input replay mismatch at " << i; + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP From d4b01e437367deba5a0b1bcc5fa457e82c2388a5 Mon Sep 17 00:00:00 2001 From: FuZoe <147491646+FuZoe@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:10:08 +0800 Subject: [PATCH 06/19] Fix CPU Attention causal mask alignment (#29050) ## Summary - align CPU ONNX Attention causal masking with upper-left behavior for q_len=1, kv_len>1, no past - preserve the existing `nonpad_kv_seqlen` / TensorScatter single-query causal behavior - update Python attention reference causal mask to model ONNX upper-left alignment with an explicit past offset - add a regression test for issue #29020 Fixes #29020 ## Validation - `python -m py_compile onnxruntime/test/python/transformers/test_onnx_attention/common.py onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py` - `git diff --check` Notes: - `pytest onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py -k "cpu_fp32 and causal" -q` could not run locally because this Python environment does not have `onnx` / `onnxruntime` installed. - After the latest follow-up commit, an incremental rebuild of `onnxruntime_provider_test` was attempted but failed in MSBuild before compiling this change due to a local environment issue: duplicate `Path` / `PATH` environment keys when launching `CL.exe`. --- .../core/providers/cpu/llm/attention.cc | 10 ++- .../providers/cpu/llm/attention_op_test.cc | 73 ++++++++++++++++++- .../test_onnx_attention/common.py | 11 +-- .../test_onnx_attention/test_gqa.py | 2 + .../test_onnx_attention/test_mha.py | 9 ++- 5 files changed, 97 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 1b0f37feab9fb..150fe7b478c1e 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -347,7 +347,15 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* mask_data = nullptr; bool delete_mask_data = false; - bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + // In the nonpad_kv_seqlen path, q_len=1 is external KV-cache decode with + // bottom-right alignment. The single query's causal frontier is the valid + // length, so nonpad masking alone leaves exactly all valid keys visible. + // Keep causal=false for that case to avoid applying the batch-shared + // upper-left overlay used by the no-nonpad path. + bool causal = parameters.is_causal && + (parameters.has_nonpad_kv_seqlen + ? parameters.q_sequence_length > 1 + : !(parameters.q_sequence_length == 1 && parameters.past_sequence_length > 0)); // When nonpad_kv_seqlen is present the causal frontier is offset-aware // (bottom-right) and per-batch, so it cannot be baked into the batch-shared mask // buffer here; it is applied per-batch in the main loop below. Skip the top-left diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index cf76ca0fa00f8..03a47664c632a 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -2332,6 +2332,45 @@ TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_Decode_BottomRight) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// q_len=1 with nonpad_kv_seqlen keeps bottom-right decode behavior: the single +// query attends every valid key. If the no-nonpad upper-left overlay is applied +// here, this returns 1.0 instead of the expected 1/6. +TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_SingleQueryKeepsBottomRight_CPU) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + test.AddAttribute("is_causal", static_cast(1)); + + constexpr int batch_size = 1; + constexpr int q_num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int q_sequence_length = 1; + constexpr int kv_sequence_length = 6; + constexpr int head_size = 8; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::fill_n(v.begin(), head_size, 1.0f); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, q); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, k); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {batch_size}, {kv_sequence_length}); + + std::vector expected_y(batch_size * q_num_heads * q_sequence_length * head_size, + 1.0f / static_cast(kv_sequence_length)); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, expected_y, false, 0, + 1e-4f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // Continued / chunked prefill (S_q=2) into a partially-filled static cache. // nonpad=[4], S_q=2 -> offset = 4 - 2 = 2: query 0 attends keys {0,1,2}, query 1 // attends {0,1,2,3}. The old top-left alignment would mask everything past the @@ -3087,9 +3126,41 @@ TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { // ============================================================================ // Causal alignment tests: verify upper-left (no past) vs lower-right (with past) -// These are CUDA-only tests that validate the causal masking fix. +// These tests validate causal mask alignment across CPU and CUDA. // ============================================================================ +// Test: Causal + cross-attention (S_q=1, S_kv=6, no past) +// ONNX spec mandates upper-left alignment: q0 attends only to kv[0]. +// This covers GitHub issue #29020, where CPU skipped causal masking for S_q=1. +TEST(AttentionTest, Attention4DCausalSingleQueryCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size, 0.0f); + for (int i = 0; i < v_head_size; ++i) { + v[i] = 1.0f; + } + std::vector y(batch_size * q_num_heads * q_sequence_length * v_head_size, 1.0f); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, + v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), + std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, + TensorType::kFloat, + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + // Test: Causal + cross-attention (S_q=3, S_kv=5, no past) // ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. // V is identity-like so output directly reveals which KV positions were attended. diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 349f68f0ac5ce..d1a5d833d12a4 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -666,12 +666,11 @@ def attention_past_func( # ################################################################################################# -def construct_causal_mask(seqlen_q, seqlen_k, device): - """Construct a causal mask for attention.""" +def construct_causal_mask(seqlen_q, seqlen_k, device, past_seqlen=0): + """Construct a causal mask for ONNX Attention upper-left alignment.""" row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - # Causal: positions can only attend to earlier positions - return col_idx > row_idx + seqlen_k - seqlen_q + return col_idx > row_idx + past_seqlen def attention_ref( @@ -682,6 +681,7 @@ def attention_ref( attn_bias=None, causal=False, softcap=0.0, + past_seqlen=0, ): """ Reference implementation of scaled dot-product attention with GQA support. @@ -694,6 +694,7 @@ def attention_ref( attn_bias: Additive attention bias [broadcastable to batch, num_heads, seq_q, seq_k] causal: Whether to apply causal masking softcap: Softcap value for attention scores (0.0 = disabled) + past_seqlen: Number of past K/V tokens before q[0] for causal masking. Returns: output: Attention output [batch, seq_q, num_heads, head_size] @@ -724,7 +725,7 @@ def attention_ref( scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device) + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device, past_seqlen=past_seqlen) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index 542094a9ac4ee..eca86e429b597 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -289,6 +289,7 @@ def parity_check_gqa_past( key_padding_mask=key_padding_mask, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -548,6 +549,7 @@ def parity_check_gqa_past_with_padding( key_padding_mask=key_padding_mask, causal=config.is_causal == 1, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) # --- ONNX Runtime Path --- diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index e2d9acbd0c500..16a00085f224f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -298,6 +298,7 @@ def parity_check_mha_past( attn_bias=attn_bias_ref, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -2393,7 +2394,13 @@ def test_mha_past_asymmetric_v_head_size(self): full_k_bsnh = full_k_bnsh.transpose(1, 2) full_v_bsnh = full_v_bnsh.transpose(1, 2) - out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + out_ref, _ = attention_ref( + q=q, + k=full_k_bsnh, + v=full_v_bsnh, + causal=True, + past_seqlen=config.past_kv_sequence_length, + ) # ORT path — should fall back to unfused (not crash in MEA) out_ort, present_k, present_v = attention_past_func( From 92b4c663496d575c3e22f1e5cfee922b35a6962e Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 26 Jun 2026 13:32:43 +0800 Subject: [PATCH 07/19] webgpu: Enable FlashAttention for batched GQA with right-padded prompts (#29247) ## Summary Lift WebGPU FlashAttention's `batch_size == 1` restriction so batched GQA with right-padded prompts (the common GenAI batched-prefill shape) takes the fused FlashAttention path instead of falling back to `ApplyAttention`. - **Per-batch seqlens in FlashAttention shaders.** Prefill, decode split-reduce, CopyKVCache, and the fused rotary-and-copyKV template now read `seqlens_k[batch_idx]` instead of hardcoding `seqlens_k[0]`. All `past_X = total_X - new_X` subtractions are clamped to avoid u32 underflow when a short batch's per-batch total is less than the batch-wide `sequence_length`. - **Indirect-dispatch sizing uses GQA's `total_sequence_length` input.** `CopyKVCache`, `SplitPackedQKVWithRotaryEmbeddingAndCopyKV`, and `FlashAttentionDecodeQKV` now take a new `total_sequence_length_input` binding (GQA input #6, GPU-resident under graph capture) for the indirect-dispatch grid sizing. This is the global max KV span across the batch by construction, replacing the previous `seqlens_k[0] + 1u` that under-dispatched whenever batch 0 wasn't the longest. Per-batch `seqlens_k[batch] + 1` still drives causal masking and K/V bounds inside the kernels. GQA now enforces `graph_capture_enabled -> past_present_share_buffer_` so the host-side `use_indirect_dispatch` predicate stays simple. - **Decoupled attention_bias stride from per-batch OOB.** `attention_bias` is still allocated to the global max `total_sequence_length`; only the causal-mask / softmax tile loops are gated by the per-batch total. The one-past-end fallback was tightened to clamp inside the same row (`offset_base + stride_total_seq - 1u`). - **Decode workgroup grid stays at global max.** `decode_qkv` keeps a workgroup grid sized to the global max tile count to keep `workgroup_idx` slicing consistent across batches, with neutral `(-inf, 0)` early-exit for tiles beyond a short batch's per-batch total so the `VxReduce` online softmax rescaling is not skewed. - **New `use_seqlen_k` template parameter** (separate from `use_indirect_dispatch` which still requires graph capture). It is enabled whenever `seqlen_k` is provided and (`graph_capture || batch_size_ > 1`). - **Rotary fix prerequisite** (`webgpu: fix GQA batched right-padded prefill with do_rotary`, 591df5b1): clamps `past_seqlen` to 0 in `RotaryEmbeddingProgram`, `FusedQKRotaryEmbeddingProgram`, and `split_packed_qkv_with_rotary_embedding`, which previously produced gibberish for the shorter batches. ## Motivation GenAI's batched prefill right-pads short prompts to the batch max and reports each batch's real length via `seqlens_k[b] = real_len[b] - 1`. The previous FlashAttention gate forced every batched call onto the slower `ApplyAttention` path, and the rotary shaders underflowed `u32` for any batch shorter than the batch-wide `sequence_length`, producing garbage Q/K positions and gibberish output text for the shorter batches. ## Test plan - [x] All `GroupQueryAttentionTest.WebGPU_*` op tests pass, including `BatchedRightPaddedRotaryPrefill` (FlashAttention path) and the new `BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU` covering a `real_lens` spread > tile_size - [x] phi4-prune three-prompt batched generation: coherent outputs on WebGPU matching CPU reference (3 prompts, 384 tokens, 173 tps) - [x] phi4-prune single-prompt generation regression: coherent - [x] phi4-graph-prune (graph capture enabled): `verify_model_correctness.py` 4/4 PASS; `verify_multi_gen.py` sequential + overlapping both PASS - [x] whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU - [x] Lintrunner clean on all changed files --- .../webgpu/bert/flash_attention.cc | 111 ++++++++++++------ .../contrib_ops/webgpu/bert/flash_attention.h | 24 ++-- .../webgpu/bert/flash_attention.wgsl.template | 79 +++++++------ .../flash_attention_decode_qkv.wgsl.template | 52 ++++++-- ...h_attention_decode_vx_reduce.wgsl.template | 10 +- .../webgpu/bert/group_query_attention.cc | 10 +- ..._rotary_embedding_and_copykv.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 66 ++++++++++- 8 files changed, 249 insertions(+), 106 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d9d299d4fd5d9..4e926c7efa597 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -55,6 +55,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + if (prepare_indirect_dispatch_) { + sh.AddInput("total_sequence_length_input", ShaderUsage::None); + } const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); @@ -97,8 +100,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::None); } - // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + // If prepare_indirect_dispatch is enabled, add total_sequence_length_input + // and indirect_buffer output. total_sequence_length_input is the global max + // total sequence length across the batch (from GQA input #6); using it for + // dispatch sizing covers right-padded batches where batch 0 is not the max. if (prepare_indirect_dispatch_) { + shader.AddInput("total_sequence_length_input", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -109,11 +116,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; if (use_seqlen_k_) { - shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[batch]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + // Right-padded batches with prompt shorter than kv_sequence_length would underflow u32; clamp to 0. + shader.MainFunctionBody() << " let past_sequence_length = select(total_seq_length - uniforms.kv_sequence_length, 0u, total_seq_length <= uniforms.kv_sequence_length);\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; } else { @@ -124,7 +132,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (prepare_indirect_dispatch_) { shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" - << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" + << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -152,7 +161,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles, + const Tensor* total_seqlen) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -188,6 +198,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -262,9 +275,15 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (use_indirect_dispatch_) { + // Global max total sequence length across batches (from GQA input #6). + // Used in indirect-dispatch mode for the workgroup_idx slicing so that + // batch 0's per-batch length cannot undersize the dispatch grid. + shader.AddInput("total_sequence_length_input", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -282,6 +301,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), @@ -293,7 +313,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -303,13 +323,16 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (use_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -320,10 +343,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } if (use_indirect_dispatch) { @@ -332,7 +357,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -343,6 +368,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, {attn_bias_dim1}, + {attn_bias_dim3}, {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); @@ -351,7 +377,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -364,7 +390,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(output, output)); @@ -379,17 +405,17 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch, const Tensor* head_sink, - uint32_t m_tile) { + uint32_t m_tile, + bool use_seqlen_k) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { @@ -399,7 +425,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) + .CacheHint(tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -415,7 +441,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink, + const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr. @@ -437,7 +464,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Read seqlens_k per batch_idx in the shader whenever seqlens_k is supplied. + // This covers both graph-capture (total_sequence_length_ is 0 on the host) and + // right-padded batches (batch_size > 1 with distinct per-batch totals), and lets + // batch=1 share the same path. When seqlens_k is null, kernels fall back to + // uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr; // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -453,8 +485,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Prepare indirect dispatch buffer for split-reduce path with static KV cache. // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && + // Static KV cache (past_present_share_buffer_) is guaranteed by GQA's + // ORT_ENFORCE when graph capture is enabled. + const bool use_indirect_dispatch = seqlen_k != nullptr && + total_seqlen != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions @@ -492,10 +526,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles)); + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); } // Extract present_sequence_length directly from present_key tensor shape @@ -555,10 +590,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) @@ -572,7 +609,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {alpha}, {num_seq_tile}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {attn_bias_dim3}}); return context.RunProgram(program); } @@ -596,27 +634,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile)); + present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - head_sink, m_tile)); + num_present_sequence_length_tile, tile_size, + head_sink, m_tile, use_seqlen_k)); return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - const bool kv_empty = parameters.kv_sequence_length_ == 0; - // FlashAttention here does not implement right-padded per-batch prefill, so the - // first disjunction restricts it to inputs where padding cannot occur: - // - batch_size_ == 1: single sequence, no padding possible. - // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. - // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. - // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), - // mismatched head/value sizes, and head_size alignments unsupported by the kernel. - return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && - !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } @@ -631,7 +660,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size, uint32_t num_q_tiles) { + uint32_t tile_size, uint32_t num_q_tiles, + const Tensor* total_seqlen) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -669,6 +699,9 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {cos_cache, ProgramTensorMetadataDependency::Rank, components}, {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }); + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, {present_key, ProgramTensorMetadataDependency::None, components}, {present_value, ProgramTensorMetadataDependency::None, components}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 218baf926173f..85ba61c1d20b5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -125,7 +125,8 @@ class FlashAttentionProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, - {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim3", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -148,8 +149,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -195,17 +199,18 @@ class FlashAttentionDecodeVxReduceProgram final : public Program u32 { - return u32(seqlens_k[0]) + 1u; +// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. +fn get_total_sequence_length(batch_idx: u32) -> u32 { + return u32(seqlens_k[batch_idx]) + 1u; } #else -// When graph capture is disabled, total_sequence_length comes from uniforms -fn get_total_sequence_length() -> u32 { +// Without seqlens_k, total_sequence_length comes from uniforms (max across batches). +fn get_total_sequence_length(batch_idx: u32) -> u32 { return uniforms.total_sequence_length; } #endif @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); @@ -95,15 +93,19 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { } #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t { - if (k_idx_global >= get_total_sequence_length()) { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t { + if (k_idx_global >= total_seq) { return q_element_t(0); } let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); - return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; + return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq - 1u)]); } #endif @@ -111,24 +113,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // For max performance max_k_step should be the same as sg_size, however we might run out of registers // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } } @@ -160,18 +162,22 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= total_seq) { return vec4(0); } // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; - let offset_max = offset_base + get_total_sequence_length(); + let offset_max = offset_base + stride_total_seq - 1u; let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -179,7 +185,7 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { return vec4(0); } #endif @@ -226,11 +232,14 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; #endif - let total_sequence_length = get_total_sequence_length(); + let total_sequence_length = get_total_sequence_length(batch_idx); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); @@ -244,8 +253,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx); - loadv(k_start, batch_head_idx, local_idx); + loadk(k_start, batch_head_idx, local_idx, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, total_sequence_length); workgroupBarrier(); for (var k = 0u; k < max_k_step; k++) { @@ -254,7 +263,7 @@ $MAIN { score += dot(q_tile[i], k_tile[k][i]); } #if has_attention_bias - score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx); + score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length); #endif qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length); } @@ -302,8 +311,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx, capped_sg_size); - loadv(k_start, batch_head_idx, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); workgroupBarrier(); // Compute QKt @@ -361,11 +370,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length); } // Neuter qk values where K is out of bounds. diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 524a18ca43245..778e07fbf63ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -10,6 +10,7 @@ #param tile_size #param tile_size_k_vec #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -34,18 +35,22 @@ var tile_max: array; var tile_sum: array; #if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + - bias_head_idx * uniforms.new_sequence_length * total_seq_length + - q_idx * total_seq_length + + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + + q_idx * stride_total_seq + k_idx; return attention_bias[offset]; } #else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { return q_element_t(0); } @@ -54,12 +59,14 @@ var tile_sum: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; + // total_sequence_length used for workgroup_idx slicing must match the host-side dispatch + // grid, i.e. the global maximum across batches. Per-batch total is derived separately below. #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + let global_total_sequence_length = u32(total_sequence_length_input[0]); #else - let total_sequence_length = uniforms.total_sequence_length; + let global_total_sequence_length = uniforms.total_sequence_length; #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_total_seq_length_tile = (global_total_sequence_length + tile_size - 1) / tile_size; let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; @@ -71,9 +78,28 @@ $MAIN { if (batch_idx >= uniforms.batch_size) { return; } + // Per-batch total_sequence_length used for K/V bounds, causal mask, and softmax range. + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[batch_idx]) + 1u; + #else + let total_sequence_length = global_total_sequence_length; + #endif let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + // If this workgroup's tile lies entirely beyond this batch's per-batch total_sequence_length, + // write neutral metadata so VxReduce contributes nothing for these tiles, then exit early. + if (total_seq_offset >= total_sequence_length) { + if (local_idx == 0u) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx_local = q_base + m; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx_local) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(-3.4028234663852886e+38f, 0.0f)); + } + } + return; + } + // ============================================================ // Phase 1: QK^T computation // ============================================================ @@ -109,6 +135,12 @@ $MAIN { } // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask +#if is_unidirectional + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); +#endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { @@ -117,9 +149,9 @@ $MAIN { sum += inner_qk_values[m][local_idx][i]; } - sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional - if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + if (total_seq_offset + local_idx > past_sequence_length + q_idx) { sum = q_element_t(-65504.0f); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a3ce0b68cb659..628ad835a9d4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -5,7 +5,7 @@ #param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -32,8 +32,12 @@ $MAIN { } let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + // Per-batch total_sequence_length: short batches contributed neutral metadata + // (-inf, 0) for tiles beyond their per-batch total, so reading only this batch's + // tiles ensures softmax rescaling is not skewed by garbage tiles. + #if use_seqlen_k + let batch_idx_for_seqlen = batch_head_idx / uniforms.num_heads; + let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 36d688c9723fd..24ace3487a4c5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -327,6 +327,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& past_value->DataRaw() == present_value->DataRaw(); ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + ORT_ENFORCE(!context.IsGraphCaptureEnabled() || parameters.past_present_share_buffer_, + "Graph capture requires past/present KV cache to share the same buffer (static KV cache)."); Tensor qSplit; Tensor kSplit; @@ -350,7 +352,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (kv_empty) { @@ -381,7 +383,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink); + present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink, + total_seqlen_tensor); } // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -472,7 +475,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink, + total_seqlen_tensor); } // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 97c610fb90024..e3d92c036d2c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -43,7 +43,8 @@ $MAIN { #if prepare_indirect_dispatch if (global_idx == 0u) { - let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + let global_total_seq_length = u32(total_sequence_length_input[0]); + let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index c6f6e0affc163..28c41b5bf5ed4 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2613,7 +2613,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - GqaTargetEp target_ep = GqaTargetEp::kCpu) { + GqaTargetEp target_ep = GqaTargetEp::kCpu, + bool smooth_softmax = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2632,6 +2633,12 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); + if (smooth_softmax) { + // smooth_softmax disqualifies the WebGPU FlashAttention path via the outer + // gating in GroupQueryAttention::ComputeInternal, routing this case through + // ApplyAttention instead. + tester.AddAttribute("smooth_softmax", static_cast(1)); + } // Packed QKV: pass through `query` input, leave key/value as optional edges. if (use_fp16) { @@ -2722,7 +2729,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, + const std::vector& real_lens = {4, 2, 6}, + bool smooth_softmax = false) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2731,10 +2740,10 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int kv_hidden_size = kv_num_heads * head_size; constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; - // Real prompt lengths per batch; max = sequence_length (right-padding extends + // Per-batch real prompt lengths; max = sequence_length (right-padding extends // shorter batches up to this length). The bug only manifests when at least // one batch is shorter than sequence_length. - const std::vector real_lens = {4, 2, 6}; + ASSERT_EQ(static_cast(real_lens.size()), batch_size); const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); std::vector packed_batched; @@ -2755,7 +2764,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, target_ep); + packed_single, target_ep, smooth_softmax); } // Now run all batches together with right-padding. @@ -2765,7 +2774,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, target_ep); + seqlens_k_data, packed_batched, target_ep, smooth_softmax); // Guard the regression deterministically: every element of the batched output // (including padding rows) must be finite. The CPU root cause is uninitialized @@ -2818,5 +2827,50 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with per-batch +// real_lens whose max crosses the prefill threshold (sequence_length >= 32) so +// the WebGPU EP picks FlashAttentionProgram (single-kernel prefill path with +// subgroup shuffles) instead of the split-reduce decode path. This exercises +// the prefill flash-attention kernel under right-padded batches with do_rotary, +// which is the path used by Phi-style models during batched prefill. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // sequence_length = max(real_lens) = 33 > 32 -> FlashAttentionProgram path. + // Mixed shorter batches (12, 20) ensure right-padding is non-trivial. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); +} + +// Stress the FlashAttention prefill path with a per-batch spread that exceeds +// the indirect-dispatch tile size (64). batch 0 has the SHORTEST real length; +// batch 2 has the LONGEST. This is the data pattern that would surface the +// indirect-dispatch undersizing bug when graph capture is enabled (where the +// dispatch grid is sized from a GPU buffer rather than the host scalar). +// OpTester does not toggle graph capture, so this test exercises the new +// total_sequence_length_input shader plumbing on the non-graph-capture path; +// the graph-capture path is covered end-to-end by phi4-graph-prune verification. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // spread = 96 - 20 = 76 > tile_size(64), batch 0 is not the max. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 96}); +} + +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with +// smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes +// through ApplyAttention (non-flash path). Covers right-padded batched prefill +// on the non-flash attention path (used by e.g. Phi-4 attention variants). +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); +} + } // namespace test } // namespace onnxruntime From a203dfafc94b5446a59ddd67e92e6b4b66b01d7e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jun 2026 23:29:40 -0700 Subject: [PATCH 08/19] [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention (#29216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description PR1 https://github.com/microsoft/onnxruntime/pull/28962 adds flash attention for **prefill**, and removed flash decoding. This PR will add optimized kernel for **single-token decode**, which will be faster than other kernels including flash decoding. This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel. #### What's included - **Decode (GEMV) kernel** — A dedicated single-token decode kernel (`MlasGQADecodeGQAThreaded`) for `sequence_length == 1`, parallelized over (batch, head) with a two-pass softmax, using GEMV (`acc[8]`-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression. - The FP32 flash gate (`group_query_attention.cc`) is enabled for `total_sequence_length > 1`, routing prefill to the tiled kernel and decode to the GEMV kernel. - The quantized KV-cache path is unchanged (FP32-only scope). #### Results (AMD EPYC 7763, AVX2, 8 threads) - **Decode:** correctness ~1e-8 vs naive; long-context decode ~1.0–1.5x (T = 4097 ~1.3–1.5x). ### Motivation and Context The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case. ### Testing - Built with `--compile_no_warning_as_error`. - Correctness verified against the naive path for both prefill and decode (max abs diff ~1e-8). - Benchmarked via `benchmark_gqa_cpu_flash.py`. --- docs/contrib_ops/cpu/gqa.md | 79 ++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 51 +- .../cpu/bert/group_query_attention.cc | 10 +- onnxruntime/core/mlas/inc/mlas.h | 14 +- onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 528 +++++++++++++++++- .../test/python/transformers/test_gqa_cpu.py | 64 +++ 6 files changed, 707 insertions(+), 39 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 8b81fdba8f1a6..ffce29682ca42 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -58,7 +58,7 @@ Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). @@ -213,7 +213,7 @@ The partials buffer is allocated alongside the per-thread scratch in a single al ## Non-Quantized Flash Attention Path -The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here. +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. ### Differences from the Quantized Path @@ -242,7 +242,8 @@ computes a per-q_block bound `ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the main reason the FP32 flash path is faster than naive even at short sequence lengths -(see the benchmark results below). +(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all +KV positions, so the bound equals `total_seqlen` and nothing is skipped. ### Activation Conditions @@ -250,18 +251,52 @@ The non-quantized flash path is selected when ALL of the following hold: - The kernel specialization is `float` (FP16 uses the naive path) - `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) -- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path) +- `total_sequence_length > 1` - No softcap - No smooth softmax - No head sink - No output QK capture - `present_key` and `present_value` are provided -Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes, Threading, and Flash Decoding + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + +#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`) + +The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for +prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size` +query rows and tiling delivers real cache-locality and SGEMM packing benefits. + +For single-token decode the query tile has `M = 1`, so every K/V element is streamed +exactly once with no reuse across query rows. Tiling provides **no** cache-locality +benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM +B-packing/setup cost on every call — which previously made the flash decode path *slower* +than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths. -### Block Sizes and Threading +Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`), +dispatched whenever `sequence_length == 1` and flash decoding is not active. It +parallelizes over `(batch, head)` and, per head, computes the attention directly with two +matrix-vector products and a two-pass softmax: + +- **QK GEMV** — `scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`. +- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` / + `ComputeSumExpF32Kernel` helpers. +- **SV GEMV** — `out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`. + +Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS +translation unit, so their inner loops use independent accumulator lanes / map-style +updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the +single new token is the most recent position and attends to every cached token); only +optional local-window masking and additive attention bias are applied. The kernel streams +K and V exactly once each, so it is memory-bandwidth bound. + +The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned +across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products +instead of `M = 1` SGEMM calls, removing the same packing overhead. -Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode. ## MLAS Dispatch Paths @@ -513,13 +548,31 @@ offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/outp The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is algorithmic rather than purely from threading. -#### Decode (S = 1, token generation) +#### Latency — Decode (S = 1, token generation) + +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H` +GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead +(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev +box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness +to suppress that jitter. -Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls -back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head, -so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel -slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps -the consistent prefill win without regressing decode. +| Total Seqlen | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) | +| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) | +| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) | +| 4097 | 6.1 | 4.5 | 1.35\u20131.5x | + +Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of +the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the +per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM +B-packing overhead that previously made flash decode noticeably **slower** than naive +(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity +for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated) +and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed +single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound, +so it is not the target of the prefill-oriented causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 413483756ed5c..60fa4f0c4ada1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1091,16 +1091,49 @@ class GQAAttentionBase { int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); thread_count = std::max(thread_count, 1); - // Per-thread scratch: l + m + scores[q_block_size * kv_block_size] + temp_output[q_block_size * head_size] - const size_t buffer_size_per_thread = - (SafeInt(q_block_size) * 2 + // l + m - SafeInt(q_block_size) * kv_block_size + // scores - SafeInt(q_block_size) * head_size) * // temp_output - sizeof(float); - size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count; + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. This KV-split is only + // wired into the unified kernel (common_past_seqlen >= 0); the ragged/per-batch + // fallback runs the single-pass decode kernel instead, which needs a larger + // per-thread scratch (scores[total_seqlen] + temp_output[head_size]). Gating on + // common_past_seqlen >= 0 keeps the per-thread buffer sizing below consistent + // with the kernel that actually runs. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + common_past_seqlen >= 0 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else if (sequence_length == 1) { + // Decode (GEMV kernel, no Q/KV tiling): per-thread scratch holds the full + // score row scores[total_seqlen] plus a temp output accumulator[head_size]. + buffer_size_per_thread = + (SafeInt(max_total_seqlen) + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; // If all batch items share the same past_seqlen, use the unified flash kernel. @@ -1133,6 +1166,8 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; MlasFlashAttentionGQA(&args, tp); } else { @@ -1185,6 +1220,8 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = true; // batch offset handled above args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; MlasFlashAttentionGQA(&args, tp); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 692a5759d1adc..debda282eb4f1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -348,13 +348,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // kernel to avoid materializing the full attention score matrix. Falls back to the // naive path when an unsupported feature is requested (softcap, smooth softmax, // head sink, or QK output). + // + // Prefill (sequence_length > 1) uses the tiled kernel; single-token decode + // (sequence_length == 1 with total_sequence_length > 1) uses the dedicated GEMV + // decode kernel. Both are reached when total_sequence_length > 1. if constexpr (std::is_same_v) { - // Restrict the flash path to prefill / chunked-prefill (query length > 1). Single-token - // decode (sequence_length == 1) has no flash benefit: the naive score matrix is only - // [1, total_sequence_length] per head, so there is nothing to tile away, and the extra - // online-softmax bookkeeping makes it slower in practice. const bool use_flash = !disable_gqa_flash_ && - parameters.sequence_length > 1 && + parameters.total_sequence_length > 1 && softcap_ == 0.0f && !use_smooth_softmax_ && head_sink_data == nullptr && diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 2410fcc83e7cd..cbca2d85a97a4 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2302,9 +2302,9 @@ MlasFlashAttention( // // Adapts the online-softmax tiled algorithm to operate on an FP32 present // K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). -// Supports GQA head grouping, causal masking, local window attention, and -// additive attention bias. Intended for prefill / chunked-prefill -// (sequence_length > 1). +// Supports GQA head grouping, causal masking, local window attention, +// additive attention bias, and an optional flash-decoding split over the KV +// sequence dimension for the single-token decode case. // struct MlasFlashAttentionGQAArgs { int batch_size; @@ -2320,7 +2320,7 @@ struct MlasFlashAttentionGQAArgs { int kv_block_size; // key/value tile size (Bc) float scale; // QK scaling factor int thread_count; // number of partitions / threads - float* buffer; // per-thread scratch + float* buffer; // per-thread scratch (+ optional flash-decoding partials) size_t buffer_size_per_thread; const float* query; // [batch, num_heads, sequence_length, head_size] BNSH @@ -2334,6 +2334,12 @@ struct MlasFlashAttentionGQAArgs { int attention_bias_seqlen_stride; bool attention_bias_broadcast_batch; bool attention_bias_broadcast_head; + + // Flash decoding (sequence_length == 1): partition KV across threads. + // Set flash_decoding_partials != nullptr to enable; otherwise the standard + // per-(batch, head, q_block) partitioning is used. + float* flash_decoding_partials; + int kv_chunk_count; }; /** diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 4d0ff65733a44..0f1210ca1e3c5 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -17,11 +17,14 @@ Module Name: an FP32 present K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head grouping (num_heads % kv_num_heads == 0), causal masking, local window - attention, and additive attention bias. Intended for prefill / - chunked-prefill (sequence_length > 1). + attention, additive attention bias, and an optional flash-decoding split + over the KV sequence dimension for single-token decode. - QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; - the outer parallelism is provided by MlasExecuteThreaded. + For multi-token prefill (sequence_length > 1) QK^T and S*V use the + single-threaded SGEMM primitive MlasSgemmOperation. For single-token decode + (sequence_length == 1, including the flash-decoding KV split) the M == 1 + GEMVs use the local MlasGQADecodeQK / MlasGQADecodeSV helpers to avoid SGEMM + packing overhead. The outer parallelism is provided by MlasExecuteThreaded. --*/ @@ -32,6 +35,71 @@ Module Name: #include "mlasi.h" +// +// Decode (sequence_length == 1) GEMV helpers. +// +// With a single query token the QK^T and S*V products degenerate into +// matrix-vector products. Computing them directly streams the K and V cache +// exactly once and avoids the SGEMM B-packing overhead that otherwise dominates +// the tiny M = 1 GEMMs. These helpers live in the baseline-ISA MLAS translation +// unit, so the inner loops are written with independent accumulator lanes and a +// map-style update so the compiler can vectorize them without -ffast-math +// (which would be required to reassociate a plain scalar float reduction). +// + +// QK^T GEMV: scores[t] = scale * dot(q[0..H), K[t*H .. t*H+H)) for t in [0, n_kv). +static void +MlasGQADecodeQK( + const float* q, + const float* k_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float scale, + float* scores +) +{ + constexpr std::ptrdiff_t kLanes = 8; + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float* krow = k_cache + t * head_size; + float acc[kLanes] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + std::ptrdiff_t h = 0; + for (; h + kLanes <= head_size; h += kLanes) { + for (std::ptrdiff_t j = 0; j < kLanes; ++j) { + acc[j] += q[h + j] * krow[h + j]; + } + } + float sum = ((acc[0] + acc[1]) + (acc[2] + acc[3])) + + ((acc[4] + acc[5]) + (acc[6] + acc[7])); + for (; h < head_size; ++h) { + sum += q[h] * krow[h]; + } + scores[t] = sum * scale; + } +} + +// S*V GEMV (accumulate): out[h] = sum_t probs[t] * V[t*H + h] for h in [0, head_size). +// `out` is overwritten (initialized to zero) before accumulation. +static void +MlasGQADecodeSV( + const float* probs, + const float* v_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float* out +) +{ + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] = 0.0f; + } + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float p = probs[t]; + const float* vrow = v_cache + t * head_size; + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] += p * vrow[h]; + } + } +} + void MlasFlashAttentionGQAThreaded( void* argptr, @@ -294,6 +362,415 @@ MlasFlashAttentionGQAThreaded( } } +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). + // The batch stride is supplied separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV for this KV chunk (M = 1) + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeQK(q_ptr, k_block, static_cast(row_size_kv), head_size, scale, scores); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride + // uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output (M = 1) + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeSV(scores, v_block, static_cast(row_size_kv), head_size, partial_output); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingGQAReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +// +// Decode kernel for sequence_length == 1 without KV-split (batch * heads >= +// thread_count). Parallelizes over (batch, head); each task attends the single +// query token to the whole KV cache with a pair of GEMVs and a two-pass softmax. +// Decode needs no causal masking (the single new token is the most recent +// position and attends to every cached token); only optional local-window +// masking and additive bias are applied. +// +void +MlasGQADecodeGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // One task per (batch, head). + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + // Local-window low bound: decode can attend to KV positions [window_start, total_seqlen). + // causal_limit == past_seqlen + 1 == total_seqlen for the single new token. + const ptrdiff_t window_start = + (local_window_size >= 0 && total_seqlen > local_window_size) ? (total_seqlen - local_window_size) : 0; + + // Per-thread scratch: scores[total_seqlen] followed by temp_output[head_size]. + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + float* temp_output = scores + total_seqlen; + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + const ptrdiff_t head_idx = task_index % num_heads; + const ptrdiff_t batch_idx = task_index / num_heads; + + // KV head index for GQA head sharing. + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size]; batch stride supplied + // separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV -> scores[0..T) + MlasGQADecodeQK(q_ptr, k_cache_head, total_seqlen, head_size, scale, scores); + + // Step 1b: additive attention bias (shape [batch|1, num_heads|1, S=1, T]). + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_matrix_size = + static_cast(args->attention_bias_seqlen_stride); // S == 1 + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset; + for (ptrdiff_t t = 0; t < total_seqlen; ++t) { + scores[t] += bias_row[t]; + } + } + + // Step 2: local-window masking (no causal mask needed for decode). + if (window_start > 0) { + for (ptrdiff_t t = 0; t < window_start; ++t) { + scores[t] = std::numeric_limits::lowest(); + } + } + + // Step 3: softmax over scores[0..T). +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, total_seqlen); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, total_seqlen); +#endif + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + if (rowmax == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#endif + + // Step 4: S_exp * V GEMV -> temp_output, then normalize by 1/l. + MlasGQADecodeSV(scores, v_cache_head, total_seqlen, head_size, temp_output); + + const float inv_l = (rowsum > 0.0f) ? (1.0f / rowsum) : 0.0f; + for (ptrdiff_t h = 0; h < head_size; ++h) { + output_ptr[h] = temp_output[h] * inv_l; + } + } +} + void MLASCALL MlasFlashAttentionGQA( @@ -301,10 +778,41 @@ MlasFlashAttentionGQA( MLAS_THREADPOOL* ThreadPool ) { - MlasExecuteThreaded( - MlasFlashAttentionGQAThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); + if (args->sequence_length == 1) { + // Decode: M = 1, use the GEMV kernels (no SGEMM packing overhead). + if (args->flash_decoding_partials != nullptr) { + // Flash decoding: two-phase approach when KV is partitioned across threads. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + // Single-pass decode parallelized over (batch, head). + MlasExecuteThreaded( + MlasGQADecodeGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } + } else { + // Prefill (sequence_length > 1): tiled online-softmax SGEMM kernel. + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } } + diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 19968db98edd7..c438790ab5950 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -19,6 +19,7 @@ import torch from bert_padding import pad_input, unpad_input from einops import rearrange, repeat +from env_var_helper import scoped_env_var from onnx import TensorProto, helper from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -2645,6 +2646,69 @@ def test_gqa_past(self): qk_output, ) + def test_gqa_decode_flash_vs_naive_parity(self): + # The FP32 flash gate enables the dedicated GEMV decode kernel (and the + # flash-decoding KV-split reduction) for sequence_length == 1 with + # total_sequence_length > 1. Run the same decode configs against the + # reference twice: once with the flash path enabled (default) and once + # with it disabled via ORT_GQA_DISABLE_FLASH_ATTENTION=1 (naive path). + # If both paths match the reference, the decode kernel and KV-split + # reduction are correct -- including the bias and local-window cases. + print("-------- TEST GQA DECODE FLASH VS NAIVE PARITY ---------") + + # FP32 only: the GEMV decode kernel and flash gate are float-only. + torch_type = torch.float32 + numpy_type = numpy.float32 + ort_type = TensorProto.FLOAT + rtol = 1e-3 + atol = 1e-3 + + batches = [1, 3] + # (sequence_length == 1) decode. Include a long KV length so that the + # flash-decoding KV-split path (kv_chunk_count > 1) is exercised. + seqs = [(1, 128), (1, 2048)] + num_h = [(9, 3)] + h_sizes = [64, 128] + + # "0" keeps the flash path enabled; "1" forces the naive path. Reseed per + # phase so both paths are validated against the reference on identical + # inputs, independent of test execution order. + for env_value in ["0", "1"]: + with scoped_env_var("ORT_GQA_DISABLE_FLASH_ATTENTION", env_value): + print(f" flash {'disabled (naive path)' if env_value == '1' else 'enabled'}") + random.seed(69) + torch.manual_seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for has_attn in [False, True]: + config = Config( + b, + s, + s2, + 0, + n, + n2, + h, + False, + has_attn, + False, + QKOutputType.NO_OUTPUT, + ) + all_close = parity_check_gqa_past( + config, + torch_type=torch_type, + numpy_type=numpy_type, + ort_type=ort_type, + local=local, + past_format=Formats.BNSH, + rtol=rtol, + atol=atol, + ) + self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] From 0271287d570c633c477288364d1b4aaca4ac9778 Mon Sep 17 00:00:00 2001 From: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:46:23 -0700 Subject: [PATCH 09/19] Fix unbounded lifetime on WithOutputTensor in Rust bindings (#29251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Fix unbounded lifetime on WithOutputTensor in Rust bindings ## Description The `WithOutputTensor<'a, T>` struct had a free lifetime parameter `'a` on its `TryFrom` impl that was unconstrained by any input. Combined with the `Deref` impl (whose `Target = ArrayView<'a, T, IxDyn>` exposed a `Clone`-able view), it was possible for the `ArrayView` to outlive the underlying `OrtOutputTensor` buffer owner. This change restructures `WithOutputTensor` to eliminate the unbounded lifetime: - Removes the `'a` lifetime parameter from `WithOutputTensor`, `OrtOutput`, and `Session::run` - Removes the `Deref` impl (the escape hatch) - Replaces the stored `ArrayView<'a, T>` with a raw pointer + shape - Adds a `view(&self)` method returning `ArrayView<'_, T, IxDyn>` — the view lifetime is now tied to `&self` - Updates all call sites (examples, integration tests) to use `.view()` ## Motivation The C API contract (`onnxruntime_c_api.h`) explicitly bounds the data pointer lifetime to the `OrtValue`: the pointer is only valid until the value is destroyed. The Rust type system must enforce this invariant. Previously it did not — the `ArrayView` could be cloned out and observed after the `OrtValue` was freed. ## API Change ```rust // Before: Deref-based access let output = outputs[0].float_array().unwrap(); let sum: f32 = output.iter().sum(); // After: explicit view() call let output = outputs[0].float_array().unwrap(); let sum: f32 = output.view().iter().sum(); ``` ## Testing Existing integration tests updated to use the new `view()` API. The fix is enforced at compile time by the borrow checker — the previously problematic pattern now produces a lifetime error. Co-authored-by: Sayan Shaw --- rust/onnxruntime/examples/issue22.rs | 2 +- rust/onnxruntime/examples/sample.rs | 5 +- rust/onnxruntime/src/session.rs | 6 +- .../src/tensor/ort_output_tensor.rs | 117 ++++++++++-------- rust/onnxruntime/tests/integration_tests.rs | 6 +- 5 files changed, 80 insertions(+), 56 deletions(-) diff --git a/rust/onnxruntime/examples/issue22.rs b/rust/onnxruntime/examples/issue22.rs index 6c96e899fa774..1fb7fe28ff123 100644 --- a/rust/onnxruntime/examples/issue22.rs +++ b/rust/onnxruntime/examples/issue22.rs @@ -51,5 +51,5 @@ fn main() { let outputs = session.run(inputs).unwrap(); - print!("outputs: {:#?}", outputs[0].float_array().unwrap()); + print!("outputs: {:#?}", outputs[0].float_array().unwrap().view()); } diff --git a/rust/onnxruntime/examples/sample.rs b/rust/onnxruntime/examples/sample.rs index 9af5cf733ccae..b6f351b2082ed 100644 --- a/rust/onnxruntime/examples/sample.rs +++ b/rust/onnxruntime/examples/sample.rs @@ -73,10 +73,11 @@ fn run() -> Result<(), Error> { let outputs = session.run(input_tensor_values)?; let output = outputs[0].float_array().unwrap(); + let view = output.view(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(view.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, view[[0, i, 0, 0]]); } Ok(()) diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs index 326426e35982c..d475d1b724111 100644 --- a/rust/onnxruntime/src/session.rs +++ b/rust/onnxruntime/src/session.rs @@ -410,10 +410,10 @@ impl Session { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'input, 'output>( - &'output self, + pub fn run<'input>( + &self, mut input_arrays: impl AsMut<[Box]> + 'input, - ) -> Result>> { + ) -> Result> { let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; diff --git a/rust/onnxruntime/src/tensor/ort_output_tensor.rs b/rust/onnxruntime/src/tensor/ort_output_tensor.rs index 83663c0d303f8..727cae1db0ef4 100644 --- a/rust/onnxruntime/src/tensor/ort_output_tensor.rs +++ b/rust/onnxruntime/src/tensor/ort_output_tensor.rs @@ -71,22 +71,41 @@ impl Drop for OrtOutputTensor { } /// An Output tensor with the ptr and the item that will copy from the ptr. -#[derive(Debug)] -pub struct WithOutputTensor<'a, T> { - #[allow(dead_code)] +/// +/// The view is materialized on each access via [`view()`](Self::view) to ensure the +/// borrowed lifetime is tied to `&self`, preventing the view from outliving the +/// underlying buffer owned by the `OrtOutputTensor`. +pub struct WithOutputTensor { pub(crate) tensor: OrtOutputTensor, - item: ArrayView<'a, T, ndarray::IxDyn>, + data_ptr: *const T, + shape: Vec, } -impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> { - type Target = ArrayView<'a, T, ndarray::IxDyn>; +impl Debug for WithOutputTensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WithOutputTensor") + .field("tensor", &self.tensor) + .field("data_ptr", &self.data_ptr) + .field("shape", &self.shape) + .finish() + } +} - fn deref(&self) -> &Self::Target { - &self.item +// SAFETY: The data pointer is derived from OrtOutputTensor which owns the allocation. +// Access is only possible through &self (via view()), so Send/Sync follow from T: Send/Sync. +unsafe impl Send for WithOutputTensor {} +unsafe impl Sync for WithOutputTensor {} + +impl WithOutputTensor { + /// Returns an [`ArrayView`] over the output tensor data. + /// + /// The returned view borrows `self`, so it cannot outlive the tensor owner. + pub fn view(&self) -> ArrayView<'_, T, ndarray::IxDyn> { + unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&self.shape), self.data_ptr) } } } -impl<'a, T> TryFrom for WithOutputTensor<'a, T> +impl TryFrom for WithOutputTensor where T: TypeToTensorElementDataType, { @@ -110,45 +129,45 @@ where status_to_result(status).map_err(OrtError::IsTensor)?; assert_ne!(output_array_ptr, std::ptr::null_mut()); - let array_view = - unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&value.shape), output_array_ptr) }; + let shape = value.shape.clone(); Ok(WithOutputTensor { tensor: value, - item: array_view, + data_ptr: output_array_ptr, + shape, }) } } /// The onnxruntime Run output type. -pub enum OrtOutput<'a> { +pub enum OrtOutput { /// Tensor of f32s - Float(WithOutputTensor<'a, f32>), + Float(WithOutputTensor), /// Tensor of f64s - Double(WithOutputTensor<'a, f64>), + Double(WithOutputTensor), /// Tensor of u8s - UInt8(WithOutputTensor<'a, u8>), + UInt8(WithOutputTensor), /// Tensor of u16s - UInt16(WithOutputTensor<'a, u16>), + UInt16(WithOutputTensor), /// Tensor of u32s - UInt32(WithOutputTensor<'a, u32>), + UInt32(WithOutputTensor), /// Tensor of u64s - UInt64(WithOutputTensor<'a, u64>), + UInt64(WithOutputTensor), /// Tensor of i8s - Int8(WithOutputTensor<'a, i8>), + Int8(WithOutputTensor), /// Tensor of i16s - Int16(WithOutputTensor<'a, i16>), + Int16(WithOutputTensor), /// Tensor of i32s - Int32(WithOutputTensor<'a, i32>), + Int32(WithOutputTensor), /// Tensor of i64s - Int64(WithOutputTensor<'a, i64>), + Int64(WithOutputTensor), /// Tensor of Strings - String(WithOutputTensor<'a, String>), + String(WithOutputTensor), } -impl<'a> OrtOutput<'a> { - /// Return `WithOutputTensor<'a, f32>` which derefs into an `ArrayView`. - pub fn float_array(&self) -> Option<&WithOutputTensor<'a, f32>> { +impl OrtOutput { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn float_array(&self) -> Option<&WithOutputTensor> { if let Self::Float(item) = self { Some(item) } else { @@ -156,8 +175,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, f64>` which derefs into an `ArrayView`. - pub fn double_array(&self) -> Option<&WithOutputTensor<'a, f64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn double_array(&self) -> Option<&WithOutputTensor> { if let Self::Double(item) = self { Some(item) } else { @@ -165,8 +184,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u8>` which derefs into an `ArrayView`. - pub fn uint8_array(&self) -> Option<&WithOutputTensor<'a, u8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint8_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt8(item) = self { Some(item) } else { @@ -174,8 +193,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u16>` which derefs into an `ArrayView`. - pub fn uint16_array(&self) -> Option<&WithOutputTensor<'a, u16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint16_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt16(item) = self { Some(item) } else { @@ -183,8 +202,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u32>` which derefs into an `ArrayView`. - pub fn uint32_array(&self) -> Option<&WithOutputTensor<'a, u32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint32_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt32(item) = self { Some(item) } else { @@ -192,8 +211,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u64>` which derefs into an `ArrayView`. - pub fn uint64_array(&self) -> Option<&WithOutputTensor<'a, u64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint64_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt64(item) = self { Some(item) } else { @@ -201,8 +220,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i8>` which derefs into an `ArrayView`. - pub fn int8_array(&self) -> Option<&WithOutputTensor<'a, i8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int8_array(&self) -> Option<&WithOutputTensor> { if let Self::Int8(item) = self { Some(item) } else { @@ -210,8 +229,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i16>` which derefs into an `ArrayView`. - pub fn int16_array(&self) -> Option<&WithOutputTensor<'a, i16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int16_array(&self) -> Option<&WithOutputTensor> { if let Self::Int16(item) = self { Some(item) } else { @@ -219,8 +238,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i32>` which derefs into an `ArrayView`. - pub fn int32_array(&self) -> Option<&WithOutputTensor<'a, i32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int32_array(&self) -> Option<&WithOutputTensor> { if let Self::Int32(item) = self { Some(item) } else { @@ -228,8 +247,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i64>` which derefs into an `ArrayView`. - pub fn int64_array(&self) -> Option<&WithOutputTensor<'a, i64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int64_array(&self) -> Option<&WithOutputTensor> { if let Self::Int64(item) = self { Some(item) } else { @@ -237,8 +256,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, String>` which derefs into an `ArrayView`. - pub fn string_array(&self) -> Option<&WithOutputTensor<'a, String>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn string_array(&self) -> Option<&WithOutputTensor> { if let Self::String(item) = self { Some(item) } else { @@ -247,10 +266,10 @@ impl<'a> OrtOutput<'a> { } } -impl<'a> TryFrom for OrtOutput<'a> { +impl TryFrom for OrtOutput { type Error = OrtError; - fn try_from(value: OrtOutputTensor) -> Result> { + fn try_from(value: OrtOutputTensor) -> Result { unsafe { let mut shape_info = std::ptr::null_mut(); diff --git a/rust/onnxruntime/tests/integration_tests.rs b/rust/onnxruntime/tests/integration_tests.rs index 7843fe269e5e4..1c096400eccf7 100644 --- a/rust/onnxruntime/tests/integration_tests.rs +++ b/rust/onnxruntime/tests/integration_tests.rs @@ -112,6 +112,7 @@ mod download { // and iterate on resulting probabilities, creating an index to later access labels. let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -209,6 +210,7 @@ mod download { let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -301,6 +303,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -398,6 +401,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -515,7 +519,7 @@ mod download { let output = outputs[0].float_array().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } From 5f49a372d6aa990b23d516e2ff13dde5849f1895 Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Fri, 26 Jun 2026 12:58:09 -0400 Subject: [PATCH 10/19] fix(ci): incorrect identity for azcopy (#29274) ### Description Use new GitHub CI identity for azcopy. ### Motivation and Context GitHub CI pools have been assigned a new identity. --- .github/workflows/windows_cuda.yml | 4 ++-- .github/workflows/windows_cuda_plugin.yml | 4 ++-- .github/workflows/windows_gpu_doc_gen.yml | 2 +- .github/workflows/windows_openvino.yml | 2 +- .github/workflows/windows_qnn_x64.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index c6346d9c4e932..9776c22fedabb 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -148,7 +148,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU CUDA CI Pipeline Test Job @@ -260,4 +260,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_cuda_plugin.yml b/.github/workflows/windows_cuda_plugin.yml index 538c1d783cd68..c2d84d7be482a 100644 --- a/.github/workflows/windows_cuda_plugin.yml +++ b/.github/workflows/windows_cuda_plugin.yml @@ -118,7 +118,7 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows CUDA Plugin EP Test @@ -214,4 +214,4 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_gpu_doc_gen.yml b/.github/workflows/windows_gpu_doc_gen.yml index 5e50a970875fc..b41de5542b0c6 100644 --- a/.github/workflows/windows_gpu_doc_gen.yml +++ b/.github/workflows/windows_gpu_doc_gen.yml @@ -44,7 +44,7 @@ jobs: setVcvars: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 52581c7d0a5f5..d87a4919a86fd 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -26,7 +26,7 @@ jobs: timeout-minutes: 240 env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e OnnxRuntimeBuildDirectory: ${{ github.workspace }} DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 35c620ca6f650..5560de89040de 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -29,7 +29,7 @@ jobs: QnnLibKind: [shared_lib, static_lib] env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index d5710795942d1..3ad0076de6d52 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -154,7 +154,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU TensorRT CI Pipeline Test Job @@ -265,4 +265,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e From 126ea8d3810391adbf81579200c12f538f71396e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jun 2026 14:52:31 -0700 Subject: [PATCH 11/19] [CUDA Plugin EP] Expose kernel sync stream for scratch allocation (#29244) ## Description This PR adds a kernel-context C API accessor for the framework `OrtSyncStream*` and uses it in the CUDA plugin EP so scratch allocations can be tagged with the actual compute stream selected for the kernel. It is stacked on #29221 and turns the previously documented concurrent multi-stream limitation into a gated capability: older runtimes keep the conservative fallback, while runtimes with the new API can safely advertise concurrent runs when EP-level unified stream mode is not forced. ## Summary of Changes ### Public API and Adapters | File | Change | |------|--------| | `include/onnxruntime/core/session/onnxruntime_c_api.h` | Adds `KernelContext_GetSyncStream` to expose the borrowed framework stream for stream-aware allocation and synchronization bookkeeping. | | `onnxruntime/core/session/custom_ops.cc` | Implements the API by retrieving the kernel's `OpKernelContext::GetComputeStream()` inside ORT core. | | `onnxruntime/core/session/ort_apis.h` and `onnxruntime/core/session/onnxruntime_c_api.cc` | Declares and wires the new API entry. | | `include/onnxruntime/core/session/onnxruntime_cxx_api.h` and `include/onnxruntime/core/session/onnxruntime_cxx_inline.h` | Adds the C++ `Ort::KernelContext::GetSyncStream()` wrapper. | | `include/onnxruntime/ep/adapter/op_kernel.h` | Adds a version-gated EP adapter accessor so plugins can use the API when available and fall back safely otherwise. | ### CUDA Plugin EP - Tracks the framework stream corresponding to both raw CUDA stream handles and `OrtStreamAdapter` stream arguments. - Passes the framework stream to scratch allocation so arena chunks are stream-tagged instead of using a null stream tag. - Re-enables concurrent run support only when `KernelContext_GetSyncStream` is available and EP-level unified stream mode is not forced. ### Tests and Docs - Extends the shared-lib custom-op test helper to exercise `Ort::KernelContext::GetSyncStream()`. - Updates CUDA plugin EP docs to describe stream-tagged scratch allocation, compatibility fallback, and the new API audit entry. ## Why a C API is needed The implementation of `KernelContext_GetSyncStream` is intentionally small, but the API boundary is the important part. ORT core can safely cast `OrtKernelContext*` back to `onnxruntime::OpKernelContext*` because it owns both the opaque C handle and the private C++ implementation. A plugin kernel should not perform that cast directly: it would make the plugin depend on ORT-core private C++ layout, vtables, and exact build compatibility. The new API keeps that private cast inside ORT core and gives plugin kernels a stable ABI entry point: ```text plugin kernel -> opaque OrtKernelContext* -> OrtApi::KernelContext_GetSyncStream -> ORT core retrieves the actual framework stream ``` This also lets the plugin use runtime version gating. When loaded by an older ORT runtime that does not expose the API, the adapter returns null, scratch allocation uses the conservative fallback, and concurrent runs are not advertised. ## Testing - `lintrunner -a` - `ninja -C build/cu130_plugin/Debug onnxruntime_providers_cuda_plugin` - `ninja -C build/cu130_plugin/Debug onnxruntime_shared_lib_test` - `cd build/cu130_plugin/Debug && ./onnxruntime_shared_lib_test --gtest_filter=CApiTest.custom_op_handler --gtest_color=no` - VS Code diagnostics on touched C++ and header files ## Checklist - [x] Tests added/updated - [x] Documentation updated - [x] Backward compatibility guarded by runtime API-version checks - [ ] CI passes --- .github/workflows/linux_cuda_plugin_ci.yml | 14 +- .../arena_allocator_migration_design.md | 14 +- .../cuda_graph_for_cuda_plugin.md | 16 +- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 13 +- .../core/session/onnxruntime_c_api.h | 21 +++ .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 6 + include/onnxruntime/ep/adapter/op_kernel.h | 8 + .../core/providers/cuda/plugin/cuda_ep.cc | 14 +- .../cuda/plugin/cuda_kernel_adapter.h | 146 ++++++++++++++---- onnxruntime/core/session/custom_ops.cc | 9 ++ onnxruntime/core/session/onnxruntime_c_api.cc | 2 + onnxruntime/core/session/ort_apis.h | 1 + .../cuda_plugin_user_stream_graph_test.cc | 133 ++++++++++++++++ .../test/shared_lib/custom_op_utils.cc | 4 + 15 files changed, 337 insertions(+), 65 deletions(-) diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index e88c6beff5280..2369af53621b2 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -144,10 +144,14 @@ jobs: # --- Run the CUDA plugin EP C++ GoogleTest binary --- # onnxruntime_provider_test is built into the artifact and links the plugin tests - # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). The user-stream + CUDA graph test - # registers the plugin .so via GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), - # which returns the platform-specific filename without a directory component. Run from - # /build/Release/Release so that filename resolves to the plugin .so built there. + # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). These tests register the plugin .so via + # GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), which returns the + # platform-specific filename without a directory component. Run from /build/Release/Release + # so that filename resolves to the plugin .so built there. + # The filter covers every CUDA plugin EP suite linked into this binary: + # CudaPlugin* -> CudaPluginUserStreamGraphTest, CudaPluginArenaTest, + # CudaPluginPartitioningTest, CudaPluginProfilingTest + # CudaResourcePartitioning* -> CudaResourcePartitioningTest - name: Run CUDA Plugin EP C++ Tests run: | docker run --rm --gpus all \ @@ -163,5 +167,5 @@ jobs: cd /build/Release/Release ls -la onnxruntime_provider_test libonnxruntime_providers_cuda_plugin.so - ./onnxruntime_provider_test --gtest_filter='CudaPluginUserStreamGraphTest.*' + ./onnxruntime_provider_test --gtest_filter='CudaPlugin*:CudaResourcePartitioning*' " diff --git a/docs/cuda_plugin_ep/arena_allocator_migration_design.md b/docs/cuda_plugin_ep/arena_allocator_migration_design.md index f082b444e10b0..d4ff21e713f85 100644 --- a/docs/cuda_plugin_ep/arena_allocator_migration_design.md +++ b/docs/cuda_plugin_ep/arena_allocator_migration_design.md @@ -62,16 +62,16 @@ if (!factory.arena_allocator_) { **Stream-aware allocation.** `ArenaImpl::AllocOnStream(size, stream)` tracks which chunks are assigned to which stream. `ResetChunksUsingStream(stream_impl)` is called from `OrtSyncStreamImpl::OnSessionRunEnd` to release chunk-to-stream assignments when a session run completes. -**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT ≥ 1.23), falling back to plain `Alloc` otherwise. The plugin `GetScratchBuffer` deliberately passes a **null stream** to the arena rather than forwarding the kernel's compute stream. A plugin kernel only has the raw `cudaStream_t` (via `KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena persists in each chunk (`chunk->stream`) and later dereferences through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). Synthesizing a temporary framework `Stream` wrapper over the raw handle would be unsafe: it would dangle once `GetScratchBuffer` returns while the arena still holds the pointer, and it would be type-confused (a framework `Stream*` reinterpreted as an `OrtSyncStream*` that ORT never created for this stream). With a null stream the arena tracks scratch chunks as freely reusable (the same semantics as a plain non-stream-aware BFC arena). This is still what keeps scratch allocations served from already-reserved chunks during CUDA graph capture — capture stability comes from chunk reuse, not from stream tagging — and it is safe for the CUDA graph path, which runs on a single unified stream. +**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` uses the framework `OrtSyncStream*` exposed through `KernelContext_GetSyncStream` to stream-tag scratch chunks, while kernels continue to use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` for launches and library handles. This keeps allocation bookkeeping on the same framework stream wrapper that the arena stores in `chunk->stream` and later queries through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). If the negotiated ORT API version does not include `KernelContext_GetSyncStream`, the adapter falls back to a null stream tag and the EP does not advertise concurrent run support. -#### Scratch buffer stream tagging — limitation and future work +#### Scratch buffer stream tagging -A common review question is: *"Passing a null stream to the scratch allocator looks wrong — won't it cause a synchronization issue? Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is that, for the path this code targets, it is correct and safe. The longer answer clarifies what the `stream` argument actually does and why forwarding the real stream is not currently possible. +A common review question is: *"Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is yes for concurrent multi-stream runs, but the allocator must receive the framework stream wrapper, not the raw CUDA handle. -- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is always consumed by the kernel on its real compute stream. So a null tag does not move work onto the default stream or skip any required sync — it only relaxes cross-stream chunk reuse. -- **Why null is safe here.** The scratch routing targets serialized runs and the CUDA graph path, which runs on a single **unified stream** when graph capture and a user compute stream are combined. On a single stream, alloc -> use -> free -> reuse are implicitly ordered by the stream itself, so there is never a second stream that could reuse a chunk while the first is still using it. A null-tagged ("freely reusable") chunk behaves exactly like a plain non-stream-aware BFC arena chunk, which is the correct behavior for one stream. Because null-tagged chunks are not safe for overlapping runs on different CUDA streams, the CUDA plugin EP does not advertise concurrent `Session::Run()` support until scratch chunks can be properly stream-tagged. -- **Why we cannot forward the real stream today (C-API limitation).** The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) — the ORT-core wrapper it stored in `chunk->stream`. A plugin kernel only has the raw `cudaStream_t`. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (an `OrtSyncStreamImpl`), but that is a *different* object from the ORT-core `OrtSyncStream*` the arena expects; passing it (or a stack-allocated shim over the raw handle) would be both dangling and type-confused. -- **Future work.** To properly stream-tag scratch chunks — which only becomes necessary if this path is extended to support concurrent multi-stream runs sharing one arena — ORT needs new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels at dispatch time (e.g. via `KernelContext`). Until then, the null-stream tag is the correct and intentional choice. The matching code comment lives in `CudaKernel::GetScratchBuffer` (`onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h`). +- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is consumed by the kernel on its raw CUDA compute stream. +- **Raw CUDA stream and framework stream are different objects.** `KernelContext_GetGPUComputeStream` returns the raw `cudaStream_t` used for CUDA calls. The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) because that stable wrapper is what it persists in each chunk. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (`OrtSyncStreamImpl`), but that is not the ORT-core `OrtSyncStream*` the arena expects. +- **How the plugin bridges them.** `KernelContext_GetSyncStream` exposes the framework stream for the current kernel dispatch. The CUDA plugin adapter records the mapping from raw `cudaStream_t` to framework stream when migrated kernels call `GetComputeStream(ctx)`, and `GetScratchBuffer` uses the framework stream for `AllocOnStream`. This preserves the existing migrated-kernel pattern while making scratch chunks safe for cross-stream reuse decisions. +- **Compatibility fallback.** When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, scratch allocations use a null stream tag. A null tag is correct for serialized runs and single-unified-stream CUDA graph capture, but it is not safe for overlapping runs on different CUDA streams, so the plugin EP only advertises concurrent `Session::Run()` when `KernelContext_GetSyncStream` is available. diff --git a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md index 1cac9464430dc..9035ce91bb3bb 100644 --- a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md +++ b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md @@ -92,8 +92,9 @@ A natural question when reading `GetPerThreadContext()` is why `use_external_str | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member | | `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture`; removed the validation that rejected `user_compute_stream` + `enable_cuda_graph` (the combination is now supported) | -| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) with a null stream, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call, so scratch allocations are served from already-reserved arena chunks during capture | -| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available; `GetScratchBuffer` still passes a null stream until plugin kernels can receive a stable framework `OrtSyncStream*` | +| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) and stream-tags scratch chunks with the framework stream exposed by `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call | +| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available | +| `include/onnxruntime/core/session/onnxruntime_c_api.h` | Added `KernelContext_GetSyncStream` so plugin kernels can obtain the framework `OrtSyncStream*` for stream-aware allocation bookkeeping while still using `KernelContext_GetGPUComputeStream` for raw CUDA work | | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` | | `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` | | `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion | @@ -128,7 +129,7 @@ CUDA graph capture requires that all memory allocations happen during warmup, no **Arena integration details (now implemented):** - Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls. -- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` passes a **null stream** to the arena. This is *not* a synchronization bug: the `stream` argument is only bookkeeping metadata the stream-aware arena uses to decide when a freed chunk may be reused on a *different* stream without a sync - it does not change where the kernel runs (the buffer is still consumed on the real compute stream). In a serialized run (and within one graph-capture run), alloc/free/reuse are implicitly ordered on that stream, so a null-tagged ("freely reusable") chunk is correct and safe. It is also currently the only safe option, because a plugin kernel only has the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` the stream-aware arena persists per chunk and later dereferences through the EP stream API; note that the ORT-core `OrtSyncStream` (`struct OrtSyncStream : public onnxruntime::Stream`) is a different object from the plugin's `CudaSyncStream` (an `OrtSyncStreamImpl`). Synthesizing a temporary `Stream*` over the raw handle would dangle after `GetScratchBuffer` returns and be type-confused, so scratch chunks are tracked with a null stream (freely reusable, like a plain BFC arena). Capture stability comes from chunk reuse, not stream tagging. Properly stream-tagging scratch chunks (required before this path can support concurrent multi-stream runs) is **future work** that requires new C-API surface to expose the framework `OrtSyncStream*` to plugin kernels — see [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). +- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` stream-tags scratch chunks with the framework `OrtSyncStream*` exposed by `KernelContext_GetSyncStream`. The raw `cudaStream_t` from `KernelContext_GetGPUComputeStream` is still used for CUDA launches and library calls; the framework stream is used only for the arena's cross-stream reuse bookkeeping. - When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology. - Pinned allocations are also arena-backed, but remain non-stream-aware. - The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay. @@ -137,12 +138,12 @@ CUDA graph capture requires that all memory allocations happen during warmup, no ### Concurrent Run Support -Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plugin EP while migrated kernels route scratch/workspace allocations through the EP arena with a null stream tag. +Concurrent `Session::Run()` is advertised by the CUDA plugin EP when the host ORT runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. - `CudaEp::PerThreadContext` still owns graph stream, graph manager, warm-up run counts, and memory watermark state per thread. This keeps graph bookkeeping thread-local and avoids sharing captured graph executables across threads. -- However, plugin kernels currently receive only the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena stores in each chunk and later uses for safe cross-stream reuse checks. -- Because `CudaKernel::GetScratchBuffer` cannot safely provide that framework stream, it passes a null stream tag. Null-tagged scratch chunks are freely reusable, which is safe for serialized runs and single-unified-stream graph capture but unsafe for overlapping runs on different CUDA streams. -- Therefore `CudaEp::IsConcurrentRunSupportedImpl()` returns false. Re-enabling concurrent multi-stream runs is future work and requires new C-API surface to expose a stable framework stream (or equivalent sync id) to plugin kernels so scratch chunks can be properly stream-tagged. +- Plugin kernels now obtain the framework `OrtSyncStream*` through `KernelContext_GetSyncStream` and use it only for scratch/workspace allocation bookkeeping. CUDA work still launches on the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`. +- Stream-tagged scratch chunks let the shared arena apply its normal cross-stream reuse rules for overlapping runs on different CUDA streams. +- When the negotiated ORT API version does not include `KernelContext_GetSyncStream`, `CudaKernel::GetScratchBuffer` falls back to a null stream tag and `CudaEp::IsConcurrentRunSupportedImpl()` returns false. ## Verification @@ -159,4 +160,3 @@ Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plug ## Future Work 1. **Profiling integration**: CUDA graph replay currently bypasses the CUDA plugin EP profiler path because the CUDA plugin EP does not yet implement `OrtEp::CreateProfiler`. Wiring graph replay into that path is future work. -2. **Stream-tagged scratch allocations**: `CudaKernel::GetScratchBuffer` passes a null stream to the EP arena because plugin kernels cannot currently obtain the framework `OrtSyncStream*` the stream-aware arena needs (they only have the raw `cudaStream_t`). This is correct and safe for serialized runs and within one graph-capture run, but it is why the EP does not advertise concurrent `Session::Run()` support. Supporting concurrent multi-stream runs that share one arena would require new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels so scratch chunks can be properly stream-tagged. See [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 438fb8606fc09..8f9a0d388d1af 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -97,10 +97,11 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA | API surface | Newest `\since` used | Representative functions | | --- | --- | --- | | `OrtApi` — direct calls (`ort_api_.*`, `Ort::GetApi().*`) | **1.23** | `SyncStream_GetHandle`, `GetTensorSizeInBytes`, `GetRunConfigEntry`, `CreateMemoryInfo_V2`, `Graph_GetNumNodes`/`Graph_GetNodes` (older: `CreateStatus`, `Logger_LogMessage`, `*KeyValuePairs`, `HardwareDevice_*`, `MemoryInfoGet*`, `GetSessionConfigEntry`) | +| `OrtApi` — optional gated kernel-context capability | **1.28** | `KernelContext_GetSyncStream` (called from the adapter only when `CurrentOrtApiVersion() >= 28`; otherwise scratch allocation uses a null stream tag and concurrent run support is not advertised) | | `OrtEpApi` — direct calls (`ep_api_.*`, `Ort::GetEpApi().*`) | **1.24** | `CreateKernelRegistry`, `KernelRegistry_AddKernel`, `ReleaseKernelRegistry`, `CreateIfKernel`/`CreateLoopKernel`/`CreateScanKernel`, `EpGraphSupportInfo_LookUpKernel` (older: `MemoryDevice_*`, `MemoryInfo_GetMemoryDevice`, `SyncStream_*`, `EpDevice_AddAllocatorInfo`, `EpGraphSupportInfo_AddSingleNode`, `CreateEpDevice`/`ReleaseEpDevice`) | | EP profiler API (only when built with `ENABLE_CUDA_PROFILING`) | **1.25** | `CreateProfilingEvent`, `ProfilingEventsContainer_AddEvents`, `ReleaseProfilingEvent` (called from `cuda_profiler_plugin.cc` via the `Ort::ProfilingEvent` / `Ort::UnownedProfilingEventsContainer` wrappers) | -`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from the optional EP profiler, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. +`provider_api_shims.cc` uses only internal helpers (`GetEnvironmentVar`, `MLFloat16` conversions), and the plugin uses no Model Editor, Model Package, or Compile API. **Apart from optional gated capabilities such as EP profiling and stream-tagged scratch allocation, every API the plugin calls is `\since 1.24` or older**, so the true compatibility floor is `1.24.4`. **Defensive capability gating.** Reading a struct field is safe because the field is append-only and ORT only reads fields it knows about. The real hazard is *calling* an `OrtApi`/`OrtEpApi` function that the (possibly older) runtime does not provide. The correct guard for that is the runtime API version, `onnxruntime::ep::CurrentOrtApiVersion()`, not `ort_version_supported`. The `CudaEp` constructor (`cuda_ep.cc`) therefore reads `const uint32_t ort_version = onnxruntime::ep::CurrentOrtApiVersion();` and only installs an `OrtEp` callback when that runtime version is new enough to provide both the callback field and every API its implementation calls: @@ -113,7 +114,9 @@ Because the plugin binary may load into an older runtime, every `OrtApi`/`OrtEpA All other `OrtEp` and `OrtEpFactory` callbacks are `\since 1.24` or older and are installed unconditionally. Gating `CreateProfiler` is what makes the three `\since 1.25` profiler functions unreachable on an older runtime: when the profiler is never created, ORT never drives the `OrtEpProfilerImpl` callbacks that call them. -The gates use **graceful degradation rather than throwing**: the gated callbacks are all optional capabilities (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). +`KernelContext_GetSyncStream` is guarded at the adapter call site rather than through an `OrtEp` callback field: `OpKernelContext::GetSyncStream()` returns null when `CurrentOrtApiVersion() < 28`, and `CudaEp::IsConcurrentRunSupportedImpl()` only advertises concurrent runs when that API is available. Older runtimes therefore keep the previous serialized-run behavior while still using the same plugin binary. + +The gates use **graceful degradation rather than throwing**: the gated callbacks and adapter capabilities are optional features (per-run sync, EP-level GPU profiling, CUDA-graph capture/replay, device-memory budgeting, stream-tagged scratch for concurrent runs), so disabling them on an older runtime still yields a fully functional EP — inference runs, just without that specific feature. This was validated by loading the plugin (built against the latest headers) into both the latest runtime (full test suite passes) and an `onnxruntime==1.24.4` runtime (the EP registers, enumerates devices, and runs inference correctly with the newer callbacks left null). --- @@ -463,14 +466,14 @@ The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: Migrated kernels need a valid device allocator in two places: scratch/workspace buffers during `Compute()`, and one-time weight conversion or packing during `PrePack()`. Both now resolve the allocator the same way the bundled CUDA EP does, through the kernel's own `OpKernelInfo`. -- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) with a null stream tag, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). The null stream tag is intentional: plugin kernels only have the raw `cudaStream_t`, not the framework `OrtSyncStream*` that the stream-aware arena persists in chunks for safe cross-stream reuse. +- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) and stream-tags scratch chunks with the framework `OrtSyncStream*` from `KernelContext_GetSyncStream`, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). CUDA launches still use the raw `cudaStream_t` from `KernelContext_GetGPUComputeStream`; the framework stream is used only for stream-aware arena bookkeeping. - **PrePack.** The framework prepack loop (`SessionState::PrepackConstantInitializedTensors`) resolves the allocator with `GetInitializerAllocator(kernel->Info().GetDevice(OrtMemTypeDefault))`, a session map keyed by device. For a plugin EP registered as a separate library, that device-keyed lookup can miss and return null. The loop now falls back to `kernel->Info().GetAllocator(OrtMemTypeDefault)` when the lookup is null, so every `PrePack` implementation receives a valid allocator at the single framework call site. This replaces the earlier approach of adding a per-kernel `if (!alloc) alloc = Info().GetAllocator(...)` guard to each prepacking op (which only covered the few ops that were touched and risked missing future ones). The fallback is behavior-neutral for in-tree EPs, whose device-keyed lookup already succeeds, and it does **not** force `is_packed`/`prepacked_weights` handling \u2014 ops such as `QMoE` and `MatMulNBits` still set `is_packed = true` and populate prepacked weights normally. -The enabling adapter change is in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h): `IAllocatorWrappingOrtAllocator` now implements `IsStreamAware()`/`AllocOnStream()` (previously `ORT_NOT_IMPLEMENTED`) by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` does not use that stream-aware path yet because the plugin kernel layer cannot safely provide the framework `OrtSyncStream*`; stream-tagged scratch allocation is future work and is documented in [arena_allocator_migration_design.md](arena_allocator_migration_design.md#scratch-buffer-stream-tagging--limitation-and-future-work). +The enabling adapter changes are in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h) and [`include/onnxruntime/ep/adapter/op_kernel.h`](../../include/onnxruntime/ep/adapter/op_kernel.h): `IAllocatorWrappingOrtAllocator` implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), and `OpKernelContext::GetSyncStream()` exposes the framework stream when the negotiated ORT API version includes `KernelContext_GetSyncStream`. The CUDA plugin uses that framework stream for `GetScratchBuffer`; if it is unavailable, allocation falls back to a null stream tag and concurrent `Session::Run()` is not advertised. ### 5.4 CUDA Graph Support -CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is intentionally not advertised while scratch allocations are null-stream-tagged; supporting concurrent multi-stream runs requires future C-API work to expose a stable framework stream or sync id to plugin kernels. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. +CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is supported when the host runtime exposes `KernelContext_GetSyncStream` and the session is not forced into EP-level unified-stream mode. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. #### 5.4.1 OrtEp C API Extensions (v1.26) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5dd53a8cf45c0..a73868527771a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7496,6 +7496,27 @@ struct OrtApi { * \since Version 1.28. */ ORT_API_T(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); + + /** \brief Get the framework synchronization stream associated with a kernel context. + * + * This returns the framework stream wrapper for the execution provider stream used by this kernel invocation. + * It is intended for APIs that need a stable framework stream object for stream-aware allocation and + * synchronization bookkeeping. Use KernelContext_GetGPUComputeStream when launching native GPU work. + * + * \param[in] context OrtKernelContext instance. + * \param[out] out Returns the framework synchronization stream, or nullptr if the kernel has no stream. + * Do not free or mutate the returned pointer. It is owned by the underlying session. + * The pointer may be stored and used for stream-aware allocation and synchronization + * bookkeeping beyond the Compute call (e.g. an allocator may persist it in arena + * chunks); it remains valid until the owning Session::Run() completes its teardown. + * Do not retain or dereference it after the run that produced this kernel context ends. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.28. + */ + ORT_API2_STATUS(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4798d3d4ad1b8..55a4e36167e86 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3020,6 +3020,7 @@ struct KernelContext { UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; + OrtSyncStream* GetSyncStream() const; Logger GetLogger() const; Ort::Allocator GetAllocator(const OrtMemoryInfo& memory_info) const; OrtKernelContext* GetOrtKernelContext() const { return ctx_; } diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d7439e7b356c6..ed3abc0961be6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2876,6 +2876,12 @@ inline void* KernelContext::GetGPUComputeStream() const { return out; } +inline OrtSyncStream* KernelContext::GetSyncStream() const { + OrtSyncStream* out = nullptr; + Ort::ThrowOnError(GetApi().KernelContext_GetSyncStream(ctx_, &out)); + return out; +} + inline Ort::Allocator KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const { OrtAllocator* out = nullptr; Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out)); diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h index 27a46cc10e306..1f103b64a443e 100644 --- a/include/onnxruntime/ep/adapter/op_kernel.h +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -164,6 +164,14 @@ struct OpKernelContext { void* GetGPUComputeStream() const { return context_.GetGPUComputeStream(); } + OrtSyncStream* GetSyncStream() const { + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + if (CurrentOrtApiVersion() < kOrtKernelContextGetSyncStreamMinVersion) { + return nullptr; + } + + return context_.GetSyncStream(); + } private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 13f7bfd7a40cf..73fa92d19cd1f 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -464,14 +464,12 @@ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl( return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null."); } - // Plugin kernels currently expose only the raw cudaStream_t to GetScratchBuffer(), not the - // framework OrtSyncStream* that the stream-aware arena needs to tag scratch chunks by stream. - // Scratch chunks are therefore allocated with a null stream tag and can be reused freely. That is - // safe when runs are serialized, but it is not safe to advertise concurrent Session::Run(): two - // runs on different CUDA streams could reuse the same scratch chunk while earlier work is still - // in flight. Re-enable concurrent runs only after the plugin kernel layer can pass a stable - // framework stream (or equivalent sync id) to the arena. - *is_supported = false; + auto* ep = static_cast(this_ptr); + // Concurrent runs require stream-tagged scratch allocations. The plugin kernel adapter can tag + // scratch chunks only when the hosting ORT runtime exposes KernelContext_GetSyncStream. + static constexpr uint32_t kOrtKernelContextGetSyncStreamMinVersion = 28; + *is_supported = !ep->config_.use_ep_level_unified_stream && + ::onnxruntime::ep::CurrentOrtApiVersion() >= kOrtKernelContextGetSyncStreamMinVersion; return nullptr; } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index f134c599d5b46..06fe635e35716 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -16,6 +16,7 @@ #pragma once #include +#include #include "core/common/status.h" #include "core/common/narrow.h" @@ -53,6 +54,81 @@ namespace onnxruntime { struct CudaStream; +namespace cuda_plugin { +namespace detail { +inline thread_local std::unordered_map stream_to_framework_stream; +inline thread_local void* current_cuda_stream = nullptr; +inline thread_local onnxruntime::Stream* current_framework_stream = nullptr; + +inline void RegisterFrameworkStreamForCudaStream(void* cuda_stream, OrtSyncStream* framework_stream) { + current_cuda_stream = cuda_stream; + current_framework_stream = reinterpret_cast(framework_stream); + + if (current_framework_stream == nullptr) { + return; + } + + // Map only from the raw cudaStream_t handle to the current framework stream. The framework + // stream is already handled directly by GetFrameworkStreamForStreamArg, so we deliberately do + // not insert a framework_stream -> framework_stream entry: it would be unused and would grow the + // thread-local map without bound while retaining framework stream pointers past the + // Session::Run() teardown lifetime documented for KernelContext_GetSyncStream. + if (cuda_stream != nullptr) { + stream_to_framework_stream[cuda_stream] = current_framework_stream; + } +} + +inline onnxruntime::Stream* GetFrameworkStreamForStreamArg(void* stream) { + // A null stream argument means "the compute stream of the current Compute call". This is the + // form used by GetTransientScratchBuffer and legacy GetScratchBuffer(..., nullptr). Map it to + // the framework stream registered for this call so scratch chunks are still stream-tagged even + // when the kernel runs on a non-default CUDA stream (where current_cuda_stream is non-null and a + // nullptr arg would otherwise miss the map lookup and fall back to a null stream tag). + // + // current_framework_stream is scoped to a single CudaKernel::Compute invocation by + // ComputeStreamScope (see below). Outside any Compute call it is nullptr, so allocations made + // from kernel constructors (which also call GetScratchBuffer(..., nullptr)) fall back to the + // non-stream-tagged path instead of inheriting a stale framework stream pointer whose lifetime + // ended with a previous Session::Run(). + if (stream == nullptr || stream == current_cuda_stream || stream == current_framework_stream) { + return current_framework_stream; + } + + auto it = stream_to_framework_stream.find(stream); + return it == stream_to_framework_stream.end() ? nullptr : it->second; +} + +// RAII guard that scopes the thread-local "current Compute call" framework stream to the lifetime +// of a single CudaKernel::Compute invocation on a worker thread. +// +// On entry it clears current_cuda_stream/current_framework_stream so that scratch allocated before +// the kernel registers its stream (via Stream(ctx)/GetComputeStream(ctx)/GetOrtStream(ctx)), or via +// a nullptr stream argument, does not inherit a stale framework stream left over from a previous +// Compute call on this worker thread. On exit it restores the previous values, which keeps nested +// Compute calls (a kernel that invokes another kernel's Compute) correct and leaves the per-thread +// "current" stream cleared once the outermost Compute returns. The borrowed framework stream is +// only valid until its owning Session::Run() completes teardown, so it must not outlive the call. +struct ComputeStreamScope { + ComputeStreamScope() + : saved_cuda_stream_(current_cuda_stream), + saved_framework_stream_(current_framework_stream) { + current_cuda_stream = nullptr; + current_framework_stream = nullptr; + } + ~ComputeStreamScope() { + current_cuda_stream = saved_cuda_stream_; + current_framework_stream = saved_framework_stream_; + } + + private: + void* saved_cuda_stream_; + onnxruntime::Stream* saved_framework_stream_; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ComputeStreamScope); +}; +} // namespace detail +} // namespace cuda_plugin + // Lightweight Stream shim for plugin build: wraps a raw cudaStream_t as a // framework-compatible Stream* that can be passed to _impl.cu functions which // call stream->GetHandle(). Stack-allocated; does NOT own the stream. @@ -70,6 +146,11 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : plugin_stream_shim_(cuda_stream_handle), stream_(&plugin_stream_shim_) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : plugin_stream_shim_(cuda_stream_handle), + stream_(framework_stream == nullptr ? static_cast(&plugin_stream_shim_) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -83,6 +164,10 @@ class OrtStreamAdapter { explicit OrtStreamAdapter(void* cuda_stream_handle) : stream_(static_cast(cuda_stream_handle)) {} + OrtStreamAdapter(void* cuda_stream_handle, OrtSyncStream* framework_stream) + : stream_(framework_stream == nullptr ? static_cast(cuda_stream_handle) + : reinterpret_cast(framework_stream)) {} + onnxruntime::Stream* get() const { return stream_; } operator onnxruntime::Stream*() const { return stream_; } @@ -868,6 +953,11 @@ class CudaKernel : public OpKernel { } virtual ~CudaKernel() = default; Status Compute(OpKernelContext* ctx) const { + // Scope the thread-local "current Compute call" framework stream to this invocation so that + // scratch tagged via a nullptr stream argument never inherits a stale framework stream from a + // previous Compute call (or leaks one to a later kernel constructor) on this worker thread. + cuda_plugin::detail::ComputeStreamScope compute_stream_scope; + // Ensure the correct CUDA device is active for this kernel. // Worker threads default to device 0; sessions on device > 0 need an // explicit cudaSetDevice. Skip during CUDA graph capture because @@ -903,17 +993,27 @@ class CudaKernel : public OpKernel { cudaStream_t Stream(OpKernelContext* ctx) const { if (!ctx) return nullptr; - return static_cast(ctx->GetGPUComputeStream()); + // Register the framework sync stream for this Compute call so that scratch allocated via + // GetTransientScratchBuffer()/GetScratchBuffer(..., nullptr) is still stream-tagged for kernels + // that call Stream(ctx) before GetComputeStream()/GetOrtStream() (e.g. conv algo search). + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return static_cast(cuda_stream); } // Returns an opaque stream pointer for passing to GetScratchBuffer/AddDeferredReleaseCPUPtr/CopyToGpu. // Returns void* for dual-build compatibility: framework wraps Stream*, plugin wraps cudaStream_t. inline void* GetComputeStream(OpKernelContext* ctx) const { - return ctx->GetGPUComputeStream(); + void* cuda_stream = ctx->GetGPUComputeStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, ctx->GetSyncStream()); + return cuda_stream; } inline onnxruntime::OrtStreamAdapter GetOrtStream(OpKernelContext* ctx) const { - return onnxruntime::OrtStreamAdapter(GetComputeStream(ctx)); + void* cuda_stream = ctx->GetGPUComputeStream(); + OrtSyncStream* framework_stream = ctx->GetSyncStream(); + cuda_plugin::detail::RegisterFrameworkStreamForCudaStream(cuda_stream, framework_stream); + return onnxruntime::OrtStreamAdapter(cuda_stream, framework_stream); } static cudnnHandle_t GetCudnnHandle(cudaStream_t s) { @@ -1023,7 +1123,7 @@ class CudaKernel : public OpKernel { template using IAllocatorUniquePtr = std::unique_ptr>; template - inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* /*stream*/) const { + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* stream) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); // Route kernel scratch/workspace allocations through the EP allocator @@ -1047,36 +1147,18 @@ class CudaKernel : public OpKernel { // enforced inside MakeUniquePtr via ValidatedCalcMemSizeForArray (it throws // on cnt * sizeof(T) overflow). // - // The compute stream is intentionally NOT forwarded to the allocator here. This is a - // bookkeeping decision, NOT a synchronization bug: the `stream` argument to a stream-aware - // arena is only metadata used to decide when a freed chunk may be reused on a *different* - // stream without an intervening sync. It does not change where the kernel runs - the returned - // buffer is still consumed by the kernel on the real compute stream. In a serialized run (and - // within one graph-capture run), alloc/free/reuse ordering is implicit on that stream, so there - // is no cross-stream chunk to race on. Tagging chunks with a null stream (freely reusable, the - // same semantics as a plain non-stream-aware BFC arena) is therefore correct and safe as long - // as the EP does not advertise concurrent Session::Run() support. - // - // It is also currently the only safe option, because of a C-API type constraint: a plugin - // kernel only has the raw cudaStream_t (KernelContext::GetGPUComputeStream), not the framework - // OrtSyncStream* that the stream-aware arena persists in each chunk (CudaArena stores - // `chunk->stream` and later dereferences it through the EP stream API, e.g. - // SyncStream_GetImpl/SyncStream_GetSyncId). Note that OrtSyncStream (the ORT-core wrapper, - // `struct OrtSyncStream : public onnxruntime::Stream`) is a DIFFERENT object from the plugin's - // CudaSyncStream (an OrtSyncStreamImpl); CudaSyncStream::FromCudaStream() recovers the latter, - // not the former. Wrapping the raw handle in a temporary framework Stream shim and passing it - // down would be unsafe on two counts: (1) the shim is stack-allocated and would dangle after - // this function returns while the arena still holds the pointer, and (2) it is type-confused — - // the arena would reinterpret a framework Stream* as an OrtSyncStream* that was never created - // by ORT for this stream. - // - // Properly stream-tagging scratch chunks (needed before this path can support concurrent - // multi-stream runs) requires new C-API surface to expose the framework OrtSyncStream* to - // plugin kernels. See docs/cuda_plugin_ep/arena_allocator_migration_design.md ("Scratch buffer - // stream tagging") for the limitation and future work. + // The `stream` argument is the raw cudaStream_t used by migrated CUDA kernels, or a Stream* + // from OrtStreamAdapter in code paths that need stream->GetHandle(). Stream-aware arena + // allocation needs the stable framework Stream* wrapper instead, because the arena stores it + // in each chunk and later queries sync ids through the EP stream API. Stream(ctx), + // GetComputeStream(ctx) and GetOrtStream(ctx) record the mapping from both argument forms to + // the framework stream for the current Compute call. + // If the negotiated ORT API version does not include KernelContext_GetSyncStream, the lookup + // returns null and allocation falls back to the non-stream-tagged path. + auto* framework_stream = cuda_plugin::detail::GetFrameworkStreamForStreamArg(stream); return ::onnxruntime::IAllocator::MakeUniquePtr( Info().GetAllocator(OrtMemType::OrtMemTypeDefault), cnt, /*use_reserve*/ false, - /*stream*/ nullptr); + framework_stream); } template inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 2c8b81e4ffefe..89969172c1bdc 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -206,6 +206,15 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKe }); }; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, + _Outptr_result_maybenull_ OrtSyncStream** out) { + return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { + auto* stream = reinterpret_cast(context)->GetComputeStream(); + *out = reinterpret_cast(stream); + return nullptr; + }); +}; + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { return ExecuteIfKernelApiEnabled([&]() -> OrtStatusPtr { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a663d209cfa53..22df898ca3227 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4916,6 +4916,8 @@ static constexpr OrtApi ort_api_1_to_28 = { // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, + + &OrtApis::KernelContext_GetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 61ece2dd9a682..e747d0d0ab2d8 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -196,6 +196,7 @@ ORT_API_STATUS_IMPL(KernelContext_GetInputCount, _In_ const OrtKernelContext* co ORT_API_STATUS_IMPL(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); ORT_API_STATUS_IMPL(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out); +ORT_API_STATUS_IMPL(KernelContext_GetSyncStream, _In_ const OrtKernelContext* context, _Outptr_result_maybenull_ OrtSyncStream** out); // OrtTypeInfo methods ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc index d49faf3c90ea8..b75de767bb7f6 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc @@ -88,6 +88,12 @@ Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { return Ort::ConstEpDevice{nullptr}; } +// Dummy external allocator callbacks. They are only used to make the external-allocator +// configuration non-null; the plugin EP rejects the combination with user_compute_stream +// before either is ever invoked. +void* DummyExternalAlloc(size_t /*size*/) { return nullptr; } +void DummyExternalFree(void* /*ptr*/) {} + } // namespace class CudaPluginUserStreamGraphTest : public ::testing::Test { @@ -129,6 +135,70 @@ class CudaPluginUserStreamGraphTest : public ::testing::Test { return so; } + // Allocate device input/output, bind them, and run `iterations` times on `stream`, verifying + // Y = X * W each run. The input is uploaded once up front and then left constant: when CUDA graph + // capture is enabled, issuing host->device work on the stream immediately before the capture run + // would interfere with cudaStreamBeginCapture, so the buffers are populated and synchronized + // before any capture happens. When `graph_ids` is non-empty, run i sets gpu_graph_id to + // graph_ids[i % size] to exercise CUDA graph annotation-id switching. mul_1.onnx computes + // Y = X * W with W = [1..6] (shape 3x2). + void RunAndVerifyOnStream(Ort::Session& session, cudaStream_t stream, int iterations, + const std::vector& graph_ids = {}) { + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::array x_values = {2.0f, 3.0f, 5.0f, 7.0f, 11.0f, 13.0f}; + + // Fixed device buffers so captured CUDA graphs keep valid IO addresses across replays. + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + // Populate the input once and synchronize, so no host-issued work is pending on `stream` + // when graph capture begins on a later run. + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, x_values.data(), kBytes, cudaMemcpyHostToDevice, stream)); + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + for (int i = 0; i < iterations; ++i) { + Ort::RunOptions run_options; + if (!graph_ids.empty()) { + run_options.AddConfigEntry("gpu_graph_id", graph_ids[i % graph_ids.size()].c_str()); + } + session.Run(run_options, binding); + + // Kernels run on `stream`; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(stream)); + std::array y{}; + ASSERT_EQ(cudaSuccess, cudaMemcpy(y.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + for (size_t j = 0; j < kNumElements; ++j) { + EXPECT_FLOAT_EQ(y[j], x_values[j] * w_values[j]) << "mismatch at iteration " << i << " index " << j; + } + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + } + std::unique_ptr registration_; Ort::ConstEpDevice cuda_device_{nullptr}; }; @@ -234,6 +304,69 @@ TEST_F(CudaPluginUserStreamGraphTest, CaptureAndReplayOnUserStream) { ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); } +// Negative: a user_compute_stream combined with an external GPU allocator +// (gpu_external_alloc/gpu_external_free) is not supported and must be rejected at session +// creation with an error rather than silently ignored. +TEST_F(CudaPluginUserStreamGraphTest, RejectsUserStreamWithExternalAllocator) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"user_compute_stream", std::to_string(reinterpret_cast(user_stream))}, + {"gpu_external_alloc", std::to_string(reinterpret_cast(&DummyExternalAlloc))}, + {"gpu_external_free", std::to_string(reinterpret_cast(&DummyExternalFree))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + EXPECT_THROW( + { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }, + Ort::Exception); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Edge case: cudaStream_t(0) (the CUDA default stream) is a valid user-provided stream. Because +// user_compute_stream parses to nullptr, the caller must set has_user_compute_stream explicitly, +// otherwise the stream would be treated as "not provided". Session creation must succeed and +// inference must run correctly on the default stream. +// +// Note: CUDA graph capture is intentionally NOT enabled here. The legacy default stream (stream 0) +// cannot be captured (cudaStreamBeginCapture returns cudaErrorStreamCaptureUnsupported), so this +// test exercises only that stream 0 is honored as the compute stream for non-graph execution. +TEST_F(CudaPluginUserStreamGraphTest, DefaultStreamAsUserStream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"has_user_compute_stream", "1"}, + {"user_compute_stream", "0"}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Run several iterations on the default stream (stream 0) and verify correctness. + RunAndVerifyOnStream(session, /*stream=*/nullptr, /*iterations=*/4); +} + +// Switching the CUDA graph annotation id (gpu_graph_id) between runs while using a user stream +// must capture/replay a distinct graph per id without crashing and keep producing correct results. +TEST_F(CudaPluginUserStreamGraphTest, GraphAnnotationIdSwitchingWithUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Alternate between annotation ids "1" and "2". With min_num_runs_before_cuda_graph_capture == 2, + // 8 iterations let each id accumulate warmup runs, capture, and then replay on the user stream. + RunAndVerifyOnStream(session, user_stream, /*iterations=*/8, /*graph_ids=*/{"1", "2"}); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 53745cae9d803..a58f82deab4ff 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -54,8 +54,11 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { EXPECT_NE(allocated, nullptr) << "KernelContext_GetAllocator() can successfully allocate some memory"; allocator.Free(allocated); + OrtSyncStream* sync_stream = ctx.GetSyncStream(); + // Do computation #ifdef USE_CUDA + EXPECT_NE(sync_stream, nullptr) << "KernelContext_GetSyncStream() returns the kernel compute stream"; // Launch on stream 0 or user provided stream void* stream; Ort::ThrowOnError(ort_.KernelContext_GetGPUComputeStream(context, &stream)); @@ -70,6 +73,7 @@ void MyCustomKernel::Compute(OrtKernelContext* context) { // and use the same compute stream to launch the custom op. // Here, an example for (1) is shown (See test_inference.cc to see how this custom op is used.) #else + EXPECT_EQ(sync_stream, nullptr) << "CPU custom ops do not have a compute stream"; ORT_UNUSED_PARAMETER(ort_); for (int64_t i = 0; i < size; i++) { out[i] = X[i] + Y[i]; From 578658826319426ff18e90ee7fa5f2f30ed0f903 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 26 Jun 2026 14:56:31 -0700 Subject: [PATCH 12/19] Validate SparseAttention CSR indices and key lengths element values (#29242) This pull request introduces device-side and host-side validation for sparse attention input indices and key sequence lengths, improving error detection and robustness for both CPU and CUDA implementations. It also adds an environment variable to optionally disable the new device-side validation for performance reasons. Some related test code has been updated to match these changes. **Validation improvements:** * Added host-side validation functions (`ValidateCSRIndices`, `ValidateKeyLengths`) for CSR indices and key sequence lengths in the CPU kernel, and integrated them into the compute path. This ensures invalid inputs are caught early on the CPU. [[1]](diffhunk://#diff-44554dbe530c593f0f0b85a591859bb6b0a21e62992c61f9622e5456a144cb45R39-R106) [[2]](diffhunk://#diff-44554dbe530c593f0f0b85a591859bb6b0a21e62992c61f9622e5456a144cb45R146-R150) * Implemented a CUDA kernel (`ValidateCSRIndicesKernel`) and supporting function (`ValidateCSRIndicesOnDevice`) to check CSR row-pointer monotonicity, column-index range, and key lengths on device, with detailed error codes and messages. The CUDA kernel is invoked in the compute path unless disabled by environment variable. [[1]](diffhunk://#diff-0c8a322cd4611e589f38a67876f309f9d83869a6d2239cadf86970ed2005ebd0R336-R449) [[2]](diffhunk://#diff-fff06841efe15d5f95c02bb38daa1d5aa0775de0e1777d9d418222e44ebc88feR71-R95) [[3]](diffhunk://#diff-08ea97fecd6c161add2607d5b7d406c0f3d2b0f1280ebb49d801f085d941770aR220-R236) **Configurability:** * Introduced the environment variable `ORT_DISABLE_SPARSE_ATTENTION_INPUT_VALIDATION` to allow users to skip device-side validation for performance when inputs are known to be valid. This is parsed and used in the CUDA kernel. [[1]](diffhunk://#diff-56cfb57cdd5f9134a8fea24bd006c691860c64a1f78f4b2c69a861d847dee9ddR94-R99) [[2]](diffhunk://#diff-08ea97fecd6c161add2607d5b7d406c0f3d2b0f1280ebb49d801f085d941770aR59-R60) [[3]](diffhunk://#diff-8e47232d826ceae90a93aa6cde8f2869dcb1998f37804055fcfb584d95e21a96R29) **Test updates:** * Refactored test helper function names and test input shapes to match the new validation logic and error messages. [[1]](diffhunk://#diff-21bdb8e3ad8b50a72c5da77349f280c15a6938ece2bea4e987bd12ff8bcb2a0eL27-R27) [[2]](diffhunk://#diff-21bdb8e3ad8b50a72c5da77349f280c15a6938ece2bea4e987bd12ff8bcb2a0eL45-R45) [[3]](diffhunk://#diff-21bdb8e3ad8b50a72c5da77349f280c15a6938ece2bea4e987bd12ff8bcb2a0eL212-R220) **Other:** * Removed redundant key length value checks from `CheckInputs` in favor of the new validation routines. --- .../contrib_ops/cpu/bert/attention_common.h | 6 + .../cpu/sparse/sparse_attention.cc | 73 +++++ .../cpu/sparse/sparse_attention_helper.h | 15 +- .../cuda/sparse/sparse_attention.cc | 23 +- .../cuda/sparse/sparse_attention.h | 17 +- .../cuda/sparse/sparse_attention_impl.cu | 114 +++++++ .../cuda/sparse/sparse_attention_impl.h | 25 ++ .../contrib_ops/sparse_attention_op_test.cc | 293 +++++++++++++++++- 8 files changed, 533 insertions(+), 33 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index ac3531e39eb53..e46075a86f811 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -91,6 +91,12 @@ constexpr bool LAYOUT_BNSH = true; namespace sparse_attention { // Environment variable to enable or disable sparse attention v1 kernel. Default is 0 (enabled). constexpr const char* kDisableSparseAttentionV1 = "ORT_DISABLE_SPARSE_ATTENTION_V1"; + +// Environment variable to disable device-side validation of CSR indices and key sequence lengths. +// Default is 0 (validation enabled). Set to 1 to skip the validation kernel launch and stream +// synchronization, which may improve latency when inputs are known to be well-formed. +// Usage: export ORT_DISABLE_SPARSE_ATTENTION_INPUT_VALIDATION=1 +constexpr const char* kDisableInputValidation = "ORT_DISABLE_SPARSE_ATTENTION_INPUT_VALIDATION"; } // namespace sparse_attention namespace attention { diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc index bab2dfd13e046..b027772971540 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc @@ -36,6 +36,74 @@ namespace contrib { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +namespace { + +// Validate CSR row-pointer monotonicity and column-index range. +// Must be called after CheckInputs has populated the parameters struct. +Status ValidateCSRIndices(const SparseAttentionParameters& parameters, + const Tensor& block_row_indices, + const Tensor& block_col_indices) { + const int num_layout = parameters.num_sparse_layout; + const int max_blocks = parameters.stride_row_indices - 1; + const int col_count = parameters.stride_col_indices; + + const int32_t* row_data = block_row_indices.Data(); + const int32_t* col_data = block_col_indices.Data(); + for (int l = 0; l < num_layout; ++l) { + const int32_t* r = row_data + l * (max_blocks + 1); + if (r[0] != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_row_indices[", l, "][0] must be 0, got ", r[0]); + } + for (int i = 0; i < max_blocks; ++i) { + if (r[i] < 0 || r[i] > r[i + 1] || r[i + 1] > col_count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_row_indices values are not monotonically non-decreasing or exceed " + "block_col_indices columns at layout ", + l, " row ", i, + ": r[", i, "]=", r[i], ", r[", i + 1, "]=", r[i + 1], + ", col_count=", col_count); + } + } + const int32_t* c = col_data + l * col_count; + const int nnz = r[max_blocks]; + for (int k = 0; k < nnz; ++k) { + if (c[k] < 0 || c[k] >= max_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "block_col_indices[", l, "][", k, "]=", c[k], + " is out of valid range [0, ", max_blocks, ")"); + } + } + } + + return Status::OK(); +} + +// Validate total_key_lengths element values. +Status ValidateKeyLengths(const SparseAttentionParameters& parameters, + const Tensor& total_key_lengths) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + + const auto* key_len_data = total_key_lengths.Data(); + const bool is_prompt = (sequence_length == total_sequence_length); + const int min_key_length = is_prompt ? 1 : sequence_length; + for (int i = 0; i < batch_size; ++i) { + const int key_length = key_len_data[i]; + if (key_length < min_key_length || key_length > total_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "key_total_sequence_lengths value ", key_length, + " at batch index ", i, + " is out of range [", min_key_length, ", ", total_sequence_length, "]."); + } + } + + return Status::OK(); +} + +} // namespace + template SparseAttention::SparseAttention(const OpKernelInfo& info) : OpKernel(info), SparseAttentionBase(info) { } @@ -75,6 +143,11 @@ Status SparseAttention::Compute(OpKernelContext* context) const { block_col_indices, total_key_lengths, total_seq_len)); + ORT_RETURN_IF_ERROR(ValidateCSRIndices(parameters, + *block_row_indices, + *block_col_indices)); + ORT_RETURN_IF_ERROR(ValidateKeyLengths(parameters, + *total_key_lengths)); const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h index 2804f30a9611d..af320b250abdb 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_helper.h @@ -198,26 +198,13 @@ Status CheckInputs(void* params, past_key_dims[3]); } - // Check the shape and values of total_key_sequence_lengths. + // Check the shape of total_key_sequence_lengths. const auto& k_len_dim = total_key_lengths->Shape().GetDims(); if (k_len_dim.size() != 1 || k_len_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "key_total_sequence_lengths must have shape (batch_size)."); } - const auto* key_len_data = total_key_lengths->Data(); - const bool is_prompt = (sequence_length == total_sequence_length); - const int min_key_length = is_prompt ? 1 : sequence_length; - for (int i = 0; i < batch_size; ++i) { - const int key_length = key_len_data[i]; - if (key_length < min_key_length || key_length > total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "key_total_sequence_lengths value ", key_length, - " at batch index ", i, - " is out of range [", min_key_length, ", ", total_sequence_length, "]."); - } - } - int rotary_dim = 0; int max_rotary_sequence_length = 0; if (do_rotary) { diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 656fde2f46ab8..355319e84e534 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -56,6 +56,8 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); disable_v1_kernel_ = ParseEnvironmentVariableWithDefault(sparse_attention::kDisableSparseAttentionV1, false); + disable_input_validation_ = ParseEnvironmentVariableWithDefault( + sparse_attention::kDisableInputValidation, false); } template @@ -105,6 +107,26 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { block_col_indices, seqlens_k_total, total_seq_len)); + + // Validate CSR indices and key lengths on device to prevent out-of-bounds access. + // This must run before the shared-buffer check so OpTester-based tests can exercise it. + cudaStream_t cuda_stream = Stream(context); + if (!disable_input_validation_) { + auto csr_error_buffer = GetScratchBuffer(1, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(ValidateCSRIndicesOnDevice( + cuda_stream, + block_row_indices->Data(), + block_col_indices->Data(), + seqlens_k_total->Data(), + parameters.num_sparse_layout, + parameters.stride_row_indices - 1, // max_blocks + parameters.stride_col_indices, // col_count + parameters.batch_size, + parameters.sequence_length, + parameters.total_sequence_length, + csr_error_buffer.get())); + } + // Some limitations of CUDA kernels // The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported. int sm = device_prop.major * 10 + device_prop.minor; @@ -137,7 +159,6 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { int32_t* total_k_seq_len_pinned = nullptr; AutoDestoryCudaEvent new_event; cudaEvent_t& isCopyDone = new_event.Get(); - cudaStream_t cuda_stream = Stream(context); if (use_v2_kernel) { pinned_buffer = AllocateBufferOnCPUPinned(parameters.batch_size); diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h index 1df3affe17ea3..06fc07f088c60 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h @@ -18,14 +18,15 @@ class SparseAttention final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - int num_heads_; // number of attention heads for q - int kv_num_heads_; // number of attention heads for k and v - float scale_; // Scaling factor applied prior to softmax. - bool is_causal_; // unidirectional attention or not - int sparse_block_size_; // block size for sparsity - bool do_rotary_; // Has rotary positional embedding - bool rotary_interleaved_; // Interleaved rotary positional embedding - bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. + int num_heads_; // number of attention heads for q + int kv_num_heads_; // number of attention heads for k and v + float scale_; // Scaling factor applied prior to softmax. + bool is_causal_; // unidirectional attention or not + int sparse_block_size_; // block size for sparsity + bool do_rotary_; // Has rotary positional embedding + bool rotary_interleaved_; // Interleaved rotary positional embedding + bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. + bool disable_input_validation_; // Whether to skip device-side CSR and key-length validation. }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu index b2a6eb89d4d23..1fecb91b4f578 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu @@ -333,6 +333,120 @@ template Status QkvToContext( contrib::SparseAttentionParameters& parameters, SparseAttentionData& data); +// Validation kernel for CSR sparse layout indices and key sequence lengths. +// Each block handles one layout (blocks [0, num_layout)) or key lengths (block num_layout). +// All threads in a warp cooperate via strided iteration over elements. +// Writes a CSRValidationError code to *error_flag if any check fails. +__global__ void ValidateCSRIndicesKernel( + const int32_t* csr_row_indices, + const int32_t* csr_col_indices, + const int32_t* seqlens_k_total, + int max_blocks, + int col_count, + int num_layout, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* error_flag) { + int block_id = blockIdx.x; + int tid = threadIdx.x; + int num_threads = blockDim.x; + + if (block_id < num_layout) { + // Validate CSR indices for this layout. + const int stride_row = max_blocks + 1; + const int32_t* r = csr_row_indices + block_id * stride_row; + + // Phase 1: thread 0 validates all row pointers sequentially. + // Row arrays are small (max_blocks+1 elements), so single-thread scan is sufficient. + // All threads must reach __syncthreads before proceeding to col validation. + __shared__ int row_valid; + if (tid == 0) { + row_valid = 1; + if (r[0] != 0) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationRowFirstNotZero); + row_valid = 0; + } else { + for (int i = 0; i < max_blocks; ++i) { + if (r[i] < 0 || r[i] > r[i + 1] || r[i + 1] > col_count) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationRowNonMonotonic); + row_valid = 0; + break; + } + } + } + } + __syncthreads(); + if (!row_valid) return; + + // Phase 2: row pointers are validated, r[max_blocks] is safe to use as NNZ bound. + // All threads cooperate on the potentially larger col-index array. + const int nnz = r[max_blocks]; + const int32_t* c = csr_col_indices + block_id * col_count; + for (int i = tid; i < nnz; i += num_threads) { + if (c[i] < 0 || c[i] >= max_blocks) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationColOutOfRange); + return; + } + } + } else if (block_id == num_layout) { + // Validate key lengths. All threads cooperate in strided fashion. + bool is_prompt = (sequence_length == total_sequence_length); + int min_key_length = is_prompt ? 1 : sequence_length; + for (int i = tid; i < batch_size; i += num_threads) { + int key_length = seqlens_k_total[i]; + if (key_length < min_key_length || key_length > total_sequence_length) { + atomicCAS(error_flag, kCSRValidationOk, kCSRValidationKeyLengthOutOfRange); + return; + } + } + } +} + +Status ValidateCSRIndicesOnDevice( + cudaStream_t stream, + const int32_t* csr_row_indices, + const int32_t* csr_col_indices, + const int32_t* seqlens_k_total, + int num_layout, + int max_blocks, + int col_count, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* d_error_flag) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(d_error_flag, 0, sizeof(int32_t), stream)); + + // Launch num_layout blocks for CSR validation + 1 block for key-length validation. + // Each block uses a full warp (32 threads) with strided iteration over elements. + ValidateCSRIndicesKernel<<>>( + csr_row_indices, csr_col_indices, seqlens_k_total, + max_blocks, col_count, num_layout, + batch_size, sequence_length, total_sequence_length, + d_error_flag); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + // Copy error flag back to host. + int32_t h_error_flag = 0; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(&h_error_flag, d_error_flag, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + if (h_error_flag != kCSRValidationOk) { + const char* msg = (h_error_flag == kCSRValidationRowFirstNotZero) + ? "block_row_indices first element must be 0 for all layouts" + : (h_error_flag == kCSRValidationRowNonMonotonic) + ? "block_row_indices values are not monotonically non-decreasing or exceed " + "block_col_indices columns" + : (h_error_flag == kCSRValidationColOutOfRange) + ? "block_col_indices value is out of valid range" + : "key_total_sequence_lengths value is out of valid range"; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, msg); + } + + return Status::OK(); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h index d4f686afe5db0..f5b24d99b10a9 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.h @@ -68,6 +68,31 @@ Status QkvToContext( contrib::SparseAttentionParameters& parameters, SparseAttentionData& data); +// Error codes returned by the CSR validation kernel via the device error flag. +enum CSRValidationError : int32_t { + kCSRValidationOk = 0, + kCSRValidationRowFirstNotZero = 1, + kCSRValidationRowNonMonotonic = 2, + kCSRValidationColOutOfRange = 3, + kCSRValidationKeyLengthOutOfRange = 4, +}; + +// Validate CSR row-pointer monotonicity, column-index range, and key lengths on device. +// Returns Status::OK() if valid, or INVALID_ARGUMENT with a description of the failure. +// d_error_flag must point to a device-allocated int32_t scratch buffer (1 element). +Status ValidateCSRIndicesOnDevice( + cudaStream_t stream, + const int32_t* csr_row_indices, // device pointer, shape [num_layout, max_blocks + 1] + const int32_t* csr_col_indices, // device pointer, shape [num_layout, col_count] + const int32_t* seqlens_k_total, // device pointer, shape [batch_size] + int num_layout, + int max_blocks, + int col_count, + int batch_size, + int sequence_length, + int total_sequence_length, + int32_t* d_error_flag); // device scratch buffer (1 int32_t) + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc b/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc index 71d8c34353f02..d7953442d738e 100644 --- a/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/sparse_attention_op_test.cc @@ -24,10 +24,10 @@ namespace test { namespace { -void RunSparseAttentionInvalidInputTest(const std::vector& total_key_lengths_data, - const std::vector& total_key_lengths_dims, - const std::string& expected_error, - int32_t total_sequence_length = 4) { +void RunSparseAttentionInvalidKeyLengthsTest(const std::vector& total_key_lengths_data, + const std::vector& total_key_lengths_dims, + const std::string& expected_error, + int32_t total_sequence_length = 4) { OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", 2); test.AddAttribute("kv_num_heads", 2); @@ -42,7 +42,7 @@ void RunSparseAttentionInvalidInputTest(const std::vector& total_key_le test.AddInput("past_key", {1, 2, 4, 8}, std::vector(64, 0.0f)); test.AddInput("past_value", {1, 2, 4, 8}, std::vector(64, 0.0f)); test.AddInput("block_row_indices", {1, 5}, {0, 1, 2, 3, 4}); - test.AddInput("block_col_indices", {1, 1}, {0}); + test.AddInput("block_col_indices", {1, 4}, {0, 1, 2, 3}); test.AddInput("total_sequence_length", {1}, {total_sequence_length}); test.AddInput("key_total_sequence_lengths", total_key_lengths_dims, total_key_lengths_data); test.AddOptionalInputEdge(); @@ -209,17 +209,17 @@ void RunSparseAttentionPromptInputTest(const std::vector& total_key_len } // namespace TEST(SparseAttentionTest, RejectsOutOfRangeKeyTotalSequenceLengths) { - RunSparseAttentionInvalidInputTest({-5}, {1}, "key_total_sequence_lengths value -5 at batch index 0 is out of range [1, 4]"); + RunSparseAttentionInvalidKeyLengthsTest({-5}, {1}, "key_total_sequence_lengths value -5 at batch index 0 is out of range [1, 4]"); } TEST(SparseAttentionTest, RejectsKeyTotalSequenceLengthsShapeMismatch) { - RunSparseAttentionInvalidInputTest({4, 4}, {2}, "key_total_sequence_lengths must have shape (batch_size)"); + RunSparseAttentionInvalidKeyLengthsTest({4, 4}, {2}, "key_total_sequence_lengths must have shape (batch_size)"); } TEST(SparseAttentionTest, RejectsPromptKeyTotalSequenceLengthsShorterThanSequenceLength) { - RunSparseAttentionInvalidInputTest({0}, {1}, - "key_total_sequence_lengths value 0 at batch index 0 is out of range [1, 1]", - 1); + RunSparseAttentionInvalidKeyLengthsTest({0}, {1}, + "key_total_sequence_lengths value 0 at batch index 0 is out of range [1, 1]", + 1); } TEST(SparseAttentionTest, AcceptsPromptKeyTotalSequenceLengthsForPaddedBatch) { @@ -258,5 +258,278 @@ TEST(SparseAttentionTest, RejectsZeroDimBlockRowIndices) { {}, nullptr, &execution_providers); } +// Helper for CSR value-validation tests. +// Uses: num_heads=2, kv_num_heads=2, sparse_block_size=16, head_size=8. +// block_row_indices shape: (1, max_blocks+1), block_col_indices shape: (1, col_count). +// max_sequence_length = max_blocks * 16 must be >= total_sequence_length. +// These tests validate that element values in block_row_indices and block_col_indices are checked. +// Note: these tests expect failure via a returned Status (ORT_MAKE_STATUS), so they are safe in +// both exceptions-enabled and no-exceptions builds. +static void RunSparseAttentionCSRValidationTest( + const std::vector& block_row_indices_data, + const std::vector& block_row_indices_dims, + const std::vector& block_col_indices_data, + const std::vector& block_col_indices_dims, + const std::string& expected_error) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 2); + test.AddAttribute("kv_num_heads", 2); + test.AddAttribute("sparse_block_size", 16); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + // head_size=8, num_heads=2 => hidden_size=16 + // sequence_length=1, batch_size=1 + test.AddInput("query", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddInput("key", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddInput("value", {1, 1, 16}, std::vector(16, 0.0f)); + // past_key/value: (batch_size=1, kv_num_heads=2, max_cache_seq_len=32, head_size=8) + test.AddInput("past_key", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddInput("past_value", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddInput("block_row_indices", block_row_indices_dims, block_row_indices_data); + test.AddInput("block_col_indices", block_col_indices_dims, block_col_indices_data); + test.AddInput("total_sequence_length", {1}, {2}); + test.AddInput("key_total_sequence_lengths", {1}, {2}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, 16}, std::vector(16, 0.0f)); + test.AddOutput("present_key", {1, 2, 32, 8}, std::vector(512, 0.0f)); + test.AddOutput("present_value", {1, 2, 32, 8}, std::vector(512, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error, {}, nullptr, &execution_providers); +} + +// block_row_indices[0][0] must be 0. +TEST(SparseAttentionTest, RejectsBlockRowIndicesFirstElementNonZero) { + // shape (1, 3) => max_blocks=2, max_sequence_length=32 + RunSparseAttentionCSRValidationTest( + {1, 1, 2}, {1, 3}, // row indices: first element is 1, not 0 + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices[0][0] must be 0"); +} + +// block_row_indices must be monotonically non-decreasing. +TEST(SparseAttentionTest, RejectsBlockRowIndicesNonMonotonic) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 2, 1}, {1, 3}, // row indices: 2 > 1 at row 1 (non-monotonic) + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_row_indices values must not exceed block_col_indices column count. +TEST(SparseAttentionTest, RejectsBlockRowIndicesExceedsColCount) { + // shape (1, 3) => max_blocks=2, col_count=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 3}, {1, 3}, // row indices: last element 3 > col_count=2 + {0, 1}, {1, 2}, // col indices shape (1, 2) + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_col_indices values must be in [0, max_blocks). +TEST(SparseAttentionTest, RejectsBlockColIndicesOutOfRange) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2}, {1, 3}, // row indices: valid + {0, 2}, {1, 2}, // col indices: value 2 >= max_blocks=2 + "block_col_indices[0][1]=2 is out of valid range [0, 2)"); +} + +// block_col_indices negative values must be rejected. +TEST(SparseAttentionTest, RejectsBlockColIndicesNegative) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2}, {1, 3}, // row indices: valid + {0, -1}, {1, 2}, // col indices: negative value + "block_col_indices[0][1]=-1 is out of valid range [0, 2)"); +} + +// block_row_indices with negative values. +TEST(SparseAttentionTest, RejectsBlockRowIndicesNegative) { + RunSparseAttentionCSRValidationTest( + {0, -1, 2}, {1, 3}, // row indices: negative value at index 1 + {0, 1}, {1, 2}, // col indices: valid + "block_row_indices values are not monotonically non-decreasing"); +} + +// block_col_indices with large OOB value (the original vulnerability scenario). +TEST(SparseAttentionTest, RejectsBlockColIndicesLargeValue) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 2, 2}, {1, 3}, // row indices: valid CSR format + {0, 1048576}, {1, 2}, // col indices: 0x100000 far out of range + "block_col_indices[0][1]=1048576 is out of valid range [0, 2)"); +} + +// Multi-layout: invalid col index in second layout only. +TEST(SparseAttentionTest, RejectsBlockColIndicesInvalidInSecondLayout) { + // shape (2, 3) => num_layout=2, max_blocks=2 + // num_heads=2, so num_heads % num_layout == 0 + RunSparseAttentionCSRValidationTest( + {0, 1, 2, 0, 1, 2}, {2, 3}, // row indices: valid for both layouts + {0, 1, 0, 5}, {2, 2}, // col indices: layout 0 valid, layout 1 has 5 >= max_blocks=2 + "block_col_indices[1][1]=5 is out of valid range [0, 2)"); +} + +// Multi-layout: invalid row pointer in second layout only. +TEST(SparseAttentionTest, RejectsBlockRowIndicesInvalidInSecondLayout) { + // shape (2, 3) => num_layout=2, max_blocks=2 + RunSparseAttentionCSRValidationTest( + {0, 1, 2, 1, 1, 2}, {2, 3}, // row indices: layout 0 valid, layout 1 starts with 1 != 0 + {0, 1, 0, 1}, {2, 2}, // col indices: valid + "block_row_indices[1][0] must be 0"); +} + +// Col index invalid within NNZ range but padding would be fine. +// row pointers say nnz=1, col[0] is invalid, col[1] is padding (not checked). +TEST(SparseAttentionTest, RejectsBlockColIndicesInvalidWithinNNZ) { + // shape (1, 3) => max_blocks=2, row indices: {0, 1, 1} means row 0 has 1 entry, row 1 has 0 + // nnz = r[max_blocks] = r[2] = 1, so only col[0] is validated + RunSparseAttentionCSRValidationTest( + {0, 1, 1}, {1, 3}, // row indices: valid, nnz=1 + {99, 0}, {1, 2}, // col[0]=99 is out of range, col[1]=0 is padding (not checked) + "block_col_indices[0][0]=99 is out of valid range [0, 2)"); +} + +#if defined(USE_CUDA) +// CUDA-specific CSR validation tests. +// CUDA SparseAttention requires head_size=128, sparse_block_size=64, and MLFloat16 inputs. +// These tests verify that the device-side ValidateCSRIndicesOnDevice kernel correctly +// rejects invalid CSR indices. Error messages are less detailed than CPU (no per-element info) +// because the CUDA kernel reports via a single error code. +// Note: OpTester does not share past/present buffers (no IOBinding), but that is fine here +// because the CSR validation runs before the shared-buffer check in ComputeInternal. +// These tests expect failure from validation, not from compute. +static void RunSparseAttentionCudaCSRValidationTest( + const std::vector& block_row_indices_data, + const std::vector& block_row_indices_dims, + const std::vector& block_col_indices_data, + const std::vector& block_col_indices_dims, + const std::string& expected_error) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 1); + test.AddAttribute("kv_num_heads", 1); + test.AddAttribute("sparse_block_size", 64); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + // head_size=128, num_heads=1 => hidden_size=128 + // sequence_length=1, batch_size=1 + const int64_t hidden_size = 128; + const int64_t max_cache_seq_len = 128; + test.AddInput("query", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("key", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("value", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("past_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("past_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("block_row_indices", block_row_indices_dims, block_row_indices_data); + test.AddInput("block_col_indices", block_col_indices_dims, block_col_indices_data); + test.AddInput("total_sequence_length", {1}, {2}); + test.AddInput("key_total_sequence_lengths", {1}, {2}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + + // Run only on CUDA EP. CPU EP does not register MLFloat16 for SparseAttention with these params. + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, expected_error, {}, nullptr, &execution_providers); +} + +// CUDA: block_row_indices first element must be 0. +TEST(SparseAttentionTest, CudaRejectsBlockRowIndicesFirstElementNonZero) { + // shape (1, 3) => max_blocks=2 + RunSparseAttentionCudaCSRValidationTest( + {1, 1, 2}, {1, 3}, + {0, 1}, {1, 2}, + "block_row_indices first element must be 0 for all layouts"); +} + +// CUDA: block_row_indices must be monotonically non-decreasing. +TEST(SparseAttentionTest, CudaRejectsBlockRowIndicesNonMonotonic) { + RunSparseAttentionCudaCSRValidationTest( + {0, 2, 1}, {1, 3}, + {0, 1}, {1, 2}, + "block_row_indices values are not monotonically non-decreasing"); +} + +// CUDA: block_col_indices values must be in range. +TEST(SparseAttentionTest, CudaRejectsBlockColIndicesOutOfRange) { + RunSparseAttentionCudaCSRValidationTest( + {0, 1, 2}, {1, 3}, + {0, 99}, {1, 2}, + "block_col_indices value is out of valid range"); +} + +// CUDA: block_col_indices with large OOB value. +TEST(SparseAttentionTest, CudaRejectsBlockColIndicesLargeValue) { + RunSparseAttentionCudaCSRValidationTest( + {0, 2, 2}, {1, 3}, + {0, 1048576}, {1, 2}, + "block_col_indices value is out of valid range"); +} + +// CUDA: key_total_sequence_lengths out of range. +TEST(SparseAttentionTest, CudaRejectsKeyLengthOutOfRange) { + OpTester test("SparseAttention", 1, onnxruntime::kMSDomain); + test.AddAttribute("num_heads", 1); + test.AddAttribute("kv_num_heads", 1); + test.AddAttribute("sparse_block_size", 64); + test.AddAttribute("scale", 1.0f); + test.AddAttribute("do_rotary", 0); + test.AddAttribute("rotary_interleaved", 0); + + const int64_t hidden_size = 128; + const int64_t max_cache_seq_len = 128; + test.AddInput("query", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("key", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("value", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddInput("past_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddInput("past_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + // Valid CSR: shape (1, 3) => max_blocks=2 + test.AddInput("block_row_indices", {1, 3}, {0, 1, 2}); + test.AddInput("block_col_indices", {1, 2}, {0, 1}); + test.AddInput("total_sequence_length", {1}, {4}); + // Invalid key length: -5 is out of range [1, 4] + test.AddInput("key_total_sequence_lengths", {1}, {-5}); + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + test.AddOutput("output", {1, 1, hidden_size}, + std::vector(hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_key", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + test.AddOutput("present_value", {1, 1, max_cache_seq_len, hidden_size}, + std::vector(max_cache_seq_len * hidden_size, MLFloat16(0.0f))); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "key_total_sequence_lengths value is out of valid range", + {}, nullptr, &execution_providers); +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime From 37a3b51c80fb0f69cc1400f66619d6caacedb963 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Fri, 26 Jun 2026 15:23:34 -0700 Subject: [PATCH 13/19] Fix bounds in WhisperDecoderSubgraph::CreateInitialFeeds initial feeds (#29239) ### Description `WhisperDecoderSubgraph::CreateInitialFeeds` constructed the decoder initial feeds using a single value that mixed a **byte count** with an **element count**. The total size was computed as `cur_len * batch_beam_size * sizeof(int)` (bytes) and then reused as: - the element count for the int32 staging buffer (`MakeUniquePtr`), and - the element count for the `gsl::span` source/destination passed to the device copy. Because the `input_ids` tensor is allocated for exactly `batch_beam_size * cur_len` int32 elements, the spans claimed 4x the real extent, so the device copy ran past the end of the buffer. The per-beam `memcpy` also used the same combined value as its length instead of a single sequence's byte size. This mirrors the correct T5 sibling (`subgraph_t5_decoder.cc`), which separates the element count (used for the spans/staging allocation) from the per-sequence byte count (used for the `memcpy`). ### Changes - `subgraph_whisper_decoder.cc`: `total_size` is now the element count `cur_len * batch_beam_size`; introduced `sequence_bytes = cur_len * sizeof(int32_t)` for the per-beam `memcpy`. The staging buffer and spans use `int32_t` consistently to match the `int32_t` tensors/sequences. - Added regression test `BeamSearchTest.DummyWhisperWithSequenceInputIds` (CPU, and CUDA under `USE_CUDA`) exercising the `use_sequence_as_input_ids` path, with a deterministic dummy model and its generator script. The test validates both the `sequences` and `scores` outputs. ### Related bool-tensor normalization fixes While exercising the Whisper path, bool tensors copied from raw data could hold non-canonical byte values (anything non-zero rather than strictly `{0, 1}`), causing provider-dependent behavior. To keep the fix self-contained, the following normalization changes are included: - `tensorprotoutils.cc`: `UnpackTensor` normalizes raw-data bytes to `{0, 1}` (with a `static_assert(sizeof(bool) == 1)` guarding the byte-wise loop). - `compress_impl.cu` (CUDA `Compress`): the prefix-sum sizing predicate normalizes bool bytes to `{0, 1}` so the output sizing agrees with the element-selection truthiness check. Since bool initializers are now normalized on unpack, the remaining exposure is runtime-produced bool condition tensors. - Added `CompressTest.Compress_cuda_non_canonical_bool_condition` (under `USE_CUDA`), which feeds a raw `0xFF` condition byte through a session-level run (`OpTester` normalizes bool inputs and so cannot reproduce this) and asserts the Compress output is sized by truthiness rather than by the sign-extended byte value. ### Motivation The decoder shares one implementation file across CPU/CUDA/ROCm, so this single change covers all execution providers. The previous behavior could overrun the staging/feed buffers for models that drive the sequence-as-input-ids decoder path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../transformers/subgraph_whisper_decoder.cc | 15 +- .../core/framework/tensorprotoutils.cc | 12 +- .../providers/cuda/tensor/compress_impl.cu | 7 +- .../test/contrib_ops/beam_search_test.cc | 23 ++ .../test/framework/tensorutils_test.cc | 24 ++ .../providers/cpu/tensor/compress_op.test.cc | 84 +++++ .../testdata/dummy_whisper_model_generator.py | 328 ++++++++++++++++++ ...dummy_whisper_with_sequence_input_ids.onnx | Bin 0 -> 7072 bytes 8 files changed, 485 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/test/testdata/dummy_whisper_model_generator.py create mode 100644 onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index ad778fb7ef907..1fe4ee65e5f31 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -168,9 +168,12 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = CPUAllocator::DefaultInstance(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); - int* seq_copy_ptr = seq_copy.get(); + // total_size is an element count: it sizes the int32 staging buffer and the copy spans. + // sequence_bytes is the byte count for copying a single beam's sequence per iteration. + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); + size_t sequence_bytes = static_cast(cur_len) * sizeof(int32_t); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + int32_t* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { ORT_RETURN_IF_ERROR(device_copy_int32_func( @@ -183,10 +186,10 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( gsl::span sequence = sequences.GetSequence(i); const int32_t* sequence_data = sequence.data(); long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + memcpy(seq_copy_ptr + seq_index, sequence_data, sequence_bytes); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); ORT_RETURN_IF_ERROR(device_copy_int32_func( temp_input, temp_sequence, diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 8685655420a38..e6df3a9923ad5 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -907,7 +907,17 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d } if (raw_data != nullptr) { - return UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data); + ORT_RETURN_IF_ERROR(UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data)); + // raw_data is copied verbatim and may contain bytes outside the canonical {0, 1} set. + // Consumers rely on bool tensors holding {0, 1}; normalize any non-zero byte to 1 so every + // reader observes a single, consistent value. Operate on the byte representation to avoid + // loading a bool object that does not yet hold a valid value. + auto* bool_bytes = reinterpret_cast(p_data); + static_assert(sizeof(bool) == 1, "Normalization loop writes expected_size bytes assuming 1 byte per bool element"); + for (size_t i = 0; i < expected_size; ++i) { + bool_bytes[i] = bool_bytes[i] != 0 ? 1 : 0; + } + return Status::OK(); } if (static_cast(tensor.int32_data_size()) != expected_size) diff --git a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu index b06a640fb72a1..9abbb93d228f4 100644 --- a/onnxruntime/core/providers/cuda/tensor/compress_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/compress_impl.cu @@ -22,7 +22,12 @@ namespace cuda { // see https://github.com/NVIDIA/cub/issues/384 struct CastToInt32 { __host__ __device__ int32_t operator()(int8_t v) const { - return static_cast(v); + // Normalize to {0, 1} so the prefix-sum sizing path agrees with the truthiness predicate + // (condition_data[div]) used in _CompressKernel. A bool byte may hold any non-zero value; + // sign-extending it here would size the output differently from how elements are selected. + // bool initializers are normalized to {0, 1} when unpacked (see tensorprotoutils.cc), so the + // remaining source of non-canonical bytes is runtime-produced bool condition tensors. + return v != 0 ? 1 : 0; } }; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index d2001cfb9f2bd..e9e7e20271090 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -446,6 +446,29 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { tester.RunWithConfig(); } +TEST(BeamSearchTest, DummyWhisperWithSequenceInputIds) { + // dummy_whisper_with_sequence_input_ids.onnx model generated using following command: + // python onnxruntime/test/testdata/dummy_whisper_model_generator.py \ + // --output-path dummy_whisper_with_sequence_input_ids.onnx --sequence-as-input + // The decoder subgraph leaves input_ids second dim symbolic, so the decoder feeds are built from the + // running sequence (use_sequence_as_input_ids_ == true), exercising the multi-token initial feed path. + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_whisper_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); + tester.AddInput("input_features", {1, 8, 5}, + {-0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, + 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, + 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, + -0.1f, 0.0f, 0.1f, 0.2f, 0.3f, -0.3f, -0.2f, -0.1f, 0.0f, 0.1f}); + tester.AddInput("decoder_input_ids", {1, 2}, {2, 5}); + tester.AddOutput("sequences", {1, 1, 10}, {2, 5, 1, 1, 1, 1, 1, 1, 1, 1}); + tester.AddOutput("scores", {1, 1}, {-0.05625312775373459f}, false /* sort_output */, 1e-4f /* rel_error */, + 1e-4f /* abs_error */); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + TEST(BeamSearchTest, DummyT5PointerGenerator) { // dummy_t5_pointer_generator.onnx model generated using following command: // python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 38a0f15f2301d..34575f481d3cc 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -225,6 +225,30 @@ TEST(TensorProtoUtilsTest, UnpackTensor) { EXPECT_FALSE(status.IsOK()); } +// A bool initializer supplied through raw_data is copied verbatim, so its bytes are not +// restricted to {0, 1}. UnpackTensor must normalize them so downstream consumers (which assume +// canonical bool values) all observe the same result regardless of how they read the byte. +TEST(TensorProtoUtilsTest, UnpackBoolTensorWithRawDataNormalizesToZeroOne) { + std::filesystem::path model_path; + TensorProto bool_tensor_proto; + bool_tensor_proto.set_data_type(TensorProto_DataType_BOOL); + bool_tensor_proto.add_dims(4); + + // Bytes outside {0, 1}: 0x00 -> 0, 0x01 -> 1, 0x02 -> 1, 0xFF -> 1. + const unsigned char raw_bytes[] = {0x00, 0x01, 0x02, 0xFF}; + bool_tensor_proto.set_raw_data(std::string(reinterpret_cast(raw_bytes), sizeof(raw_bytes))); + + bool bool_data[4]; + auto status = UnpackTensor(bool_tensor_proto, model_path, bool_data, 4); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + const auto* bytes = reinterpret_cast(bool_data); + EXPECT_EQ(bytes[0], 0); + EXPECT_EQ(bytes[1], 1); + EXPECT_EQ(bytes[2], 1); + EXPECT_EQ(bytes[3], 1); +} + namespace { template std::vector CreateValues() { diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c3d91100605e9..495fea5735b32 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -4,6 +4,14 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#ifdef USE_CUDA +#include "core/graph/model.h" +#include "core/session/inference_session.h" +#include "test/test_environment.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/default_providers.h" +#endif + namespace onnxruntime { namespace test { @@ -148,5 +156,81 @@ TEST(CompressTest, Compress_3dims_neg_axis) { test.Run(); } +#ifdef USE_CUDA +// Regression test for the CUDA Compress prefix-sum sizing path. A bool condition byte may hold a +// non-canonical value (e.g. 0xFF) at runtime — initializers are normalized to {0, 1} on unpack, +// but runtime-produced bool tensors are not. Without normalizing the byte before the prefix sum, +// 0xFF would be summed as 255 (sizing the output for 255 selected elements) while _CompressKernel +// selects it as a single element via truthiness, so sizing and selection would disagree. This +// test feeds a raw 0xFF condition byte (which OpTester cannot produce, since it normalizes bool +// inputs to {0, 1}) and asserts the output is sized by truthiness. +TEST(CompressTest, Compress_cuda_non_canonical_bool_condition) { + // Build: output = Compress(input, condition, axis=0) + auto model = std::make_unique("compress_non_canonical_bool", false, + DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + + ONNX_NAMESPACE::TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ONNX_NAMESPACE::TypeProto tensor_bool; + tensor_bool.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + + auto& input_arg = graph.GetOrCreateNodeArg("input", &tensor_float); + auto& condition_arg = graph.GetOrCreateNodeArg("condition", &tensor_bool); + auto& output_arg = graph.GetOrCreateNodeArg("output", &tensor_float); + + std::vector input_defs{&input_arg, &condition_arg}; + std::vector output_defs{&output_arg}; + auto& node = graph.AddNode("compress", "Compress", "Compress", input_defs, output_defs, nullptr, + onnxruntime::kOnnxDomain); + node.AddAttribute("axis", static_cast(0)); + ASSERT_STATUS_OK(graph.Resolve()); + + SessionOptions so; + so.session_logid = "CompressTest.Compress_cuda_non_canonical_bool_condition"; + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); + + std::string serialized_model; + ASSERT_TRUE(model->ToProto().SerializeToString(&serialized_model)); + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + AllocatorPtr cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + OrtValue input_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({3, 2}), cpu_allocator, input_value); + const float input_data[6] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + memcpy(input_value.GetMutable()->MutableData(), input_data, sizeof(input_data)); + + // Condition {false, non-canonical-true, true}: write a raw 0xFF byte for the middle element to + // emulate a runtime-produced bool tensor outside the canonical {0, 1} set. + OrtValue condition_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({3}), cpu_allocator, condition_value); + auto* condition_bytes = + reinterpret_cast(condition_value.GetMutable()->MutableDataRaw()); + condition_bytes[0] = 0x00; + condition_bytes[1] = 0xFF; + condition_bytes[2] = 0x01; + + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run( + std::unordered_map{{"input", input_value}, {"condition", condition_value}}, + std::vector{"output"}, &fetches)); + + ASSERT_EQ(fetches.size(), 1u); + const Tensor& output = fetches[0].Get(); + // Two non-zero condition bytes select two rows along axis 0 (not 256). + EXPECT_EQ(output.Shape(), TensorShape({2, 2})); + const auto output_span = output.DataAsSpan(); + const std::vector expected{3.0f, 4.0f, 5.0f, 6.0f}; + ASSERT_EQ(output_span.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_EQ(output_span[i], expected[i]); + } +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/dummy_whisper_model_generator.py b/onnxruntime/test/testdata/dummy_whisper_model_generator.py new file mode 100644 index 0000000000000..a90518255193b --- /dev/null +++ b/onnxruntime/test/testdata/dummy_whisper_model_generator.py @@ -0,0 +1,328 @@ +"""Script to generate a dummy ONNX model emulating a Whisper model with BeamSearch op. + +The model is intentionally tiny and produces deterministic (but meaningless) outputs. +Its only purpose is to exercise the WhisperBeamSearch encoder/decoder subgraph plumbing, +in particular the decoder "use sequence as input ids" path that builds the initial decoder +feeds from the full running sequences. +""" + +import argparse + +import numpy as np +import onnx + + +def create_model( + vocab_size: int, + embed_dim: int, + num_heads: int, + head_size: int, + feature_size: int, + beam_size: int, + min_length: int, + max_length: int, + length_penalty: float, + sequence_as_input: bool, +) -> onnx.ModelProto: + encoder_graph = create_encoder(vocab_size, embed_dim, num_heads, head_size, feature_size) + decoder_graph = create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input) + + # Top-level inputs: input_features (audio) and decoder_input_ids (initial transcript tokens). + input_features = onnx.helper.make_tensor_value_info( + "input_features", onnx.TensorProto.FLOAT, ["batch_size", feature_size, "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "initial_decode_sequence_length"] + ) + + # Outputs: sequences, scores + sequences = onnx.helper.make_tensor_value_info( + "sequences", onnx.TensorProto.INT32, ["batch_size", "num_return_sequences", "decode_sequence_length"] + ) + scores = onnx.helper.make_tensor_value_info( + "scores", onnx.TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) + + # Initializers for the BeamSearch parameters. + max_length_t = onnx.numpy_helper.from_array(np.array(max_length, dtype=np.int32), name="max_length") + min_length_t = onnx.numpy_helper.from_array(np.array(min_length, dtype=np.int32), name="min_length") + num_beams_t = onnx.numpy_helper.from_array(np.array(beam_size, dtype=np.int32), name="num_beams") + num_return_sequences_t = onnx.numpy_helper.from_array(np.array(1, dtype=np.int32), name="num_return_sequences") + length_penalty_t = onnx.numpy_helper.from_array( + np.array(length_penalty, dtype=np.float32), name="length_penalty_as_tensor" + ) + + # The Whisper BeamSearch op expects decoder_input_ids at input index 10. The intervening + # optional inputs (repetition_penalty, vocab_mask, prefix_vocab_mask, attention_mask) are + # left empty. + beam_search = onnx.helper.make_node( + "BeamSearch", + [ + "input_features", + "max_length", + "min_length", + "num_beams", + "num_return_sequences", + "length_penalty_as_tensor", + "", + "", + "", + "", + "decoder_input_ids", + ], + ["sequences", "scores"], + decoder_start_token_id=2, + eos_token_id=2, + early_stopping=0, + model_type=2, + pad_token_id=1, + decoder=decoder_graph, + encoder=encoder_graph, + domain="com.microsoft", + ) + + graph = onnx.helper.make_graph( + [beam_search], + "model", + [input_features, decoder_input_ids], + [sequences, scores], + [max_length_t, min_length_t, num_beams_t, num_return_sequences_t, length_penalty_t], + ) + + model = onnx.helper.make_model( + graph, opset_imports=[onnx.helper.make_opsetid("", 17), onnx.helper.make_opsetid("com.microsoft", 1)] + ) + + return model + + +def create_encoder(vocab_size, embed_dim, num_heads, head_size, feature_size) -> onnx.GraphProto: + # Inputs: encoder_input_ids (audio features, float), decoder_input_ids (int32) + encoder_input_ids = onnx.helper.make_tensor_value_info( + "encoder_input_ids", onnx.TensorProto.FLOAT, ["batch_size", feature_size, "encode_sequence_length"] + ) + decoder_input_ids = onnx.helper.make_tensor_value_info( + "decoder_input_ids", onnx.TensorProto.INT32, ["batch_size", "initial_decode_sequence_length"] + ) + + # Outputs: logits, encoder_hidden_states, present_key_self_0, present_value_self_0, + # present_key_cross_0, present_value_cross_0 + logits = onnx.helper.make_tensor_value_info( + "logits", onnx.TensorProto.FLOAT, ["batch_size", "initial_decode_sequence_length", vocab_size] + ) + encoder_hidden_states = onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ) + present_key_self_0 = onnx.helper.make_tensor_value_info( + "present_key_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_value_self_0 = onnx.helper.make_tensor_value_info( + "present_value_self_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, 1, head_size] + ) + present_key_cross_0 = onnx.helper.make_tensor_value_info( + "present_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + present_value_cross_0 = onnx.helper.make_tensor_value_info( + "present_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ) + + # Initializers + feature_proj = onnx.numpy_helper.from_array( + np.random.randn(feature_size, embed_dim).astype(np.float32), name="feature_proj" + ) + decoder_embeddings = onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="encoder_decoder_embeddings" + ) + final_proj = onnx.numpy_helper.from_array( + np.random.randn(embed_dim, vocab_size).astype(np.float32), name="encoder_final_proj" + ) + num_heads_and_size = onnx.numpy_helper.from_array( + np.array([num_heads, head_size], dtype=np.int64), name="num_heads_and_size" + ) + self_state_shape = onnx.numpy_helper.from_array( + np.array([-1, 1, num_heads, head_size], dtype=np.int64), name="self_state_shape" + ) + + nodes = [ + # encoder_hidden_states = transpose(features)[B, Es, Fs] @ feature_proj[Fs, E] -> [B, Es, E] + onnx.helper.make_node("Transpose", ["encoder_input_ids"], ["features_t"], perm=[0, 2, 1]), + onnx.helper.make_node("MatMul", ["features_t", "feature_proj"], ["encoder_hidden_states"]), + # cross KV: reshape [B, Es, E] -> [B, Es, num_heads, head_size] -> transpose [B, num_heads, Es, head_size] + onnx.helper.make_node("Shape", ["encoder_hidden_states"], ["enc_batch_seq"], end=2), + onnx.helper.make_node("Concat", ["enc_batch_seq", "num_heads_and_size"], ["enc_cross_shape"], axis=0), + onnx.helper.make_node("Reshape", ["encoder_hidden_states", "enc_cross_shape"], ["enc_cross_reshaped"]), + onnx.helper.make_node("Transpose", ["enc_cross_reshaped"], ["present_key_cross_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["enc_cross_reshaped"], ["present_value_cross_0"], perm=[0, 2, 1, 3]), + # decoder hidden states from decoder_input_ids + onnx.helper.make_node("Gather", ["encoder_decoder_embeddings", "decoder_input_ids"], ["decoder_hidden_states"]), + # logits = decoder_hidden_states[B, Ds, E] @ final_proj[E, V] -> [B, Ds, V] + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["enc_hidden_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_states", "enc_hidden_mean"], ["decoder_sum"]), + onnx.helper.make_node("MatMul", ["decoder_sum", "encoder_final_proj"], ["logits"]), + # self KV (length 1): reduce decoder hidden over Ds -> [B, 1, E] -> [B, 1, Hn, Hs] -> [B, Hn, 1, Hs] + onnx.helper.make_node("ReduceMean", ["decoder_sum"], ["self_hidden_mean"], axes=[1]), + onnx.helper.make_node("Reshape", ["self_hidden_mean", "self_state_shape"], ["self_state"]), + onnx.helper.make_node("Transpose", ["self_state"], ["present_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state"], ["present_value_self_0"], perm=[0, 2, 1, 3]), + ] + + graph = onnx.helper.make_graph( + nodes, + "encoder", + [encoder_input_ids, decoder_input_ids], + [ + logits, + encoder_hidden_states, + present_key_self_0, + present_value_self_0, + present_key_cross_0, + present_value_cross_0, + ], + [feature_proj, decoder_embeddings, final_proj, num_heads_and_size, self_state_shape], + ) + return graph + + +def create_decoder(vocab_size, embed_dim, num_heads, head_size, sequence_as_input) -> onnx.GraphProto: + # Inputs: input_ids, encoder_hidden_states, past_key_self_0, past_value_self_0, + # past_key_cross_0, past_value_cross_0 + inputs = [ + onnx.helper.make_tensor_value_info( + "input_ids", onnx.TensorProto.INT32, ["batch_size", "decode_sequence_length" if sequence_as_input else 1] + ), + onnx.helper.make_tensor_value_info( + "encoder_hidden_states", onnx.TensorProto.FLOAT, ["batch_size", "encode_sequence_length", embed_dim] + ), + onnx.helper.make_tensor_value_info( + "past_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "past_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "past_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "past_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "past_key_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ), + onnx.helper.make_tensor_value_info( + "past_value_cross_0", onnx.TensorProto.FLOAT, ["batch_size", num_heads, "encode_sequence_length", head_size] + ), + ] + + outputs = [ + onnx.helper.make_tensor_value_info("logits", onnx.TensorProto.FLOAT, ["batch_size", 1, vocab_size]), + onnx.helper.make_tensor_value_info( + "present_key_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + onnx.helper.make_tensor_value_info( + "present_value_self_0", + onnx.TensorProto.FLOAT, + ["batch_size", num_heads, "present_decode_sequence_length", head_size], + ), + ] + + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(vocab_size, embed_dim).astype(np.float32), name="decoder_embeddings" + ), + onnx.numpy_helper.from_array(np.random.randn(embed_dim, vocab_size).astype(np.float32), name="final_proj"), + onnx.numpy_helper.from_array( + np.array([-1, num_heads, head_size], dtype=np.int64), name="self_state_shape_no_batch" + ), + onnx.numpy_helper.from_array(np.array([-1, 1, embed_dim], dtype=np.int64), name="hidden_mean_shape"), + ] + + nodes = [ + onnx.helper.make_node("Gather", ["decoder_embeddings", "input_ids"], ["decoder_hidden_states"]), + # encoder signal from encoder_hidden_states mean -> [B, 1, E] + onnx.helper.make_node("ReduceMean", ["encoder_hidden_states"], ["enc_hidden_mean"], axes=[1]), + onnx.helper.make_node("Reshape", ["enc_hidden_mean", "hidden_mean_shape"], ["enc_hidden_mean_reshaped"]), + # reduce decoder hidden over the sequence dim -> [B, 1, E] + onnx.helper.make_node("ReduceMean", ["decoder_hidden_states"], ["decoder_hidden_mean"], axes=[1]), + onnx.helper.make_node("Add", ["decoder_hidden_mean", "enc_hidden_mean_reshaped"], ["decoder_sum"]), + onnx.helper.make_node("MatMul", ["decoder_sum", "final_proj"], ["logits"]), + # self KV for this step (length 1) concatenated with the running past + onnx.helper.make_node("Shape", ["decoder_sum"], ["decoder_batch"], end=1), + onnx.helper.make_node( + "Concat", ["decoder_batch", "self_state_shape_no_batch"], ["self_state_shape_dec"], axis=0 + ), + onnx.helper.make_node("Reshape", ["decoder_sum", "self_state_shape_dec"], ["self_state"]), + onnx.helper.make_node("Transpose", ["self_state"], ["single_key_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Transpose", ["self_state"], ["single_value_self_0"], perm=[0, 2, 1, 3]), + onnx.helper.make_node("Concat", ["past_key_self_0", "single_key_self_0"], ["present_key_self_0"], axis=2), + onnx.helper.make_node("Concat", ["past_value_self_0", "single_value_self_0"], ["present_value_self_0"], axis=2), + ] + + graph = onnx.helper.make_graph(nodes, "decoder", inputs, outputs, initializers) + return graph + + +def run_model(model_path, feature_size): + # Imported lazily so model *generation* only depends on `onnx`; running needs `onnxruntime`. + import onnxruntime as ort # noqa: PLC0415 + + ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + encode_length = 5 + # Fixed, deterministic inputs so a C++ regression test can reproduce the exact golden outputs. + input_features = (((np.arange(feature_size * encode_length, dtype=np.float32) % 7) - 3.0) * 0.1).reshape( + 1, feature_size, encode_length + ) + decoder_input_ids = np.array([[2, 5]], dtype=np.int32) + sequences, scores = ort_session.run( + None, {"input_features": input_features, "decoder_input_ids": decoder_input_ids} + ) + print("input_features (flat):", input_features.flatten().tolist()) + print("decoder_input_ids:", decoder_input_ids.tolist()) + print("sequences shape:", sequences.shape) + print("sequences:", sequences.tolist()) + print("scores:", scores.tolist()) + return sequences, scores + + +def arg_parser(): + parser = argparse.ArgumentParser(description="Generate a dummy ONNX model emulating Whisper with BeamSearch op.") + parser.add_argument("--output-path", type=str, default="dummy_whisper.onnx", help="Model output path") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--vocab-size", type=int, default=20, help="Vocab size") + parser.add_argument("--embed-dim", type=int, default=8, help="Embedding dimension") + parser.add_argument("--num-heads", type=int, default=2, help="Number of heads") + parser.add_argument("--head-size", type=int, default=4, help="Head size") + parser.add_argument("--feature-size", type=int, default=8, help="Encoder input feature size") + parser.add_argument("--beam-size", type=int, default=3, help="Beam size") + parser.add_argument("--min-length", type=int, default=1, help="Min length") + parser.add_argument("--max-length", type=int, default=10, help="Max length") + parser.add_argument("--length-penalty", type=float, default=1.1, help="Length penalty") + parser.add_argument("--sequence-as-input", action="store_true", help="Use sequence as input ids") + parser.add_argument( + "--no-run", + action="store_true", + help="Only generate and save the model; skip running it (avoids needing an onnxruntime install)", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = arg_parser() + np.random.seed(args.seed) + + model = create_model( + args.vocab_size, + args.embed_dim, + args.num_heads, + args.head_size, + args.feature_size, + args.beam_size, + args.min_length, + args.max_length, + args.length_penalty, + args.sequence_as_input, + ) + onnx.save(model, args.output_path) + + if not args.no_run: + run_model(args.output_path, args.feature_size) diff --git a/onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_whisper_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ad37e9d265bd81b4e087f86acedad30109c3f65f GIT binary patch literal 7072 zcmc(kc~}!y_s1b@VFZOi0j;1?L5+2{C`;}Mh@dEofM`L9fkcBsNJ&7!4FshkiW}~t zQYdvnT>85FAorj}O5JM3DlV^8Y1OK@6|K^@ALeUKQg&<&Y3fJ&Yk=H z+{-t2T{4QP8o_jqRwxtH;+ZmuIw4-BVwhOTJaLRnF-tAyj-wSV$EJ#eSaGCG602gm zY7gUOoV`M036e<@qGW#Dv7i%U zD48cyiS&2sn=n31dt1h`m7`cCmndbdZM%3}Z#ABfI+z4C9Wi4W-N~|6?e5{v^YLJG z*)#78*VBjH@G6R^_Z45IrntOm^WqjOh);K0^R#70iMr`2M5}D|v|&fL$%E19$ue)_z3p}()2f-ZRB;s)BNNY&C2^I^B{tN-G$dZ4 zP$}b7GJz#ysFcOWiu6Q0k-o@4`}m=ayH&G$q9i6k_V)t#FhmPQBJ0h5|Yz)^r14dU`Y}b~Mtxa_XWm|@vV7n2&E1z$n@2O+h`KA~x4fVuJ zdOb*f!+sLs^b{KWRq&$N3R7Rp@Y0Q6$*kW$M-ux0e*9$voEi0$zV^9}2PYK4&2VZ+k>G-n594$gNquqjEAzf)9bnY_(SK5UWk=b3%^|6uE#MKUrA9=u| z^Cnm&x59N*S+pp~s;u(a2`qiFAN*Eb1b>6YRJhp+22a~TBF)B%-2AiQ(r-LEF;|WJ zo#*h>)Ic(Frw2xx9l>);9Z{gnM9;b_;CNU9`*-k4zfowwTa_#9wQw4iyf_KGm!>ek zJRSSq_oPev&%m5N4q$!vSSX)W2V)EmW0}1Lq^)+s6nQ=EQ)q)JgQw$m>Me9S9)gpS zGD&iUpXTg-HMI?u0eFICp-wsm zXBgkZz~Bpbyv|GbQA0YJkU9&`Cch5{_Fcm2kPOXw-c~HR>4ZHUJ%#38({Sz&<#gov z6d20Y)x#%~!F|hWgk2^OwDcGVa*X(VepjwOTfMWmx{MiOUlyW25%o%}pwGhq`kgT) z6}dg|_sIu^*P?sTJ4rd98nc9WUfV}9Z(pR#3hSxYynoZY7he;f6<4Tu`g#&kWhgS? ziDA{F_YX7G4&f#xZBW%c}2?3wrk*VP}UeeY~=Uz~FpjH3$458*RtZEy&XH1o1& zULTW~DPtkfQ-hz(Tnn=Tib(FxOZ4^ceSjZqDROdiC7FX3k+5M8>8$LJux51#1oRKb zpfziS>DhI(+WiGN`Q0AU^L#qQt>&TGpiJtyIvM53HT07~CkXGf2bfsJ;G*0JSh2JS zG(}aU+Vv@z+jwJj;0}DRs0JmA5SFC(gR*^58nda{#Mojrj>zaK$__Fl66Tih${81^ z_I^pOd&tQ0X};(ueMSb~eG12&_d#au8Q_1Rp*x;XDU=>JnX{e5N`vib9AOx_B5 zdh-Wm6QmKE9?3B@$8|Vt+Vcp;jaUs|?Ojgu^HyT*!M-As?08}ED5-nHu2ZEYmwm~o zDl3|HoyDJDcftw%jWsoPozP;TGuphmDx8tEgjO#pgN3^Xlr2880*)=-0kMY+%gWTn zbnzfcrw;f6|FHMLkNT~`nZCDR$0u0}3%#11VN)L(2?jxM7n> z)9xcp(szdWi8Unq!ZG?#x*o>O$ix>#cCg?(8~A=fF1_FI8ClqOCz`H|f!{)fbkMg( z_+;!dxEU}K6UTjl0vEo4g^B08r#fHTKbk*0y{3<0^RHl7qiaV=@zzDZ*_7Y>8|Kzp z0S#t8;2X3388en?oK?lD=wumV%UU;0)o=7#v#bl7@GeO-oEgg3XFuo|s7XP!55(~J z{4oCxlN;;CxNu=UY!`;d=P`EddyP|rwhGBs^X19M9(U-A!GyO)z$Ej4}p{z4A zl1p)@jug3Wd*|I6hD~JHR&(z^j27v{G>a{^C)0z=$GSs4`dpuMjSS(g+Cdz!gmLG5 zMtAVxF}+#)7W;SibYv)_Qy=`(xI6WY1a?ew?Zs26mF`YTB zAH(S{b(~Vk$+1R-C#dbtm^Af${P3II88^nV`Clt|U|OI>tZpsay0n&UTG8ghY1(GI z5vv)q=Hn)Xs~ZU32wscmWfG}MEKx`s2WAs^ zG?h*xAaoF9jdGV`g_gT-!!1oL(#oA>wP`gqbVLpxM#pK^s---%BDlGVTX&}p6+B*N zpl&UJniS^W3+2!BXt}kf@{AX^LOUJAsGFMpmx$;34l?UNemYocSX(gB_ID$0{s~jH zWf|6@kwaP`<-1wbh3R12+|tOs%`DoAciv(UX~Qu(-Y9=j0irhkI^P|C9)WEc-GYtW z%+UxLZ3d=q3UAV+*Q9@rIemEE5%Qm(gnOsr=;jk~wBY^IuqDw8o<}^Tqkd$d+t^y# zNBx+-D4$Cr?$%+?(>KuSc^6dWr=rP^JBgQCju91Kpkzx4+#5EFx=zd#Iz=48z0*2j z4RkD7=g$eAylEOwFP3rae$?Qwww7pI*vFWuX7sAlT})4X50iOy*r~Kt0|b ze4R6%IGPDz;QS$kclD8Gj9?4>P^7^lKBKX^`XQ|h$N=HikYfAgKdR(Jy@m z;kO%l!kJ&vV1`@)FE59p^`u-nZR>6ltVx34&=)kyj?l<$D{-RTW)u{(;MboF5l-$N z)DHyj%|*~Dd7Y+Ql>*`YkCQb{j#$kmy)kn6}oL3rKC3n|H7(>E7WIRG3^xThu!)+(i7jG zCz*<8kWoII6#P4l^q9LzGjYoc+&SWEY2D#O;mHC!k#~a^oh{u$9uMlQAschhXxv&T zUc_NUsxRc8lVka+!*D(BFnYOLL+ns9(dl&yvF4Qp8d{$fR@UfY-vt$jv$KS2>ND_P z${&_1 zoeYK=qsi`a^Z$b$sTovcGgwqQus0Mb!cpI9Iqo$6g(#||R8#2;BPTv5!ToYHNncIH z^t&f;+x8!5oqIS)-NH$PdIwGk9Y%u2og(wSBGGQP9r21$2(L|Ci}!~dCF`=gK+LX8 z`0dUzm^bAB`gK!c*i{2a&M*MMH3R6kE>U#bk%duVYar1@4eL)h!K>N3=rS%Hlmi+&=`8V$Y*?nV<>7biD4k`s}>x87+-m0i+;JnECT^ zbhICZ_6>s|pxzaO3a69I-LArM-#DSf*;JH1&H2NgpBxcn9 zLLCKpP&ME~^mDHvbNszw(u7#NaY3c2RR)wf^h(3+fvI%b(H(gG`33mC^HEImyG?K0 zG9m@@kHO^AM{(ty^|1Eye(wFe=HsPJkBI-}3FJayca+tw5?1+~qLuUS3uUPTU^DLw zXq35_o5mLz+^olyvG36rTdiT>&1~Wx-3#^XC)3>7ztL^CtZ~`mkD&DQ5wa|D7!H~^ z6Q+l5CnYw45I&?27`lX@vAhfB`2)Cm^n{BpdrPN`e=1zYEq%CrY=i4Mo599T0rcX% z2zo3(3l{hX!$7XKA9pVVk;OjudFuzm`0KgUp}v5)cp1@)1rW*WdtUu(Vui<{Vr>G@bEcvnlwK$g6a$aCkrGEU(Wj66 z(co1mp8dEx-mf(&yR*?41p+>g!!aH5%h#d+=vaWkG4$(qDLdy?XiOo!~zLW*$*Na4hNv?5v^&8^q!7Wa`}Z?M9Z8OevB@W``?0JWB{XE>i)L@ zwUlJVn91U}9g{X}WXbKoaPl!mtWHwOnsAA;Q%a<7?X;`xuI5p3v4dlywTe4#rkXWs zbP{y6F!pS Date: Sat, 27 Jun 2026 00:25:46 +0200 Subject: [PATCH 14/19] [CPU] Enable pre-packed weights sharing for MatMulNBits (#29163) ### Description Enable pre-packed weights sharing for `MatMulNBits` operator on CPU. When performing DQ + MatMul -> MatMulNBits fusion, the original weight names are lost, so the standard `AddInitializer` approach does not work. To overcome this, introduced the option for graph optimization pass to tag weights which are sharable across sessions (hashing the content and matching it across the sessions). ### Motivation and Context For executing ASG SLMs on CPU - there are two sessions, one for prefill stage and for decode stage (due to different shapes and session options). With this change, storing the weights in memory twice is avoided. The first sessions pre-packs the weights which the second session can reuse. Confirmed memory reduction through the WPA memory traces. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- include/onnxruntime/core/graph/graph.h | 32 +++ .../cpu/quantization/matmul_nbits.cc | 201 +++++++++++---- onnxruntime/core/framework/session_state.cc | 31 ++- .../core/optimizer/dq_matmulnbits_fusion.cc | 20 +- .../optimizer/matmul_nbits_sharing_identity.h | 55 ++++ .../selectors_actions/qdq_actions.cc | 18 +- .../test/contrib_ops/matmul_2bits_test.cc | 115 +++++++++ .../test/contrib_ops/matmul_4bits_test.cc | 64 +++++ .../test/contrib_ops/matmul_8bits_test.cc | 63 +++++ .../matmul_nbits_prepack_sharing_test_util.cc | 108 ++++++++ .../matmul_nbits_prepack_sharing_test_util.h | 33 +++ .../optimizer/dq_matmulnbits_fusion_test.cc | 167 +++++++++++++ .../qdq_matmulnbits_transformer_test.cc | 235 ++++++++++++++++++ 13 files changed, 1082 insertions(+), 60 deletions(-) create mode 100644 onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h create mode 100644 onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc create mode 100644 onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 815fc6aa69a60..a6d8eaecad0c0 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1608,6 +1608,34 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return *prepacked_weights_for_graph_; } + // Tags a fusion-generated initializer (whose name is not stable across sessions) with a stable, + // content-derived identity that SessionState uses to key cross-session pre-pack sharing. + // + // Single-consumer invariant: a MatMulNBits packed buffer folds in the *consuming* node's + // scales/zero_points/attributes, not B alone, so this id is meaningful only for a B initializer that + // has exactly one consumer. The DQ->MatMulNBits producers guarantee that -- each generated B has a + // unique name with a single consumer, and the fusion bails when the source weight/scale is shared (the + // DQMatMulNotConvertedToMatMulNBits_SharedWeight case). If a future change ever tags a multi-consumer + // initializer whose consumers differ in scales/zp/attrs, they would compute different ids for the same + // name and the last writer would silently mis-share. Enforce that a name is never re-tagged with a + // conflicting id so the invariant survives later refactors. + void SetSharedPrepackInitializerId(const std::string& initializer_name, std::string share_id) { + auto it = generated_shared_prepack_ids_.find(initializer_name); + if (it != generated_shared_prepack_ids_.end()) { + ORT_ENFORCE(it->second == share_id, "MatMulNBits pre-pack sharing id for initializer '", + initializer_name, "' was re-tagged with a different id; the single-consumer invariant ", + "is violated (a multi-consumer weight whose consumers differ in scales/zp/attrs)."); + return; + } + generated_shared_prepack_ids_.emplace(initializer_name, std::move(share_id)); + } + + // Returns the sharing identity for a generated initializer, or nullptr if it was not tagged. + const std::string* GetSharedPrepackInitializerId(const std::string& initializer_name) const { + auto it = generated_shared_prepack_ids_.find(initializer_name); + return it == generated_shared_prepack_ids_.end() ? nullptr : &it->second; + } + /** Returns the Node containing the GraphProto for this Graph instance if IsSubgraph is true */ const Node* ParentNode() const { return parent_node_; } @@ -2011,6 +2039,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // This is optional due to delayed construction. std::optional prepacked_weights_for_graph_; + // Maps a fusion-generated initializer name to its cross-session sharing identity. + // See SetSharedPrepackInitializerId. + InlinedHashMap generated_shared_prepack_ids_; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // Runtime optimization storage. // Note: runtime_optimizations_ == *runtime_optimizations_ptr_ and must be initialized diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 6bd1690fca815..162d7257d0a4c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #include +#include #include #include @@ -162,6 +163,13 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_{}; size_t packed_b_size_{0}; + // True once PrePack(InputIndex::B) has folded the scales and (constant) zero points into packed_b_, + // leaving the CompInt8 buffer fully packed and compute-ready. Pre-packed weight sharing + // content-hashes the buffer right after the B PrePack returns, so everything that affects the + // packed bytes (in particular the block sum / BZpCorr, which depend on the zero points) must be + // folded in by then. Once set, the later scales/zero_point PrePack calls must not pack again: the + // CompInt8 packing is single-shot, and the buffer may by then be one shared from another session. + bool packed_b_finalized_{false}; IAllocatorUniquePtr scales_fp32_{}; IAllocatorUniquePtr bias_fp32_{}; @@ -227,7 +235,6 @@ template Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { - ORT_UNUSED_PARAMETER(prepacked_weights); is_packed = false; if (has_g_idx_) { return Status::OK(); @@ -308,10 +315,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All static_cast(packed_b_.get()), threadpool_ptr); - if (prepacked_weights != nullptr) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } + // Do not append packed_b_ here. Both the LUT and non-LUT branches share the single append + // after this if/else, so each records exactly one buffer. Appending here as well would move + // packed_b_ out now and then have the shared append record a second, moved-from/null buffer + // with a non-zero packed_b_size_. PrePackedWeights::GetHash() skips null buffers so sharing + // appears to work, but the prepacked-blob save path writes buffer_sizes_[i] bytes from + // buffers_[i].get() and would dereference that null pointer. } else { // For HQNBIT_CompInt8, route through SQNBIT_CompInt8 for sizing and packing. // This gets KleidiAI-sized buffer when available for 4-bit and packs B+scales correctly. @@ -341,24 +350,64 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + // The framework content-hashes this packed buffer to deduplicate pre-packed weights, both + // within a session and across sessions (the shared container). The session-state prepack pass + // (SessionState::PrepackConstantInitializedTensors) passes a non-null prepacked_weights on both + // the container and the default single-session paths, so this zero-fill runs on essentially + // every prepack at load, not only when a sharing container is configured -- the guard below + // only skips a caller that asks for no cacheable buffer. The pack routines need not write every + // byte (alignment padding between the CompInt8 sub-regions; any layout could gain padding) and + // the reserve allocation is not zero-filled, so the hash would otherwise depend on uninitialized + // bytes. Zeroing the whole buffer is a one-time O(packed_b_size_) load cost (the pack overwrites + // the data regions, leaving only padding zeroed); inference is unaffected. + if (prepacked_weights != nullptr) { + std::memset(packed_b_.get(), 0, packed_b_size_); + } MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, effective_compute_type, qptr, packed_b_.get(), scale_ptr, has_zp_input_, nullptr, threadpool_ptr, &mlas_backend_kernel_selector_config_); -#if defined(MLAS_TARGET_ARM64) - // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales and zero_points are accessible. - if (compute_type_ == HQNBIT_CompInt8 && nbits_ == 4 && has_zp_input_ && scales_fp32_ && - MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, SQNBIT_CompInt8, has_zp_input_, &mlas_backend_kernel_selector_config_)) { - const Tensor* zp_tensor = nullptr; - OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); - if (zp_tensor != nullptr) { - auto zptr = zp_tensor->Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + // Fold the scales and (constant) zero points into packed_b_ now, during the B PrePack, instead + // of deferring them to the later scales/zero_points PrePack calls. Pre-packed weight sharing + // content-hashes this buffer immediately after the B PrePack returns; the CompInt8 block sum + // (and the KleidiAI BZpCorr) is a function of the zero points, so they must already be folded + // in for the hash to reflect them. Otherwise two initializers with identical B and scales but + // different zero points would hash equal and the second would wrongly adopt the first's buffer + // and silently compute wrong results. scales and zero_points are constant initializers, so they + // are available here. The B pack above only partially populates the buffer (on x64 the block sum + // is deferred; on ARM64 8-bit the scales are ignored during B packing), so issue one more pack + // call with QuantBData == nullptr to finalize it. This is byte-identical to the staged + // scales + zero_points packing it replaces. + bool finalize_scale_zp_into_packed_b = effective_compute_type == SQNBIT_CompInt8 && scale_ptr != nullptr; +#if !defined(MLAS_TARGET_AMD64_IX86) + // On ARM64 the scales/zero points are folded into B only for 8-bit, or for 4-bit when MLAS bakes + // them in (KleidiAI). For 4-bit non-KleidiAI they are applied at compute time and must not be + // passed to the packing routine, which would dereference the null QuantBData buffer. + finalize_scale_zp_into_packed_b = + finalize_scale_zp_into_packed_b && + (nbits_ == 8 || MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, effective_compute_type, + has_zp_input_, &mlas_backend_kernel_selector_config_)); +#endif + if (finalize_scale_zp_into_packed_b) { + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + zp_ptr = zp_tensor->Data(); + } } + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, effective_compute_type, nullptr /*QuantBData*/, + packed_b_.get(), scale_ptr, has_zp_input_, zp_ptr, nullptr, + &mlas_backend_kernel_selector_config_); + packed_b_finalized_ = true; } -#endif // MLAS_TARGET_ARM64 } is_packed = true; + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } } else if (compute_type_ == SQNBIT_CompInt8 && !prefer_lut_gemm_) { // Packing scales and zero points // Guard: for LUT-eligible nodes, scales/ZP are already packed inside @@ -376,7 +425,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All }(); if (should_pack_scale_and_zp_inputs) { - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + // packed_b_ is already finalized during the B PrePack (scales and zero points folded in there so + // the sharing content hash captures them), so skip packing here. The CompInt8 packing is + // single-shot and packed_b_ may now be a buffer shared from another session. + if (input_idx == InputIndex::scales && packed_b_ != nullptr && !packed_b_finalized_) { auto sptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -384,7 +436,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } // Packing zero_point - if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -410,13 +462,21 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } scales_are_packed_ = true; - is_packed = true; - // For KleidiAI asymmetric 4-bit path: compute BZpCorr now while scales are still accessible. - // After this PrePack returns is_packed=true, ORT may erase scales from the constant - // input table (use count drops to 0), making them unavailable in later PrePack calls. - // Zero points haven't been PrePacked yet so they are still accessible. - if (has_zp_input_ && nbits_ == 4) { + // The scales were folded into packed_b_ during the B PrePack, so there is no separate packed + // scales buffer to cache or share. Report is_packed = false (as the x64 path already does for + // the scales input) so the framework does not engage pre-packed weight sharing for scales. + // Engaging it would require pushing a placeholder buffer, but the real scales live inside + // packed_b_ so the placeholder would be null - and PrePackedWeights::GetHash() skips null + // buffers, making the scales container key identical for every MatMulNBits node. That would + // falsely increment the shared-weights counter for unrelated nodes without sharing any real + // buffer. The quantized weight B (which carries the folded-in scales) is shared on its own. + is_packed = false; + + // BZpCorr was already folded into packed_b_ during the B PrePack (so the sharing content hash + // captures the zero points), so re-folding it here must be skipped: the packing is single-shot + // and packed_b_ may now be a buffer shared from another session. + if (has_zp_input_ && nbits_ == 4 && !packed_b_finalized_) { const Tensor* zp_tensor = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); if (zp_tensor != nullptr) { @@ -457,7 +517,14 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // BZpCorr was already computed during B packing in Step 1 (if applicable). scales_are_packed_ = true; - is_packed = true; + + // The scales were folded into the packed B buffer during the B PrePack, so there is no + // separate packed scales buffer to cache or share. Report is_packed = false (mirroring the + // x64 path and the SQNBIT_CompInt8 path above) so the framework does not engage sharing for + // the scales input; engaging it would push a null placeholder whose content hash is identical + // for every node, falsely incrementing the shared-weights counter without sharing any real + // buffer. + is_packed = false; } else #endif // MLAS_TARGET_ARM64 { @@ -471,7 +538,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // Pack scales separately only for 8-bit. For 4-bit on ARM64, scales are already packed // during B packing or used as a raw pointer at compute time (matching standard // SQNBIT_CompInt8 behavior where should_pack_scale_and_zp_inputs = (nbits_ == 8) on ARM64). - if (nbits_ == 8) { + // Skip when packed_b_ was already finalized during the B PrePack (scales/zero points folded + // in there for the sharing content hash); it may now be a buffer shared from another session. + if (nbits_ == 8 && !packed_b_finalized_) { MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -482,7 +551,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // Pack zero_points separately only for 8-bit (matching standard SQNBIT_CompInt8 behavior). // For 4-bit, zero_points are passed directly in data params or handled via KleidiAI BZpCorr. - if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && nbits_ == 8) { + if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && nbits_ == 8 && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, SQNBIT_CompInt8, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -540,8 +609,6 @@ template <> Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { - ORT_UNUSED_PARAMETER(prepacked_weights); - if (input_idx == InputIndex::scales || input_idx == InputIndex::bias) { auto sptr = tensor.Data(); auto tensor_size = static_cast(tensor.Shape().Size()); @@ -565,8 +632,12 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); - if (scales && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, - has_zp_input_, &mlas_backend_kernel_selector_config_)) { + // Convert the constant fp16 scales to fp32 up front so they (and the zero points) can be folded + // into packed_b_ during this B PrePack, mirroring the primary float PrePack above. Pre-packed + // weight sharing content-hashes the buffer right after this B PrePack returns, so for CompInt8 + // everything that affects the packed bytes (the scales, and the block sum / KleidiAI BZpCorr that + // depend on the zero points) must be folded in by now. + if (scales && compute_type_ == SQNBIT_CompInt8) { auto sptr = scales->Data(); auto scales_size = static_cast(scales->Shape().Size()); auto ptr = IAllocator::MakeUniquePtr(alloc, scales_size, true); @@ -581,25 +652,55 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + // See the primary PrePack() above: SessionState::PrepackConstantInitializedTensors passes a + // non-null prepacked_weights on both the container and the default single-session paths, so this + // zero-fill runs on essentially every prepack at load (the guard only skips a caller that asks for + // no cacheable buffer). It keeps the dedup content hash reproducible regardless of bytes the pack + // leaves uninitialized (alignment padding), for any compute type. One-time O(packed_b_size_) load + // cost; inference is unaffected. + if (prepacked_weights != nullptr) { + std::memset(packed_b_.get(), 0, packed_b_size_); + } MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); -#if defined(MLAS_TARGET_ARM64) - // For KleidiAI asymmetric 4-bit path: compute BZpCorr during B packing. - // The fp16 specialization packs B here (with scales already converted to fp32), - // so we also compute BZpCorr now while both scales and zero_points are accessible. - if (has_zp_input_ && nbits_ == 4 && scales_fp32_ != nullptr) { - const Tensor* zp_tensor = nullptr; - OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); - if (zp_tensor != nullptr) { - auto zptr = zp_tensor->Data(); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), - scales_fp32_.get(), has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); + // Fold the scales and (constant) zero points into packed_b_ now (see the primary PrePack above): + // the CompInt8 block sum and the KleidiAI BZpCorr depend on the zero points, so they must be + // folded in before the sharing content hash is taken. Otherwise two initializers with identical B + // and scales but different zero points would hash equal and the second would wrongly adopt the + // first's buffer. The B pack above only partially populates the buffer, so issue one more pack + // call with QuantBData == nullptr to finalize it. This is byte-identical to the staged + // scales + zero_points packing it replaces. + bool finalize_scale_zp_into_packed_b = compute_type_ == SQNBIT_CompInt8 && scales_fp32_ != nullptr; +#if !defined(MLAS_TARGET_AMD64_IX86) + // On ARM64 the scales/zero points are folded into B only for 8-bit, or for 4-bit when MLAS bakes + // them in (KleidiAI). For 4-bit non-KleidiAI they are applied at compute time and must not be + // passed to the packing routine, which would dereference the null QuantBData buffer. + finalize_scale_zp_into_packed_b = + finalize_scale_zp_into_packed_b && + (nbits_ == 8 || MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, + has_zp_input_, &mlas_backend_kernel_selector_config_)); +#endif + if (finalize_scale_zp_into_packed_b) { + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zp_tensor = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor); + if (zp_tensor != nullptr) { + zp_ptr = zp_tensor->Data(); + } } + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr /*QuantBData*/, + packed_b_.get(), scales_fp32_.get(), has_zp_input_, zp_ptr, nullptr, + &mlas_backend_kernel_selector_config_); + packed_b_finalized_ = true; } -#endif // MLAS_TARGET_ARM64 - is_packed = true; + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } } else if (compute_type_ == SQNBIT_CompInt8) { bool should_pack_scale_and_zp = [&]() { #if defined(MLAS_TARGET_AMD64_IX86) @@ -610,11 +711,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou }(); if (should_pack_scale_and_zp) { - if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr && !packed_b_finalized_) { MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), scales_fp32_.get(), has_zp_input_, nullptr, nullptr, &mlas_backend_kernel_selector_config_); is_packed = false; - } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr && !packed_b_finalized_) { auto zptr = tensor.Data(); MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr, &mlas_backend_kernel_selector_config_); @@ -635,6 +736,11 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& used_shared_buffers = false; if (input_idx == InputIndex::B && !prepacked_buffers.empty()) { + // The buffer handed back is fully finalized: the producing session folded the scales and zero + // points (block sums / KleidiAI BZpCorr) into it during its PrePack(B), which is also when this + // kernel set packed_b_finalized_ on its own (identical) B PrePack. The later scale/zero-point + // PrePack calls already skip the staged packing whenever packed_b_finalized_ is set, so simply + // adopt the shared buffer here - no extra bookkeeping is needed to avoid re-folding into it. packed_b_ = std::move(prepacked_buffers[0]); used_shared_buffers = true; @@ -643,6 +749,9 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); } } + // Only the quantized weight B yields a separately cached pre-packed buffer. The scales (and zero + // points) are folded into packed_b_ during the B PrePack and reported with is_packed = false, so + // the framework never asks this kernel to adopt a shared buffer for them. return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index ad92ddd797d3a..241eb8362ddfa 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -498,8 +498,15 @@ Status SessionState::PrepackConstantInitializedTensors( auto iter = initializers_to_share_map.find(input_name); bool is_shared_initializer = (iter != initializers_to_share_map.end()); - // Caching pre-packed weights is limited to shared initializers associated with the CPU EP for now - if (is_shared_initializer && should_cache_prepacked_weights_for_shared_initializers && + // CPU EP only. An initializer joins the shared pre-packed container either when it was + // registered via OrtApi::AddInitializer (is_shared_initializer) or when a graph transformer + // tagged this synthesized initializer with a sharing identity. Only the tag's *presence* + // matters here: it is the enrollment signal. The container key below is the packed-bytes + // hash, never the tag value (see the rationale at the key computation). + const bool enroll_tagged_initializer = + (st->graph_.GetSharedPrepackInitializerId(input_name) != nullptr); + if ((is_shared_initializer || enroll_tagged_initializer) && + should_cache_prepacked_weights_for_shared_initializers && node.GetExecutionProviderType() == kCpuExecutionProvider) { // caching of pre-packed weights' turned ON @@ -530,12 +537,18 @@ Status SessionState::PrepackConstantInitializedTensors( // TODO: Check if some version of the ONNX IR allows op_type to be empty ORT_ENFORCE(!op_type.empty(), "The op type of a node cannot be empty"); - // The key for the pre-packed weights container lookup is the op_type + hash of the prepacked-weight - // that we just got by invoking PrePack() on this kernel. - + // Key by the packed-bytes hash (op_type + a hash of the packed buffer), exactly as the + // AddInitializer path does, so only byte-identical packed buffers are ever shared. The + // tag is solely the enrollment signal that opted this fusion-generated initializer into + // the container; it must NOT be used as the key, because it is derived from the + // *unpacked* initializer content and so cannot distinguish packings that differ by node + // options/attributes that change the packed layout (e.g. mlas.use_lut_gemm or a CPU + // backend-selector difference). Two sessions that share a container but differ in such an + // option compute the same tag yet produce different packed bytes; keying by the packed + // bytes gives them distinct keys and prevents reusing an incompatible buffer + // (wrong results/crash). const std::string prepacked_weights_container_key = - GenerateKeyForPrepackedWeightsMap(op_type, - weights_to_be_filled_in); + GenerateKeyForPrepackedWeightsMap(op_type, weights_to_be_filled_in); bool container_contains_packed_weight = prepacked_weights_container_->HasWeight( prepacked_weights_container_key); @@ -624,11 +637,9 @@ Status SessionState::PrepackConstantInitializedTensors( is_packed, &weights_to_be_filled_in)); - // Some kernels (matmul_nbits and non-CPU related kernels) do not share their pre-packed results + // Some kernels (non-CPU related kernels) do not share their pre-packed results // even though they set is_packed = true so we leave it up to them. // We can change their behavior if we wish do so in a separate PR - // XXX: Interestingly enough, matmul_nbits does accept shared pre-packs, but does not - // produce them. if (is_packed && !weights_to_be_filled_in.buffers_.empty()) { const auto& op_type = node.OpType(); const std::string prepacked_weights_container_key = GenerateKeyForPrepackedWeightsMap( diff --git a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc index f3956d5e9e0f3..07fccef64fee1 100644 --- a/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc +++ b/onnxruntime/core/optimizer/dq_matmulnbits_fusion.cc @@ -12,6 +12,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/matmul_nbits_sharing_identity.h" #include "core/optimizer/utils.h" #include @@ -447,7 +448,6 @@ std::vector CollectDirectDQMatches( return direct_matches; } -// --------------------------------------------------------------------------- // Pattern 1 rewriting: DQ+Reshape+Transpose+[Cast]+MatMul/Gemm -> MatMulNBits // --------------------------------------------------------------------------- @@ -569,6 +569,10 @@ void ApplyReshapeTransposeFusions( zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } + // Cross-session sharing identity for the generated weight group; computed before the tensors move. + const std::string share_id = + ComputeMatMulNBitsSharingId(weight_dst, scale_dst, zp_dst, N, K, block_size, /*bits*/ 4, accuracy_level); + NodeAttributes mnb_attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); @@ -578,7 +582,10 @@ void ApplyReshapeTransposeFusions( std::vector mnb_inputs; mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); - mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + NodeArg& b_weight_arg = graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + mnb_inputs.push_back(&b_weight_arg); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); @@ -749,6 +756,10 @@ void ApplyDirectDQFusions( zp_mnb_tp.emplace(utils::TensorToTensorProto(*zp_dst, zp_dst_name, true)); } + // Cross-session sharing identity for the generated weight group; computed before the tensors move. + const std::string share_id = + ComputeMatMulNBitsSharingId(weight_dst, scale_dst, zp_dst, N, K, block_size, /*bits*/ 4, accuracy_level); + NodeAttributes mnb_attrs; utils::SetNodeAttribute(utils::MakeAttribute("K", K), mnb_attrs); utils::SetNodeAttribute(utils::MakeAttribute("N", N), mnb_attrs); @@ -758,7 +769,10 @@ void ApplyDirectDQFusions( std::vector mnb_inputs; mnb_inputs.push_back(const_cast(mm_node->InputDefs()[0])); - mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst))); + NodeArg& b_weight_arg = graph_utils::AddInitializerWithOrtValue(graph, weight_mnb_tp, std::move(weight_dst)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + mnb_inputs.push_back(&b_weight_arg); mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, scale_mnb_tp, std::move(scale_dst))); if (zp_mnb_tp) { mnb_inputs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, zp_mnb_tp.value(), std::move(*zp_dst))); diff --git a/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h b/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h new file mode 100644 index 0000000000000..829a78d3ebcf1 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_sharing_identity.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/framework/murmurhash3.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { + +// Stable, content-derived identity for a fusion-generated MatMulNBits weight group, used to share its +// pre-packed buffer across sessions. The id is identical for the same model in any session and differs +// whenever a semantic input differs. accuracy_level is hashed so buffers packed for different compute +// types never collide. Pass zero_point only when it is an actual kernel input. +inline std::string ComputeMatMulNBitsSharingId(const Tensor& weight, const Tensor& scale, + const std::optional& zero_point, + int64_t N, int64_t K, int64_t block_size, + int64_t bits, int64_t accuracy_level) { + // MurmurHash3 fmix64 finalizer: a bijection that avalanches a 64-bit value so each input bit affects + // every output bit. + auto fmix64 = [](uint64_t x) { + x ^= x >> 33; + x *= 0xff51afd7ed558ccdULL; + x ^= x >> 33; + x *= 0xc4ceb9fe1a85ec53ULL; + x ^= x >> 33; + return x; + }; + // Fold each segment's full 128-bit hash into the 64-bit accumulator and carry the whole accumulator + // forward, not just a 32-bit seed. Every bit of weight/scale/zero_point/params therefore reaches the + // id, so collision resistance tracks the 64-bit id width instead of the ~2^32 a chain forwarding only + // hash[0] would give. A collision would let one weight group adopt another's already-packed buffer and + // silently compute a wrong result, so the wider margin is worth the few extra mixing ops. + uint64_t acc = 0; + auto mix = [&acc, &fmix64](const void* data, size_t len) { + uint32_t h[4]; + MurmurHash3::x86_128(data, len, static_cast(acc), h); + acc = fmix64(acc ^ ((static_cast(h[1]) << 32) | h[0])); + acc = fmix64(acc ^ ((static_cast(h[3]) << 32) | h[2])); + }; + mix(weight.DataRaw(), weight.SizeInBytes()); + mix(scale.DataRaw(), scale.SizeInBytes()); + if (zero_point) { + mix(zero_point->DataRaw(), zero_point->SizeInBytes()); + } + const int64_t params[] = {N, K, block_size, bits, accuracy_level}; + mix(params, sizeof(params)); + return "MatMulNBits.DQ:" + std::to_string(acc); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index b9d7e898157bd..6bd5e157d8b65 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -7,6 +7,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/initializer.h" +#include "core/optimizer/matmul_nbits_sharing_identity.h" #include "core/graph/node_attr_utils.h" #include "core/graph/graph_utils.h" #include "core/framework/tensorprotoutils.h" @@ -646,8 +647,23 @@ Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits( graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed)); + // Cross-session sharing identity for the generated B weight; computed before it is moved. + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* weight_shape = weight_arg->Shape(); + ORT_RETURN_IF_NOT(weight_shape != nullptr && weight_shape->dim_size() >= 2, + "Weight shape unavailable for DQ node ", dq_node->Name()); + const int64_t bits = DQWeightBits(weight_arg->TypeAsProto()->tensor_type().elem_type()); + const std::string share_id = ComputeMatMulNBitsSharingId( + transposed.weight, transposed.scale, transposed.zero_point, + weight_shape->dim(1).dim_value(), weight_shape->dim(0).dim_value(), + effective_bs, bits, accuracy_level_); + auto& input_defs = replacement_node.MutableInputDefs(); - input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight))); + NodeArg& b_weight_arg = + graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight)); + // Tag the generated B weight for cross-session pre-pack sharing. + graph.SetSharedPrepackInitializerId(b_weight_arg.Name(), share_id); + input_defs.push_back(&b_weight_arg); replacement_node.MutableInputArgsCount().push_back(1); input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.scale_proto, std::move(transposed.scale))); diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 9deb064a90853..8e133caa15d55 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -3,6 +3,7 @@ #ifndef ORT_MINIMAL_BUILD +#include #include #include "gtest/gtest.h" @@ -26,6 +27,9 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" +#include "core/graph/model.h" +#include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test/test_environment.h" #include "core/providers/webgpu/webgpu_provider_options.h" #ifdef USE_WEBGPU #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" @@ -461,6 +465,117 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256_Bias) { TestMatMul2BitsLutGemm(32, 256, 256, 32, /*has_zero_point=*/true, /*has_bias=*/true); } +// Regression test for the LUT GEMM pre-pack + prepacked-save path. A 2-bit MatMulNBits node pre-packed +// via the LUT path must record its packed B buffer exactly once. A prior bug appended packed_b_ twice +// on the LUT path (inside the LUT branch and again in the shared append at the end of the B block), so +// the second entry was a moved-from/null buffer paired with a non-zero packed_b_size_. The pre-packed +// content hash skips null buffers, so cross-session sharing appeared to work, but saving pre-packed +// initializers iterates every recorded buffer and writes buffer_sizes_[i] bytes from buffers_[i].get(), +// dereferencing the null pointer when mlas.use_lut_gemm=1. This drives mlas.use_lut_gemm=1 together with +// session.save_external_prepacked_constant_initializers=1 and a non-empty optimized_model_filepath, and +// asserts that initialization (which performs the save) and a subsequent run both succeed. +TEST(MatMulNBitsLutGemm, Float32_2Bits_PrepackSaveDoesNotCrash) { + constexpr int64_t M = 1, N = 128, K = 128, block_size = 32; + if (!MlasIsLutGemmAvailable(static_cast(N), static_cast(K), 2, static_cast(block_size))) { + GTEST_SKIP() << "LUT GEMM not available on this platform"; + } + + // Quantize random weights into valid 2-bit MatMulNBits B/scales/zero_points initializers. + RandomValueGenerator random{1234}; + std::vector b_fp32(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + int q_rows = 0, q_cols = 0; + MlasBlockwiseQuantizedShape(static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), q_rows, q_cols); + size_t q_data_size_in_bytes = 0, q_scale_size = 0, q_zp_size_in_bytes = 0; + MlasBlockwiseQuantizedBufferSizes(static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + std::vector b_data(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); + + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); + MlasQuantizeBlockwise(b_data.data(), scales.data(), zp.data(), b_fp32.data(), + static_cast(block_size), /*columnwise*/ true, + static_cast(K), static_cast(N), + static_cast(N), tp); + + // Single-node MatMulNBits model: A is a runtime input; B/scales/zero_points are constant initializers + // (so they are pre-packed at session initialization). + const int64_t k_blocks = (K + block_size - 1) / block_size; + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("matmul_2bits_lut_prepack_save", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + ONNX_NAMESPACE::TypeProto float_2d; + float_2d.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); + float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(M); + float_2d.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(K); + NodeArg* A = &graph.GetOrCreateNodeArg("A", &float_2d); + NodeArg* Y = &graph.GetOrCreateNodeArg("Y", nullptr); + + NodeArg* B = builder.MakeInitializer( + {static_cast(q_cols), k_blocks, static_cast(q_rows) / k_blocks}, b_data); + NodeArg* scales_arg = builder.MakeInitializer({N, static_cast(q_scale_size) / N}, scales); + NodeArg* zero_points = + builder.MakeInitializer({N, static_cast(q_zp_size_in_bytes) / N}, zp); + + Node& node = builder.AddNode("MatMulNBits", {A, B, scales_arg, zero_points}, {Y}, kMSDomain); + node.AddAttribute("K", K); + node.AddAttribute("N", N); + node.AddAttribute("block_size", block_size); + node.AddAttribute("bits", static_cast(QBits)); + node.AddAttribute("accuracy_level", static_cast(0)); + + graph.SetOutputs(std::vector{Y}); + ASSERT_STATUS_OK(graph.Resolve()); + + std::string model_bytes; + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); + + // Save the optimized model + pre-packed initializers into a unique temp dir. Writing the prepacked + // initializers is the path that dereferenced the duplicate null buffer before the fix. + namespace fs = std::filesystem; + const fs::path tmp_dir = fs::temp_directory_path() / "ort_matmul2bits_lut_prepack_save_test"; + std::error_code ec; + fs::remove_all(tmp_dir, ec); + ASSERT_TRUE(fs::create_directories(tmp_dir, ec)) << ec.message(); + const fs::path optimized_model_path = tmp_dir / "optimized.onnx"; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsSavePrePackedConstantInitializers, "1")); + so.optimized_model_filepath = optimized_model_path.native(); + + std::vector fetches; + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + // Initialization performs the LUT pre-pack and writes the optimized model with external + // pre-packed initializers. Before the fix this dereferenced the duplicate null packed buffer. + ASSERT_STATUS_OK(session.Initialize()); + + auto cpu_allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + std::vector a_data = random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f); + OrtValue a_value; + CreateMLValue(cpu_allocator, AsSpan({M, K}), a_data, &a_value); + NameMLValMap feeds{{"A", a_value}}; + + ASSERT_STATUS_OK(session.Run(RunOptions{}, feeds, std::vector{"Y"}, &fetches)); + } + + ASSERT_EQ(fetches.size(), static_cast(1)); + EXPECT_TRUE(fs::exists(optimized_model_path)); + + fs::remove_all(tmp_dir, ec); +} + // Float zero point tests — directed QAD scenario (zp=1.5) void RunTest2BitsFloatZP(int64_t M, int64_t N, int64_t K, int64_t block_size, float zp_value) { RandomValueGenerator random{1234}; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 07b275b813aa7..bedf035d320f8 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -22,10 +22,14 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" #include "core/providers/webgpu/webgpu_provider_options.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/util/include/test/test_environment.h" extern std::unique_ptr ort_env; @@ -87,6 +91,10 @@ struct TestOptions { bool legacy_shape{false}; // for backward compatibility + // When set, RunTest validates cross-session sharing of the pre-packed weights instead of doing a + // single run. The model is run in two sessions that use the same pre-packed weights container. + std::optional prepack_sharing_mode{}; + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -269,6 +277,13 @@ void RunTest(const TestOptions& opts, test.SetOutputRelErr("Y", *opts.output_rel_error); } + if (opts.prepack_sharing_mode.has_value()) { + // Pre-packed weight sharing is a CPU-EP-only feature; the helper runs the model on the CPU EP + // in two sessions and validates the sharing counters. + CheckSharedPrepackedWeights(test, *opts.prepack_sharing_mode, {N, k_blocks, blob_size}, input1_vals); + return; + } + if (!explicit_eps.empty()) { test.ConfigEps(std::move(explicit_eps)); } @@ -597,6 +612,55 @@ TEST(MatMulNBits, Float32_4b_Accuracy4_Batch) { RunTest(opts); } +#ifndef ENABLE_TRAINING +// Pre-packing (and therefore cross-session sharing of pre-packed weights) is disabled in a full +// training build, so there is nothing to exercise there. + +namespace { +// Builds a representative MatMulNBits TestOptions for the pre-packed weight sharing tests. +TestOptions MakeSharingTestOptions(int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, + bool has_zero_point, bool has_bias, PrepackSharingMode mode) { + TestOptions opts{}; + opts.M = 8; + opts.N = N; + opts.K = K; + opts.block_size = block_size; + opts.accuracy_level = accuracy_level; + opts.has_zero_point = has_zero_point; + opts.zp_is_4bit = true; + opts.has_bias = has_bias; + opts.prepack_sharing_mode = mode; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + return opts; +} +} // namespace + +// Legacy sharing path: the weight B is registered as a shared initializer via +// SessionOptions::AddInitializer. Covers float and float16 activations, symmetric/asymmetric, +/- bias. +TEST(MatMulNBits, SharedPrepackedWeights_AddInitializer) { + for (bool has_zero_point : {false, true}) { + for (bool has_bias : {false, true}) { + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, has_zero_point, + has_bias, PrepackSharingMode::kAddInitializer)); + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, has_zero_point, + has_bias, PrepackSharingMode::kAddInitializer)); + } + } +} + +// Negative control: with the shared container present but neither opt-in mechanism enabled, no +// pre-packed weights are shared across sessions. +TEST(MatMulNBits, SharedPrepackedWeights_NotSharedWithoutOptIn) { + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, /*has_zero_point*/ true, + /*has_bias*/ true, PrepackSharingMode::kNoSharing)); + RunTest(MakeSharingTestOptions(32, 256, /*block_size*/ 32, /*accuracy_level*/ 0, + /*has_zero_point*/ false, /*has_bias*/ false, + PrepackSharingMode::kNoSharing)); +} + +#endif // !ENABLE_TRAINING + #endif #endif diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index f99334c4f33ef..411e83536c190 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -21,6 +21,7 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "test/util/include/scoped_env_vars.h" +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -51,6 +52,10 @@ struct TestOptions8Bits { bool has_g_idx{false}; bool has_bias{false}; + // When set, RunTest8Bits validates cross-session sharing of the pre-packed weights instead of + // doing a single run. The model is run in two CPU sessions that use the same container. + std::optional prepack_sharing_mode{}; + std::optional output_abs_error{}; std::optional output_rel_error{}; }; @@ -221,6 +226,14 @@ void RunTest8Bits(const TestOptions8Bits& opts) { test.SetOutputRelErr("Y", *opts.output_rel_error); } + if (opts.prepack_sharing_mode.has_value()) { + // Pre-packed weight sharing is a CPU-EP-only feature; the helper runs the model on the CPU EP + // in two sessions and validates the sharing counters. + CheckSharedPrepackedWeights(test, *opts.prepack_sharing_mode, + {q_cols, k_blocks, q_rows / k_blocks}, input1_vals); + return; + } + std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); @@ -671,6 +684,56 @@ TEST(MatMulNBits, BFloat16_Int8_Chunked_BFloat16ZeroPoint) { } #endif +#if !defined(USE_CUDA) && !defined(USE_WEBGPU) +#ifndef ENABLE_TRAINING +// Pre-packing (and therefore cross-session sharing of pre-packed weights) is disabled in a full +// training build and is only implemented for the CPU EP, so these tests are CPU-only. + +namespace { +// Builds a representative 8-bit MatMulNBits TestOptions for the pre-packed weight sharing tests. +// accuracy_level 4 selects the int8 compute type (SQNBIT_CompInt8 / HQNBIT_CompInt8), which is the +// 8-bit path that pre-packs the quantized B weight. +TestOptions8Bits MakeSharingTestOptions8Bits(int64_t block_size, bool has_zero_point, bool has_bias, + PrepackSharingMode mode) { + TestOptions8Bits opts{}; + opts.M = 8; + opts.N = 32; + opts.K = 256; + opts.block_size = block_size; + opts.accuracy_level = 4; + opts.has_zero_point = has_zero_point; + opts.has_bias = has_bias; + opts.prepack_sharing_mode = mode; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + return opts; +} +} // namespace + +// Legacy sharing path for 8-bit weights: B is registered as a shared initializer via +// SessionOptions::AddInitializer. +TEST(MatMulNBits, SharedPrepackedWeights_8b_AddInitializer) { + for (bool has_zero_point : {false, true}) { + for (bool has_bias : {false, true}) { + RunTest8Bits(MakeSharingTestOptions8Bits(32, has_zero_point, has_bias, + PrepackSharingMode::kAddInitializer)); + RunTest8Bits(MakeSharingTestOptions8Bits(32, has_zero_point, has_bias, + PrepackSharingMode::kAddInitializer)); + } + } +} + +// Negative control for 8-bit weights: with the shared container present but neither opt-in mechanism +// enabled, no pre-packed weights are shared across sessions. +TEST(MatMulNBits, SharedPrepackedWeights_8b_NotSharedWithoutOptIn) { + RunTest8Bits(MakeSharingTestOptions8Bits(32, /*has_zero_point*/ true, /*has_bias*/ true, + PrepackSharingMode::kNoSharing)); + RunTest8Bits(MakeSharingTestOptions8Bits(32, /*has_zero_point*/ false, /*has_bias*/ false, + PrepackSharingMode::kNoSharing)); +} +#endif // !ENABLE_TRAINING +#endif // !USE_CUDA && !USE_WEBGPU + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc new file mode 100644 index 0000000000000..97566afe02489 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "core/framework/tensor.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime { +namespace test { + +void CheckSharedPrepackedWeights(OpTester& test, PrepackSharingMode mode, + const std::vector& b_dims, + std::vector& b_data) { + SessionOptions so; + OrtValue b_ortvalue; + + switch (mode) { + case PrepackSharingMode::kAddInitializer: + // Register B as an explicitly shared initializer (the pre-existing sharing mechanism). + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(b_dims), b_data.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b_ortvalue); + ASSERT_STATUS_OK(so.AddInitializer("B", &b_ortvalue)); + break; + case PrepackSharingMode::kNoSharing: + // Neither opt-in mechanism is used. + break; + } + + // Have all sessions created by this OpTester use the same pre-packed weights container. + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited to the CPU EP, so the sharing behavior is only exercised there. + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_1, + &number_of_shared_pre_packed_weights_counter); + // Nothing can be shared yet because this is the first session. + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + const auto number_of_elements_in_shared_container = test.GetNumPrePackedWeightsShared(); + + if (mode == PrepackSharingMode::kNoSharing) { + // Without opting in, pre-packed weights must not be placed in the shared container. + ASSERT_EQ(number_of_elements_in_shared_container, static_cast(0)); + } + + // On some platforms/architectures MLAS may choose not to pre-pack, in which case there is nothing + // to share and we cannot meaningfully continue. + if (number_of_pre_packed_weights_counter_session_1 == 0) { + return; + } + + if (mode != PrepackSharingMode::kNoSharing) { + // At least the quantized weight B is content-addressed into the shared container. Some + // architectures (e.g. ARM64 KleidiAI) additionally pre-pack scales, but in the AddInitializer + // mode only the explicitly-registered B participates, so the container can hold fewer elements + // than the total number of pre-packed weights. + ASSERT_GT(number_of_elements_in_shared_container, static_cast(0)); + ASSERT_LE(number_of_elements_in_shared_container, number_of_pre_packed_weights_counter_session_1); + } + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_2, + &number_of_shared_pre_packed_weights_counter); + + // The same number of weights is pre-packed in both sessions. + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Every weight stored in the shared container is served from it (i.e. shared) in the second + // session. For the no-sharing control this is zero; otherwise it matches the container size. + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, number_of_elements_in_shared_container); + + if (mode == PrepackSharingMode::kNoSharing) { + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } else { + ASSERT_GT(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h new file mode 100644 index 0000000000000..1de0bbaa4bb85 --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_nbits_prepack_sharing_test_util.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include + +namespace onnxruntime { +namespace test { + +class OpTester; + +// How two sessions are configured to share the pre-packed weights of a MatMulNBits node. +enum class PrepackSharingMode { + // Legacy path: the weight is explicitly registered as a shared initializer via + // SessionOptions::AddInitializer. + kAddInitializer, + // Negative control: the shared container exists but neither opt-in mechanism is used, so no + // cross-session sharing must happen. + kNoSharing, +}; + +// Runs the already-configured MatMulNBits OpTester in two CPU sessions that share the same +// pre-packed weights container and asserts that the pre-packed weights are shared as expected. +// This logic is independent of the weight bit width, so it is shared by the 4-bit and 8-bit tests. +// `b_dims`/`b_data` describe the quantized B initializer and are only needed for the +// PrepackSharingMode::kAddInitializer path (to register B as a shared initializer). +void CheckSharedPrepackedWeights(OpTester& test, PrepackSharingMode mode, + const std::vector& b_dims, + std::vector& b_data); + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc index 8aa4c88052742..47e08802c9e20 100644 --- a/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc +++ b/onnxruntime/test/optimizer/dq_matmulnbits_fusion_test.cc @@ -4,17 +4,24 @@ // Unit tests for the DQMatMulNBitsFusion graph transformer. // Tests Pattern 1: DQ(3D,axis=2)->Reshape->Transpose([1,0])->[Cast]->MatMul/Gemm -> MatMulNBits // Tests Pattern 2: DQ(2D,axis=0)->MatMul/Gemm -> MatMulNBits +#include #include "core/common/span_utils.h" #include "core/framework/int4.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/dq_matmulnbits_fusion.h" +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/test_environment.h" #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/graph_transform_test_builder.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" #include "gtest/gtest.h" @@ -354,6 +361,166 @@ TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_NoZP) { TransformerLevel::Level1, 1, pre_check, post_check)); } +// Validates the cross-session-sharing tag the fusion attaches to the generated B weight. The tag is a +// stable, content-derived enrollment identity: identical source quantization groups yield the SAME +// identity, while a semantic difference -- here, different zero points -- yields a DIFFERENT identity. +// (The tag only enrolls B into the shared container; the actual sharing is keyed by the packed-bytes +// hash, so a stable, content-distinct tag just keeps enrollment deterministic across sessions.) +TEST_F(DQMatMulNBitsFusionTest, TagsGeneratedWeightWithStableContentIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(N * num_blocks)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + // Non-default (non-8) zero points so the fusion keeps them (it elides uniform-8 zero points). + std::vector zp_a(static_cast(N * num_blocks), 3); + std::vector zp_b(zp_a.size(), 5); + + // Runs the fusion on a Pattern-1 model built from the given zero points and returns the sharing + // identity tagged onto the generated MatMulNBits B weight. + auto tag_for = [&](const std::vector& zp) -> std::string { + std::string captured; + auto build = [&](ModelTestBuilder& builder) { + BuildPattern1Graph(builder, M, N, K, block_size, /*with_zp*/ true, /*with_cast*/ false, + /*use_gemm*/ false, &weight, &scale, &zp); + }; + auto pre_check = [](Graph&) -> Status { return Status::OK(); }; + auto post_check = [&](Graph& graph) -> Status { + int matmulnbits = 0; + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + ++matmulnbits; + const std::string& b_name = node.InputDefs()[1]->Name(); // input 1 == quantized B + const std::string* id = graph.GetSharedPrepackInitializerId(b_name); + EXPECT_NE(id, nullptr) << "generated B weight was not tagged for cross-session sharing"; + if (id != nullptr) { + captured = *id; + } + } + } + EXPECT_EQ(matmulnbits, 1); + return Status::OK(); + }; + auto transformer = std::make_unique(4); + EXPECT_TRUE(TestGraphTransformer(build, 21, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_check, post_check) + .IsOK()); + return captured; + }; + + const std::string id_a1 = tag_for(zp_a); + const std::string id_a2 = tag_for(zp_a); + const std::string id_b = tag_for(zp_b); + + ASSERT_FALSE(id_a1.empty()); + EXPECT_EQ(id_a1, id_a2); // stable: identical source quantization group -> identical identity + EXPECT_NE(id_a1, id_b); // collision-safe: different zero points -> different identity +} + +// Builds and serializes a Pattern-1 DQ->Reshape->Transpose->MatMul model (UINT4 constant weight). When +// loaded into a session with the DQ->MatMulNBits fusion enabled, it becomes a MatMulNBits whose B is +// tagged for cross-session sharing. +static void SerializeDQMatMulModel(int64_t M, int64_t N, int64_t K, int64_t block_size, + const std::vector& weight, const std::vector& scale, + const std::vector& zp, std::string& model_bytes) { + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("dq_matmulnbits_share", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + ModelTestBuilder builder(model.MainGraph()); + BuildPattern1Graph(builder, M, N, K, block_size, /*with_zp*/ true, /*with_cast*/ false, + /*use_gemm*/ false, &weight, &scale, &zp); + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); +} + +// Loads the serialized model on the CPU EP with the DQ->MatMulNBits fusion enabled and the supplied +// shared container. Reports whether the fusion produced a MatMulNBits and how many pre-packed weights +// this session served from the container. +static void RunSharedFusionSession(const std::string& model_bytes, PrepackedWeightsContainer& container, + bool& produced_matmulnbits, size_t& used_shared_count) { + SessionOptions so; + // This test exercises prepack-weight sharing, not parallel execution. Cap the intra-op thread pool + // to a single thread so we don't spin up one worker per core: under AddressSanitizer each thread adds + // fake-stack and thread-local allocator overhead, which on a high-core CI runner multiplies across the + // sessions every test creates (the sibling SessionStatePrepackingTest caps it for the same reason). + so.intra_op_param.thread_pool_size = 1; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsEnableDQMatMulNBitsFusion, "1")); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.AddPrePackedWeightsContainer(&container)); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + ASSERT_STATUS_OK(session.Initialize()); + + produced_matmulnbits = false; + for (const auto& node : session.GetGraph().Nodes()) { + if (node.OpType() == "MatMulNBits") { + produced_matmulnbits = true; + break; + } + } + used_shared_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +// End-to-end: two sessions optimizing the same DQ+MatMul model share the fused MatMulNBits B weight +// through a common container WITHOUT any session option -- the fusion tags it to enroll it, and +// SessionState keys the sharing by the packed-bytes hash. A model whose quantized weight differs packs +// to different bytes, so it gets a different key and must NOT share. (A zero-point-only difference is +// intentionally NOT used: on the CompFp32 path the zero points are not folded into the packed B, so two +// such models pack identically and would correctly share a byte-identical buffer.) +TEST_F(DQMatMulNBitsFusionTest, SharesFusedWeightAcrossSessionsViaTag) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(N * num_blocks * block_size)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + // A different quantized weight -> different packed B on every compute type (unlike a zp-only change). + std::vector weight_other(weight.size()); + for (size_t i = 0; i < weight_other.size(); ++i) { + weight_other[i] = static_cast((i + 7) % 16); + } + std::vector scale(static_cast(N * num_blocks)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(N * num_blocks), 3); + + std::string model_a, model_other; + SerializeDQMatMulModel(M, N, K, block_size, weight, scale, zp, model_a); + SerializeDQMatMulModel(M, N, K, block_size, weight_other, scale, zp, model_other); + + PrepackedWeightsContainer container; + bool fused1 = false, fused2 = false, fused_other = false; + size_t used1 = 0, used2 = 0, used_other = 0; + + RunSharedFusionSession(model_a, container, fused1, used1); + ASSERT_TRUE(fused1) << "DQ -> MatMulNBits fusion did not run"; + if (container.GetNumberOfElements() == 0) { + GTEST_SKIP() << "MatMulNBits B was not pre-packed on this platform"; + } + EXPECT_EQ(used1, static_cast(0)); // first session: nothing to share yet + + // Second session over the SAME model shares the tagged B from the container. + RunSharedFusionSession(model_a, container, fused2, used2); + ASSERT_TRUE(fused2); + EXPECT_GT(used2, static_cast(0)); + + // A model with a different quantized weight packs to different bytes -> different key, so it must NOT + // reuse the buffer (on any compute type). + RunSharedFusionSession(model_other, container, fused_other, used_other); + ASSERT_TRUE(fused_other); + EXPECT_EQ(used_other, static_cast(0)); +} + TEST_F(DQMatMulNBitsFusionTest, Pattern1_MatMul_WithDefaultZP8) { constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc index a1c0f8adfffb7..b53577a81ff4a 100644 --- a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -2,10 +2,14 @@ // Licensed under the MIT License. #include +#include #include "core/common/span_utils.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/prepacked_weights_container.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" #include "core/graph/node_attr_utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" @@ -1462,6 +1466,237 @@ TEST(QDQTransformerTests, DQGemmNotConvertedToMatMulNBits_Alpha) { 1e-5, 2e-5); } +// --------------------------------------------------------------------------- +// Cross-session pre-pack sharing for the DEFAULT DQ->MatMulNBits path +// --------------------------------------------------------------------------- +// DQMatMulToMatMulNBitsAction (in the QDQ selector/action transformer) runs without the +// session.enable_dq_matmulnbits_fusion flag and synthesizes the MatMulNBits B/scales/zp initializers +// with names that are NOT stable across sessions. It tags the generated B weight with a sharing +// identity that SessionState treats as the enrollment signal opting the buffer into the cross-session +// container; the actual sharing is keyed by the packed-bytes hash (only byte-identical packed buffers +// are reused, exactly like the AddInitializer path), so packings that differ by compute type/options +// are never falsely shared. + +// Packs uint4 nibble values (row-major, 2 per byte) into UInt4x2 storage. +static std::vector PackUint4Nibbles(const std::vector& values) { + const size_t num_pairs = UInt4x2::CalcNumInt4Pairs(values.size()); + std::vector packed(num_pairs); + for (size_t i = 0; i < values.size(); i += 2) { + const uint8_t lo = values[i] & 0x0F; + const uint8_t hi = (i + 1 < values.size()) ? (values[i + 1] & 0x0F) : 0; + packed[i / 2] = UInt4x2(lo, hi); + } + return packed; +} + +// Builds a default-path model: a constant UINT4 weight [K, N] block-quantized along axis 0 feeding a +// DequantizeLinear whose output is the second input to a single MatMul. The QDQ selector/action +// transformer converts this into a MatMulNBits. Explicit weight/scale/zp give a deterministic identity. +static void BuildDefaultPathDQMatMul(ModelTestBuilder& builder, int64_t M, int64_t N, int64_t K, + int64_t block_size, const std::vector& weight, + const std::vector& scale, const std::vector& zp) { + const int64_t num_blocks = (K + block_size - 1) / block_size; + + auto* input_a = builder.MakeInput({M, K}, -1.0f, 1.0f); + auto* output = builder.MakeOutput(); + + auto* weight_arg = builder.MakeInitializer({K, N}, PackUint4Nibbles(weight)); + auto* scale_arg = builder.MakeInitializer({num_blocks, N}, scale); + auto* zp_arg = builder.MakeInitializer({num_blocks, N}, PackUint4Nibbles(zp)); + + NodeAttributes dq_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), dq_attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), dq_attrs); + auto* dq_output = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &dq_attrs); + + builder.AddNode("MatMul", {input_a, dq_output}, {output}); +} + +// Serializes a default-path DQ->MatMul model built from explicit quantization data. +static void SerializeDefaultPathModel(int64_t M, int64_t N, int64_t K, int64_t block_size, + const std::vector& weight, const std::vector& scale, + const std::vector& zp, std::string& model_bytes) { + const std::unordered_map domain_to_version{{"", 21}, {kMSDomain, 1}}; + Model model("dq_matmul_default_share", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + std::vector(), DefaultLoggingManager().DefaultLogger()); + ModelTestBuilder builder(model.MainGraph()); + BuildDefaultPathDQMatMul(builder, M, N, K, block_size, weight, scale, zp); + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + ASSERT_TRUE(model.ToProto().SerializeToString(&model_bytes)); +} + +// Loads the model on the CPU EP with the given shared container and DEFAULT options (no fusion flag). +// Reports whether a MatMulNBits was produced, the sharing identity tagged onto its B weight, and how +// many pre-packed weights this session served from the container. +static void RunDefaultPathSession(const std::string& model_bytes, PrepackedWeightsContainer& container, + bool& produced_matmulnbits, std::string& b_tag, size_t& used_shared_count, + int accuracy_level = -1) { + SessionOptions so; + // This test exercises prepack-weight sharing, not parallel execution. Cap the intra-op thread pool + // to a single thread so we don't spin up one worker per core: under AddressSanitizer each thread adds + // fake-stack and thread-local allocator overhead, which on a high-core CI runner multiplies across the + // sessions every test creates (the sibling SessionStatePrepackingTest caps it for the same reason). + so.intra_op_param.thread_pool_size = 1; + if (accuracy_level >= 0) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str())); + } + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.AddPrePackedWeightsContainer(&container)); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + ASSERT_STATUS_OK(session.Initialize()); + + produced_matmulnbits = false; + b_tag.clear(); + const Graph& graph = session.GetGraph(); + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + produced_matmulnbits = true; + const std::string& b_name = node.InputDefs()[1]->Name(); // input 1 == quantized B + if (const std::string* id = graph.GetSharedPrepackInitializerId(b_name); id != nullptr) { + b_tag = *id; + } + break; + } + } + used_shared_count = session.GetSessionState().GetUsedSharedPrePackedWeightCounter(); +} + +// Verifies the default DQ->MatMulNBits path tags its generated B weight with a stable, content-derived +// enrollment identity: identical quantization data yields the SAME identity, while different zero points +// yield a DIFFERENT identity. (The tag only enrolls the buffer for sharing; the container keys by the +// packed-bytes hash. A stable, content-distinct tag keeps enrollment deterministic across sessions.) +TEST(QDQTransformerTests, DefaultPath_TagsGeneratedWeightWithStableContentIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp_a(static_cast(num_blocks * N), 3); + std::vector zp_b(zp_a.size(), 5); + + auto tag_for = [&](const std::vector& zp) -> std::string { + std::string model_bytes; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_bytes); + PrepackedWeightsContainer container; + bool produced = false; + std::string tag; + size_t used = 0; + RunDefaultPathSession(model_bytes, container, produced, tag, used); + EXPECT_TRUE(produced) << "DQ -> MatMulNBits conversion did not run on the default path"; + return tag; + }; + + const std::string id_a1 = tag_for(zp_a); + const std::string id_a2 = tag_for(zp_a); + const std::string id_b = tag_for(zp_b); + + ASSERT_FALSE(id_a1.empty()) << "generated B weight was not tagged for cross-session sharing"; + EXPECT_EQ(id_a1, id_a2); // stable: identical quantization data -> identical identity + EXPECT_NE(id_a1, id_b); // collision-safe: different zero points -> different identity +} + +// End-to-end: two sessions converting the same model via the default path share the MatMulNBits B +// pre-packed buffer through a common container (no session option). A model whose quantized weight +// differs packs to different bytes -> different container key, so it must not reuse the buffer. (A +// zero-point-only difference is intentionally NOT used here: on the CompFp32 path the zero points are +// applied at compute time and left out of the packed B, so two such models pack identically and would +// correctly share -- packed-bytes keying only ever reuses byte-identical buffers.) +TEST(QDQTransformerTests, DefaultPath_SharesWeightAcrossSessionsViaTag) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + // A different quantized weight -> different packed B on every compute type (unlike a zp-only change). + std::vector weight_other(weight.size()); + for (size_t i = 0; i < weight_other.size(); ++i) { + weight_other[i] = static_cast((i + 7) % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(num_blocks * N), 3); + + std::string model_a, model_other; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_a); + SerializeDefaultPathModel(M, N, K, block_size, weight_other, scale, zp, model_other); + + PrepackedWeightsContainer container; + bool produced1 = false, produced2 = false, produced_other = false; + std::string tag1, tag2, tag_other; + size_t used1 = 0, used2 = 0, used_other = 0; + + RunDefaultPathSession(model_a, container, produced1, tag1, used1); + ASSERT_TRUE(produced1) << "DQ -> MatMulNBits conversion did not run on the default path"; + if (container.GetNumberOfElements() == 0) { + GTEST_SKIP() << "MatMulNBits B was not pre-packed on this platform"; + } + EXPECT_EQ(used1, static_cast(0)); // first session: nothing to share yet + + // Second session over the SAME model reuses the tagged B from the container. + RunDefaultPathSession(model_a, container, produced2, tag2, used2); + ASSERT_TRUE(produced2); + EXPECT_GT(used2, static_cast(0)); + + // A model with a different quantized weight packs to different bytes -> different key, so it must NOT + // reuse the buffer (on any compute type). + RunDefaultPathSession(model_other, container, produced_other, tag_other, used_other); + ASSERT_TRUE(produced_other); + EXPECT_EQ(used_other, static_cast(0)); +} + +// accuracy_level participates in the enrollment identity, so the same weights requested at different +// accuracy levels get distinct identities. Whether the two sessions then share the packed buffer is +// platform-dependent (level 4 may pack as CompInt8 -- different bytes, no share -- or fall back to the +// same CompFp32 packing as level 0 and benignly reuse the byte-identical buffer); packed-bytes keying +// makes either outcome safe, so this asserts the identity is distinct, not a fixed sharing count. +TEST(QDQTransformerTests, DefaultPath_DifferentAccuracyLevelGetsDistinctIdentity) { + constexpr int64_t M = 4, N = 8, K = 32, block_size = 16; + const int64_t num_blocks = K / block_size; + + std::vector weight(static_cast(K * N)); + for (size_t i = 0; i < weight.size(); ++i) { + weight[i] = static_cast(i % 16); + } + std::vector scale(static_cast(num_blocks * N)); + for (size_t i = 0; i < scale.size(); ++i) { + scale[i] = 0.1f + 0.01f * static_cast(i % 10); + } + std::vector zp(static_cast(num_blocks * N), 3); + + std::string model_bytes; + SerializeDefaultPathModel(M, N, K, block_size, weight, scale, zp, model_bytes); + + PrepackedWeightsContainer container; + bool produced0 = false, produced4 = false; + std::string tag0, tag4; + size_t used0 = 0, used4 = 0; + + RunDefaultPathSession(model_bytes, container, produced0, tag0, used0, /*accuracy_level*/ 0); + ASSERT_TRUE(produced0) << "DQ -> MatMulNBits conversion did not run on the default path"; + + // Same model/weights, different accuracy level, sharing the same container. + RunDefaultPathSession(model_bytes, container, produced4, tag4, used4, /*accuracy_level*/ 4); + ASSERT_TRUE(produced4); + + ASSERT_FALSE(tag0.empty()); + ASSERT_FALSE(tag4.empty()); + EXPECT_NE(tag0, tag4); // accuracy_level participates in the enrollment identity +} + #endif // !defined(DISABLE_CONTRIB_OPS) } // namespace test From d857f77afe7d701e8abeac403a8e01b364d91481 Mon Sep 17 00:00:00 2001 From: Gopalakrishnan Nallasamy Date: Fri, 26 Jun 2026 15:43:53 -0700 Subject: [PATCH 15/19] Fix integer overflow in RKNPU implicit bias allocation (#29249) ### Description The RKNPU execution provider's ONNX converter creates implicit (all-zero) bias buffers when a `Conv` / `Gemm` / `QLinearConv` node omits its bias input. The buffer size was computed as `sizeof(T) * dim` (where `dim` derives from a model weight's shape) with no overflow check, and the raw allocation was tracked in a manually-freed `void*` list. This PR hardens the converter: - **Dimension validation:** ONNX `int64_t` dimensions are validated (via `ORT_ENFORCE` in a new `ToRknpuDim` helper) before being narrowed to the RKNPU `uint32_t` shape representation, rejecting negative or out-of-range values. This covers all four ingestion points (`HandleInitializer`, `GetInputOfOnnxModel`, `GetShape`, `GetSupportedNodes`). - **Overflow-checked allocation:** all four implicit-bias sites (`AddLayerConvImpl`, `AddLayerQLinearConvImpl`, `AddLayerDepthwiseConvImpl`, `AddLayerFC`) go through a shared `AllocZeroedBias` helper that computes the byte count with `SafeInt` (throws on overflow) and returns a zero-initialized `std::make_unique` buffer. - **RAII ownership:** `free_list_` is now `std::vector>`, so the bias buffers are freed automatically and `Clear()` no longer walks/`free()`s raw pointers. ### Motivation and Context A malicious ONNX model can provide dimensions that are unsafe for the RKNPU converter's 32-bit shape representation or for byte-size allocation arithmetic: - ONNX stores dimensions as `int64_t`, while the RKNPU converter/DDK uses `uint32_t` shape values. Silently narrowing a large or negative `int64_t` value can produce a misleading `uint32_t` dimension. - Even after a dimension is represented as `uint32_t`, the original `sizeof(T) * dim` could overflow `size_t` on 32-bit RKNPU targets. For example, `dim = 0x40000400` makes `sizeof(float) * dim` wrap to **4096 bytes** while the created tensor still advertises `dim` elements, which would corrupt the heap when the bias is consumed by the driver. - The original code also passed the `malloc` result to `memset` without a null check. The fix uses `SafeInt` (the ORT-standard idiom for memory-size arithmetic), validates ONNX dimensions before they enter the RKNPU `uint32_t` shape model, and replaces the manual `malloc`/`free` list with zero-initialized, RAII-owned `std::make_unique` buffers. **Validation:** the RKNPU EP requires the Rockchip DDK (`RKNPU_DDK_PATH`) and an ARM target, so it does not build on a typical x64 dev box and has no GPU-free CI leg. The change was validated with `clang-format`, `git diff --check`, and a standalone test against ORT's `SafeInt.hpp`, confirming the guard throws on the overflowing dimensions (`0x40000400`, `0xFFFFFFFF`) while preserving normal and zero-dimension behavior. **Testing:** No unit test is included because the RKNPU EP is not compiled in any CI leg (it requires the proprietary Rockchip DDK and an ARM target), and the `ToRknpuDim` / `AllocZeroedBias` helpers have internal (file-`static`) linkage, so they are not reachable from the gtest suites. The overflow/validation logic was instead exercised with a standalone test against ORT's `SafeInt.hpp` as noted above. --------- Co-authored-by: Gopalakrishnan Nallasamy --- .../core/providers/rknpu/onnx_converter.cc | 50 ++++++++++++------- .../core/providers/rknpu/onnx_converter.h | 2 +- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/rknpu/onnx_converter.cc b/onnxruntime/core/providers/rknpu/onnx_converter.cc index 30de1af8fa7bd..bab76de49519e 100644 --- a/onnxruntime/core/providers/rknpu/onnx_converter.cc +++ b/onnxruntime/core/providers/rknpu/onnx_converter.cc @@ -1,6 +1,7 @@ // Copyright 2020 rock-chips.com Inc. #include +#include #include #include #include @@ -10,7 +11,9 @@ #include #include #include +#include "core/common/common.h" #include "core/common/logging/logging.h" +#include "core/common/safeint.h" #include "onnx_converter.h" #include "node_attr_helper.h" @@ -119,12 +122,28 @@ OnnxConverter::CreateRknnTensor(const std::string& name, return graph_->CreateTensor(attr, (void*)data); } +static uint32_t ToRknpuDim(int64_t dim, const std::string& name) { + ORT_ENFORCE(dim >= 0 && dim <= static_cast(std::numeric_limits::max()), + "RKNPU: tensor dimension out of uint32_t range (name=", name, ", dim=", dim, ")"); + + return static_cast(dim); +} + +// Allocates a zero-initialized bias buffer for `count` elements of `element_size` +// bytes, used when a Conv/Gemm node omits its bias input. SafeInt provides +// overflow-checked size arithmetic (throws on size_t overflow); std::make_unique +// zero-initializes and owns the buffer. +static std::unique_ptr AllocZeroedBias(size_t element_size, uint32_t count) { + const size_t num_bytes = SafeInt(element_size) * count; + return std::make_unique(num_bytes); +} + void OnnxConverter::HandleInitializer() { for (const auto& tensor : model_proto_.graph().initializer()) { const std::string name = tensor.name(); std::vector dims; for (const auto dim : tensor.dims()) { - dims.push_back(static_cast(dim)); + dims.push_back(ToRknpuDim(dim, name)); } if (tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { const char* ptr = tensor.float_data().empty() @@ -186,7 +205,7 @@ std::vector> OnnxConverter::GetInputOfOnnxModel( for (const auto& dim : input.type().tensor_type().shape().dim()) { if (dim.value_case() == ONNX_NAMESPACE::TensorShapeProto_Dimension::kDimValue) { - shape.push_back(static_cast(dim.dim_value())); + shape.push_back(ToRknpuDim(dim.dim_value(), input.name())); } else { throw std::invalid_argument( "The input of graph doesn't have dim_value"); @@ -267,7 +286,7 @@ Shaper::Shape GetShape(const ONNX_NAMESPACE::ModelProto& model_proto, for (const auto& dim : value_info.type().tensor_type().shape().dim()) { if (dim.has_dim_value()) { - shape.push_back(dim.dim_value()); + shape.push_back(ToRknpuDim(dim.dim_value(), value_info.name())); } else { break; } @@ -548,7 +567,7 @@ std::vector> OnnxConverter::GetSupportedNodes( const std::string name = tensor.name(); std::vector dims; for (const auto dim : tensor.dims()) { - dims.push_back(static_cast(dim)); + dims.push_back(ToRknpuDim(dim, name)); } tensor_dims_[name] = dims; } @@ -814,9 +833,6 @@ void OnnxConverter::Clear() { rk_tensors_.clear(); shaper_.Clear(); - for (const auto p : free_list_) { - if (p) free(p); - } free_list_.clear(); } @@ -944,9 +960,8 @@ void OnnxConverter::AddLayerConvImpl(const std::string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); @@ -1053,9 +1068,8 @@ void OnnxConverter::AddLayerQLinearConvImpl(const string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(int32_t) * dim); - memset(ptr, 0, sizeof(int32_t) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(int32_t), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST, @@ -1142,9 +1156,8 @@ void OnnxConverter::AddLayerDepthwiseConvImpl( } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); @@ -1376,9 +1389,8 @@ void OnnxConverter::AddLayerFC(const std::string& input, } } else { uint32_t dim = shaper_[weight][0]; - void* ptr = (void*)malloc(sizeof(float) * dim); - memset(ptr, 0, sizeof(float) * dim); - free_list_.push_back(ptr); + free_list_.push_back(AllocZeroedBias(sizeof(float), dim)); + void* ptr = free_list_.back().get(); std::vector dims = {dim}; auto rk_bias = CreateRknnTensor(bias, dims, ptr, rk::nn::TensorRole::CONST); diff --git a/onnxruntime/core/providers/rknpu/onnx_converter.h b/onnxruntime/core/providers/rknpu/onnx_converter.h index 10cc09a9dba92..41d50d6b9401f 100644 --- a/onnxruntime/core/providers/rknpu/onnx_converter.h +++ b/onnxruntime/core/providers/rknpu/onnx_converter.h @@ -63,7 +63,7 @@ class OnnxConverter { // for GetSupportedNodes std::map> tensor_dims_; - std::vector free_list_; // remember free + std::vector> free_list_; // owns implicit-bias buffers std::pair, FuseCode> FindActivation(const ONNX_NAMESPACE::ModelProto& model_proto, From 37cccc89fa3a5f0eda14c38df32b8be18a7bfbc0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jun 2026 15:52:49 -0700 Subject: [PATCH 16/19] Fix CUDA/cuDNN DLL preload paths for CUDA 13 consolidated wheel layout (#29202) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description Fix https://github.com/microsoft/onnxruntime/issues/29198. NVIDIA restructured the CUDA Python wheels starting with CUDA 13: the per-component CUDA Toolkit packages (cublas, cufft, cuda_runtime, cuda_nvrtc, curand, ...) were consolidated into a single `nvidia/cu{major}` package and the `-cuNN` suffix was dropped from those package names. This PR updates the DLL/shared-library preload logic and the wheel dependency metadata so `onnxruntime-gpu` (and `onnxruntime-trt-rtx`) keep working on both the legacy CUDA 12 layout and the new CUDA 13 consolidated layout. ## Summary of Changes ### Preload logic (`onnxruntime/__init__.py`) | File | Change | |------|--------| | `onnxruntime/__init__.py` | `_get_nvidia_dll_paths` now detects the CUDA 13+ consolidated layout and resolves CUDA libraries under `nvidia/cu{major}` — Windows uses an architecture sub-folder (`bin/`, e.g. `bin/x86_64`), Linux uses a flat `lib`. The legacy CUDA 12 per-component paths are preserved. | | `onnxruntime/__init__.py` | Added `build_cuda_version` and `arch` parameters (for testability/arch override); cuDNN paths factored out since cuDNN keeps its own `nvidia/cudnn` package layout in both schemes. | | `onnxruntime/__init__.py` | `print_debug_info` drops the `-cuNN` suffix from CUDA Toolkit package names for CUDA 13+ (cuDNN keeps its suffixed name). | ### Wheel dependency metadata (`setup.py`) | File | Change | |------|--------| | `setup.py` | `onnxruntime-gpu` `cuda` extras drop the `-cuNN` suffix for CUDA 13+ (`nvidia-cuda-nvrtc`, `nvidia-cuda-runtime`, `nvidia-cufft`, `nvidia-curand`); cuDNN dependency keeps the suffixed name. | | `setup.py` | `onnxruntime-trt-rtx` CUDA Runtime dependency drops the `-cuNN` suffix for CUDA 13+. | ### Tests (`onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py`) - New unit tests pin the expected relative paths for the CUDA 12 (legacy) and CUDA 13 (consolidated) layouts on both Windows and Linux, the Windows arch override, the Linux flat-`lib` layout, the unchanged cuDNN layout, and the `cuda`/`cudnn` toggles. ## Testing - Run the new tests: `python -m pytest onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py` (or `python -m unittest onnxruntime.test.python.onnxruntime_test_python_preload_dlls`). - Backward compatibility: CUDA 12 paths and the cuDNN layout are unchanged; only CUDA 13+ takes the new consolidated paths and unsuffixed package names. - Build in Linux and Windows, and `pip install onnxruntime-gpu*.whl[cuda,cudnn]`, then `import onnxruntime; onnxruntime.preload_dlls()` can run successfully in python. ## Checklist - [x] Tests added/updated - [x] No breaking changes (CUDA 12 behavior preserved) --- cmake/onnxruntime_providers_cuda.cmake | 32 ++++---- onnxruntime/__init__.py | 81 ++++++++++++++----- .../onnxruntime_test_python_preload_dlls.py | 79 ++++++++++++++++++ setup.py | 17 ++-- 4 files changed, 168 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f692f1f5e0a57..2aa31276cc395 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -208,7 +208,7 @@ target_compile_definitions(onnxruntime_providers_cuda PRIVATE FILE_NAME=\"onnxruntime_providers_cuda.dll\") endif() - # Work around a CUDA 13.x cudafe++ (EDG front-end) regression that mis-parses CCCL's + # Work around a CUDA 13.3 cudafe++ (EDG front-end) regression that mis-parses CCCL's # global-qualified partial specializations, e.g. in : # template # struct ::cuda::proclaims_copyable_arguments<...> : ::cuda::std::true_type {}; @@ -218,7 +218,7 @@ # corrected copies of the affected headers into the build tree and place that directory # ahead of the toolkit cccl include path. This is a no-op on toolkits whose headers do not # contain the offending pattern (e.g. once NVIDIA fixes it), so it is safe to keep enabled. - function(ort_cuda13_patch_cccl_header src dst) + function(ort_cuda133_patch_cccl_header src dst) if (NOT EXISTS "${src}") return() endif() @@ -412,19 +412,21 @@ if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) foreach(inc_dir ${CUDAToolkit_INCLUDE_DIRS}) if (EXISTS "${inc_dir}/cccl") - # Generate cudafe++-parseable copies of the CCCL headers that contain global-qualified - # partial specializations (see ort_cuda13_patch_cccl_header above) and put the fixed - # directory ahead of the toolkit cccl include so the corrected headers win. - set(_ort_cccl_fix_dir "${CMAKE_CURRENT_BINARY_DIR}/cccl_cuda13_fix") - ort_cuda13_patch_cccl_header( - "${inc_dir}/cccl/cub/device/device_transform.cuh" - "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh") - ort_cuda13_patch_cccl_header( - "${inc_dir}/cccl/cub/device/dispatch/tuning/tuning_transform.cuh" - "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") - if (EXISTS "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh" OR - EXISTS "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") - target_include_directories(${target} BEFORE PRIVATE "${_ort_cccl_fix_dir}") + if (UNIX AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.3 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 13.4) + # Generate cudafe++-parseable copies of the CCCL headers that contain global-qualified + # partial specializations (see ort_cuda133_patch_cccl_header above) and put the fixed + # directory ahead of the toolkit cccl include so the corrected headers win. + set(_ort_cccl_fix_dir "${CMAKE_CURRENT_BINARY_DIR}/cccl_cuda13_fix") + ort_cuda133_patch_cccl_header( + "${inc_dir}/cccl/cub/device/device_transform.cuh" + "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh") + ort_cuda133_patch_cccl_header( + "${inc_dir}/cccl/cub/device/dispatch/tuning/tuning_transform.cuh" + "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") + if (EXISTS "${_ort_cccl_fix_dir}/cub/device/device_transform.cuh" OR + EXISTS "${_ort_cccl_fix_dir}/cub/device/dispatch/tuning/tuning_transform.cuh") + target_include_directories(${target} BEFORE PRIVATE "${_ort_cccl_fix_dir}") + endif() endif() # Add the cccl subdirectory to the include path so can be found diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index df14bc8c57f24..0a06156fe78d8 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -150,25 +150,57 @@ def _extract_cuda_major_version(version_str: str) -> str: return version_str.split(".", maxsplit=1)[0] if version_str else "12" -def _get_cufft_version(cuda_major: str) -> str: +def _get_cufft_version(cuda_major_version: str) -> str: """Get cufft library version based on CUDA major version. Args: - cuda_major: CUDA major version as string (e.g., "12", "13") + cuda_major_version: CUDA major version as string (e.g., "12", "13") Returns: cufft version as string """ # cufft versions: CUDA 12.x -> 11, CUDA 13.x -> 12 - return "12" if cuda_major == "13" else "11" + return "12" if int(cuda_major_version) >= 13 else "11" def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = True): - # Dynamically determine CUDA major version from build info + # Dynamically determine CUDA major version from build info. + # build_cuda_version defaults to the version this package was built with; it is a parameter for testability. cuda_major_version = _extract_cuda_major_version(cuda_version) cufft_version = _get_cufft_version(cuda_major_version) - if is_windows: + # Starting with CUDA 13, NVIDIA consolidated the per-component CUDA Toolkit wheels + # (cublas, cufft, cuda_runtime, cuda_nvrtc, curand, ...) into a single "nvidia/cu{major}" + # package and dropped the "-cuNN" suffix from those package names. On Windows the DLLs + # moved into an architecture sub-folder ("bin/", e.g. "bin/x86_64"); on Linux the + # libraries are placed directly in "lib" (the wheel itself is architecture specific, so + # there is no arch sub-folder). cuDNN keeps its own "nvidia/cudnn" package and layout. + use_consolidated_layout = cuda_major_version.isdigit() and int(cuda_major_version) >= 13 + + if use_consolidated_layout: + cuda_dir = f"cu{cuda_major_version}" + if is_windows: + import platform # noqa: PLC0415 + + arch = "arm64" if platform.machine().lower() in ("arm64", "aarch64") else "x86_64" + cuda_dll_paths = [ + ("nvidia", cuda_dir, "bin", arch, f"cublasLt64_{cuda_major_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cublas64_{cuda_major_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cufft64_{cufft_version}.dll"), + ("nvidia", cuda_dir, "bin", arch, f"cudart64_{cuda_major_version}.dll"), + ] + else: # Linux + # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first. + cuda_dll_paths = [ + ("nvidia", cuda_dir, "lib", f"libcublasLt.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", f"libcublas.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", f"libnvrtc.so.{cuda_major_version}"), + ("nvidia", cuda_dir, "lib", "libcurand.so.10"), + ("nvidia", cuda_dir, "lib", f"libcufft.so.{cufft_version}"), + ("nvidia", cuda_dir, "lib", f"libcudart.so.{cuda_major_version}"), + ] + elif is_windows: + # CUDA 12 and earlier: each component ships its own "nvidia/" package. # Path is relative to site-packages directory. cuda_dll_paths = [ ("nvidia", "cublas", "bin", f"cublasLt64_{cuda_major_version}.dll"), @@ -176,16 +208,6 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru ("nvidia", "cufft", "bin", f"cufft64_{cufft_version}.dll"), ("nvidia", "cuda_runtime", "bin", f"cudart64_{cuda_major_version}.dll"), ] - cudnn_dll_paths = [ - ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn64_9.dll"), - ("nvidia", "cudnn", "bin", "cudnn_engines_tensor_ir64_9.dll"), - ] else: # Linux # cublas64 depends on cublasLt64, so cublasLt64 should be loaded first. cuda_dll_paths = [ @@ -197,6 +219,19 @@ def _get_nvidia_dll_paths(is_windows: bool, cuda: bool = True, cudnn: bool = Tru ("nvidia", "cuda_runtime", "lib", f"libcudart.so.{cuda_major_version}"), ] + # cuDNN keeps its own "nvidia/cudnn" package layout in both old and consolidated schemes. + if is_windows: + cudnn_dll_paths = [ + ("nvidia", "cudnn", "bin", "cudnn_engines_runtime_compiled64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_engines_precompiled64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_heuristic64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_ops64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_adv64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_graph64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn64_9.dll"), + ("nvidia", "cudnn", "bin", "cudnn_engines_tensor_ir64_9.dll"), + ] + else: # Linux # Do not load cudnn sub DLLs (they will be dynamically loaded later) to be consistent with PyTorch in Linux. cudnn_dll_paths = [ ("nvidia", "cudnn", "lib", "libcudnn.so.9"), @@ -238,15 +273,19 @@ def print_debug_info(): # Print version of installed packages that is related to CUDA or cuDNN DLLs. cuda_major = _extract_cuda_major_version(cuda_version) + # Starting with CUDA 13, NVIDIA dropped the "-cuNN" suffix from the per-component + # CUDA Toolkit packages (cuDNN keeps its suffixed package name). + cuda_pkg_suffix = "" if (cuda_major.isdigit() and int(cuda_major) >= 13) else f"-cu{cuda_major}" + packages = [ "torch", - f"nvidia-cuda-runtime-cu{cuda_major}", + f"nvidia-cuda-runtime{cuda_pkg_suffix}", f"nvidia-cudnn-cu{cuda_major}", - f"nvidia-cublas-cu{cuda_major}", - f"nvidia-cufft-cu{cuda_major}", - f"nvidia-curand-cu{cuda_major}", - f"nvidia-cuda-nvrtc-cu{cuda_major}", - f"nvidia-nvjitlink-cu{cuda_major}", + f"nvidia-cublas{cuda_pkg_suffix}", + f"nvidia-cufft{cuda_pkg_suffix}", + f"nvidia-curand{cuda_pkg_suffix}", + f"nvidia-cuda-nvrtc{cuda_pkg_suffix}", + f"nvidia-nvjitlink{cuda_pkg_suffix}", ] for package in packages: directory_name = "nvidia" if package.startswith("nvidia-") else None diff --git a/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py b/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py new file mode 100644 index 0000000000000..a8ce794f5fdd3 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_preload_dlls.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# pylint: disable=C0114,C0115,C0116,W0212 +import unittest + +import onnxruntime + + +class TestGetNvidiaDllPaths(unittest.TestCase): + """Unit tests for the private _get_nvidia_dll_paths helper that locates CUDA/cuDNN + libraries inside the NVIDIA site-packages folders. + + NVIDIA restructured the CUDA Python wheels starting with CUDA 13: the per-component + packages (cublas, cufft, cuda_runtime, ...) were consolidated into a single + "nvidia/cu{major}" tree. These tests pin down the expected relative paths for the + old (CUDA 12) and new (CUDA 13) layouts on both Windows and Linux. + """ + + def _paths(self, **kwargs): + return onnxruntime._get_nvidia_dll_paths(**kwargs) + + # ---- CUDA 12 (legacy per-component layout) -------------------------------------- + def test_cuda12_windows(self): + paths = self._paths(is_windows=True, build_cuda_version="12.4", cudnn=False) + self.assertIn(("nvidia", "cublas", "bin", "cublasLt64_12.dll"), paths) + self.assertIn(("nvidia", "cublas", "bin", "cublas64_12.dll"), paths) + self.assertIn(("nvidia", "cufft", "bin", "cufft64_11.dll"), paths) + self.assertIn(("nvidia", "cuda_runtime", "bin", "cudart64_12.dll"), paths) + + def test_cuda12_linux(self): + paths = self._paths(is_windows=False, build_cuda_version="12.4", cudnn=False) + self.assertIn(("nvidia", "cublas", "lib", "libcublasLt.so.12"), paths) + self.assertIn(("nvidia", "cublas", "lib", "libcublas.so.12"), paths) + self.assertIn(("nvidia", "cuda_nvrtc", "lib", "libnvrtc.so.12"), paths) + self.assertIn(("nvidia", "curand", "lib", "libcurand.so.10"), paths) + self.assertIn(("nvidia", "cufft", "lib", "libcufft.so.11"), paths) + self.assertIn(("nvidia", "cuda_runtime", "lib", "libcudart.so.12"), paths) + + # ---- CUDA 13 (consolidated "cu13" layout) --------------------------------------- + def test_cuda13_windows_x86_64(self): + paths = self._paths(is_windows=True, build_cuda_version="13.2", cudnn=False, arch="x86_64") + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cublasLt64_13.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cublas64_13.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cufft64_12.dll"), paths) + self.assertIn(("nvidia", "cu13", "bin", "x86_64", "cudart64_13.dll"), paths) + + def test_cuda13_windows_arch_override(self): + paths = self._paths(is_windows=True, build_cuda_version="13.2", cudnn=False, arch="arm64") + self.assertIn(("nvidia", "cu13", "bin", "arm64", "cudart64_13.dll"), paths) + + def test_cuda13_linux_is_flat(self): + paths = self._paths(is_windows=False, build_cuda_version="13.2", cudnn=False) + # Linux consolidated layout has no architecture sub-folder (flat "lib"). + self.assertIn(("nvidia", "cu13", "lib", "libcublasLt.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcublas.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libnvrtc.so.13"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcurand.so.10"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcufft.so.12"), paths) + self.assertIn(("nvidia", "cu13", "lib", "libcudart.so.13"), paths) + + # ---- cuDNN keeps its own package/layout in both schemes ------------------------- + def test_cudnn_layout_unchanged(self): + for build_cuda_version in ("12.4", "13.2"): + win = self._paths(is_windows=True, build_cuda_version=build_cuda_version, cuda=False) + self.assertIn(("nvidia", "cudnn", "bin", "cudnn64_9.dll"), win) + + linux = self._paths(is_windows=False, build_cuda_version=build_cuda_version, cuda=False) + self.assertEqual(linux, [("nvidia", "cudnn", "lib", "libcudnn.so.9")]) + + # ---- toggles -------------------------------------------------------------------- + def test_cuda_and_cudnn_toggles(self): + self.assertEqual(self._paths(is_windows=False, build_cuda_version="13.2", cuda=False, cudnn=False), []) + + cuda_only = self._paths(is_windows=False, build_cuda_version="13.2", cuda=True, cudnn=False) + self.assertTrue(all(p[1] == "cu13" for p in cuda_only)) + + +if __name__ == "__main__": + unittest.main() diff --git a/setup.py b/setup.py index 3b8bb9b81d20a..62ced38819f2c 100644 --- a/setup.py +++ b/setup.py @@ -817,7 +817,9 @@ def reformat_run_count(count_str): # Adding CUDA Runtime as dependency for NV TensorRT RTX python wheel if package_name == "onnxruntime-trt-rtx": major = cuda_major_version or "12" # Default to CUDA 12 - install_requires.append(f"nvidia-cuda-runtime-cu{major}~={major}.0") + # CUDA 13 dropped the "-cuNN" suffix from the CUDA Runtime package name. + runtime_pkg = "nvidia-cuda-runtime" if int(major) >= 13 else f"nvidia-cuda-runtime-cu{major}" + install_requires.append(f"{runtime_pkg}~={major}.0") def save_build_and_package_info(package_name, version_number, cuda_version, qnn_version): @@ -862,13 +864,18 @@ def save_build_and_package_info(package_name, version_number, cuda_version, qnn_ if package_name == "onnxruntime-gpu" and cuda_major_version: # Determine cufft version: CUDA 13 uses cufft 12, CUDA 12 uses cufft 11 cufft_version = "12.0" if cuda_major_version == "13" else "11.0" + + # Starting with CUDA 13, NVIDIA renamed the per-component CUDA Toolkit packages by + # dropping the "-cuNN" suffix (e.g. "nvidia-cuda-runtime-cu12" -> "nvidia-cuda-runtime"). + # cuDNN keeps the suffixed package name ("nvidia-cudnn-cu13"). + cuda_pkg_suffix = "" if int(cuda_major_version) >= 13 else f"-cu{cuda_major_version}" extras_require.update( { "cuda": [ - f"nvidia-cuda-nvrtc-cu{cuda_major_version}~={cuda_major_version}.0", - f"nvidia-cuda-runtime-cu{cuda_major_version}~={cuda_major_version}.0", - f"nvidia-cufft-cu{cuda_major_version}~={cufft_version}", - f"nvidia-curand-cu{cuda_major_version}~=10.0", + f"nvidia-cuda-nvrtc{cuda_pkg_suffix}~={cuda_major_version}.0", + f"nvidia-cuda-runtime{cuda_pkg_suffix}~={cuda_major_version}.0", + f"nvidia-cufft{cuda_pkg_suffix}~={cufft_version}", + f"nvidia-curand{cuda_pkg_suffix}~=10.0", ], "cudnn": [ f"nvidia-cudnn-cu{cuda_major_version}~=9.0", From 49efb3280faae9b365a8bbe564970172fc8cb63e Mon Sep 17 00:00:00 2001 From: Gopalakrishnan Nallasamy Date: Fri, 26 Jun 2026 17:26:15 -0700 Subject: [PATCH 17/19] Crypto support : App supply I/O callbacks to EP + callback and fallback helpers (#28624) ## Summary This PR adds an opt-in mechanism that lets an application supply its own I/O callbacks for an execution provider's EPContext binary data, so the data can live somewhere other than a plain file on disk (for example, an encrypted store or an in-memory buffer). It introduces the callback APIs end-to-end and demonstrates their use with a sample helper in the AutoEP example plugin EP. When an EP compiles a model into an EPContext model, it may emit the compiled blob either embedded in the ONNX model or as a separate external payload. For the external case, ORT previously assumed the payload is a file. These callbacks let the application own that read/write instead, while ORT core stays policy-neutral and never imposes a storage format. ### What this PR adds - **Write callback (`OrtWriteNamedBufferFunc`) + setter `OrtCompileApi::ModelCompilationOptions_SetEpContextDataWriteFunc`.** Set on `OrtModelCompilationOptions`, because writing EPContext binary data happens only during **compilation**. Passing a NULL callback clears a previously set one. - **Read callback (`OrtReadNamedBufferFunc`) + setter `OrtApi::SessionOptions_SetEpContextDataReadFunc`.** Set on `OrtSessionOptions`, because reading external EPContext binary data happens during **session load / inference**. Passing a NULL callback clears a previously set one. - **EP-facing access via `OrtEpContextConfig`.** Both callbacks are surfaced to execution providers through a single unified handle, `OrtEpContextConfig`, obtained via `OrtEpApi::SessionOptions_GetEpContextConfig` (getters `EpContextConfig_GetEpContextDataReadFunc` / `EpContextConfig_GetEpContextDataWriteFunc`, released with `ReleaseEpContextConfig`). This keeps the application-facing setters scoped to the correct lifecycle while giving EPs one consistent place to retrieve both callbacks. Each setter's doc comment cross-references the other so the split is discoverable. - **Experimental API surface + C++ accessors.** These functions ship through ORT's experimental API mechanism (declared in `include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc`), so they are reached via the generated `Ort::Experimental::Get__SinceV28_Fn(...)` / `...FnOrThrow(...)` accessors rather than fixed `OrtApi` slots. A move-only RAII wrapper, **`Ort::Experimental::EpContextConfig`** (in `onnxruntime_experimental_cxx_api.h`), owns an `OrtEpContextConfig` and exposes `GetReadFunc()` / `GetWriteFunc()`; it can be constructed directly from a C++ `SessionOptions` / `ConstSessionOptions`. - **Sample-only helper utilities** (`onnxruntime/test/autoep/library/ep_context_data_utils.h`) implementing callback-or-file fallback behavior: if a callback is supplied it is used, otherwise the helper falls back to direct file I/O. The AutoEP example plugin EP uses this helper for its external EPContext read/write paths. Because the names read on the load side originate from the untrusted EPContext model (`ep_cache_context` attribute), the helper validates them: it rejects absolute/rooted paths, `..` traversal, and directory-like names (`.` or a trailing separator), and confines model-relative names to the model directory (resolving `.`/`..` and symlinks via `std::filesystem::weakly_canonical`). It reports all failures via `OrtStatus*` (no exceptions) and lives outside the public C API / EP ABI, so it is purely illustrative and imposes no policy on ORT core; its doc comments note that production EPs should still apply their own sandboxing and payload size limits. The callback typedef names (`OrtReadNamedBufferFunc` / `OrtWriteNamedBufferFunc`) are intentionally generic. They are currently used for EPContext binary data, but the contract is deliberately storage-agnostic so future APIs can reuse the same callback shape for other named data payloads. ### Note on the Android workflow change `.github/workflows/android.yml` bumps the minimal-build binary-size threshold (`1436672` -> `1438720` bytes) to accommodate the small size increase from compiling the new experimental API into the Android minimal build. ## Testing - Built and tested in RelWithDebInfo: `python tools/ci_build/build.py --config RelWithDebInfo --build --parallel --test --build_dir build\Windows`. - Focused EPContext suites: - Public C/C++ API: `onnxruntime_shared_lib_test.exe --gtest_filter=EpContextDataApiTest.*` -> 9 passed. - AutoEP helper + compile/load end-to-end (callbacks and file fallback): `onnxruntime_autoep_test.exe --gtest_filter=*EpContext*` -> 17 passed, 1 skipped (`EpContextDataUtils_ResolvePathRejectsSymlinkEscape` requires the Windows "create symbolic link" privilege). - `clang-format` clean on touched C++ files; `git diff --check`: clean. Test layout: public EPContext API tests in `onnxruntime/test/shared_lib/test_ep_context_data_api.cc`; sample-helper unit tests in `onnxruntime/test/autoep/ep_context_data_utils_test.cc`; compile/load end-to-end tests in `onnxruntime/test/autoep/test_execution.cc`. --------- Co-authored-by: Gopalakrishnan Nallasamy Co-authored-by: Gopalakrishnan Nallasamy Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .github/workflows/android.yml | 4 +- cmake/onnxruntime_unittests.cmake | 2 + .../session/onnxruntime_experimental_c_api.h | 62 +++ .../onnxruntime_experimental_c_api.inc | 114 ++++ .../onnxruntime_experimental_cxx_api.h | 74 +++ .../core/framework/ep_context_options.cc | 4 + .../core/framework/ep_context_options.h | 14 + onnxruntime/core/framework/session_options.h | 6 + .../core/session/experimental_c_api.cc | 106 ++++ .../core/session/model_compilation_options.cc | 9 + .../core/session/model_compilation_options.h | 8 + .../test/autoep/ep_context_data_callbacks.h | 65 +++ .../test/autoep/ep_context_data_utils_test.cc | 327 ++++++++++++ .../autoep/library/ep_context_data_utils.h | 501 ++++++++++++++++++ .../autoep/library/example_plugin_ep/ep.cc | 64 ++- .../autoep/library/example_plugin_ep/ep.h | 12 +- .../library/example_plugin_ep/ep_factory.cc | 18 +- onnxruntime/test/autoep/test_execution.cc | 266 ++++++++++ onnxruntime/test/autoep/test_model_package.cc | 5 + .../test/framework/ep_plugin_provider_test.cc | 2 + .../shared_lib/test_ep_context_data_api.cc | 331 ++++++++++++ 21 files changed, 1977 insertions(+), 17 deletions(-) create mode 100644 onnxruntime/test/autoep/ep_context_data_callbacks.h create mode 100644 onnxruntime/test/autoep/ep_context_data_utils_test.cc create mode 100644 onnxruntime/test/autoep/library/ep_context_data_utils.h create mode 100644 onnxruntime/test/shared_lib/test_ep_context_data_api.cc diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 252ea7281d981..954c3313faf25 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -78,8 +78,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1436672" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1436672" + echo "Binary size threshold in bytes: 1438720" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1438720" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index c32aa7f4ae75a..bb171b056b400 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -603,6 +603,7 @@ set (onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ep_context_data_api.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_experimental_api.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc @@ -2174,6 +2175,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND # file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep/*.cc" + "${TEST_SRC_DIR}/autoep/library/ep_context_data_utils.h" "${TEST_SRC_DIR}/autoep/library/plugin_ep_utils.h") onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src}) target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h index e5b9bd4713a1c..ce76bd385cd85 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h @@ -46,6 +46,68 @@ ORT_RUNTIME_CLASS(ModelPackageOptions); ORT_RUNTIME_CLASS(ModelPackageContext); ORT_RUNTIME_CLASS(ModelPackageComponentContext); +// Opaque handle holding the EPContext callbacks and opaque state extracted from an OrtSessionOptions instance. Used by +// the experimental OrtEpApi_* EPContext data functions. Create via OrtEpApi_SessionOptions_GetEpContextConfig and +// release with OrtEpApi_ReleaseEpContextConfig. +ORT_RUNTIME_CLASS(EpContextConfig); + +/** \brief Function called to write named binary data. + * + * This callback is currently used for EPContext binary data, but its contract is intentionally generic so future APIs + * can reuse it for other named data payloads. The callback is called synchronously by the component that receives it. + * ORT does not own or retain buffer after the callback returns. ORT does not serialize invocations made by different + * EP instances or worker threads. + * + * Each callback invocation represents one complete write operation for name. The callback signature does not + * provide an offset, sequence number, or final-chunk marker, so the component invoking the callback must define any + * chunked ordering and completion contract with the application. Current EPContext use should prefer a single callback + * invocation per EPContext binary unless chunking semantics are documented by the EP. + * + * The application's implementation can process the data in any way (e.g., encrypt and store, upload to cloud storage, + * or compress) before persisting it. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid for the duration required by the API that accepted the callback and must provide + * any synchronization required if it can be used concurrently. + * \param[in] name The file name or logical data identifier as a null-terminated UTF-8 string. + * \param[in] buffer The buffer containing data to write. + * \param[in] buffer_num_bytes The size of the buffer in bytes. + * + * \return OrtStatus* Write status. Return nullptr on success. + * On failure, use CreateStatus to provide error info with an appropriate OrtErrorCode + * (e.g., ORT_FAIL); ORT propagates the returned code. ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtWriteNamedBufferFunc)(_In_ void* state, + _In_ const char* name, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes); + +/** \brief Function called to read named binary data. + * + * This callback is currently used for EPContext binary data, but its contract is intentionally generic so future APIs + * can reuse it for other named data payloads. The application reads, processes (e.g., decrypts, decompresses, + * downloads), and returns the requested data. ORT provides an allocator so the application can allocate the output + * buffer directly. The callback is called synchronously by the component that receives it. ORT does not serialize + * invocations made by different EP instances or worker threads. + * + * \param[in] state Opaque pointer holding the user's state. ORT does not own or manage this pointer. The application + * must keep it valid for the duration required by the API that accepted the callback and must provide + * any synchronization required if it can be used concurrently. + * \param[in] name The file name or logical data identifier to read as a null-terminated UTF-8 string. + * \param[in] allocator ORT-provided allocator. The application must use this to allocate the output buffer. + * \param[out] buffer Set by the implementation to the allocated buffer containing the output data. + * \param[out] data_size Set by the implementation to the size of the output data in bytes. + * + * \return OrtStatus* Read status. Return nullptr on success. + * On failure, use CreateStatus to provide error info with an appropriate OrtErrorCode + * (e.g., ORT_FAIL); ORT propagates the returned code. ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* OrtReadNamedBufferFunc)(_In_ void* state, + _In_ const char* name, + _In_ OrtAllocator* allocator, + _Outptr_ void** buffer, + _Out_ size_t* data_size); + // // C function pointer typedefs and name constants // diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 57a4e472b6f6d..b1140485f7ff1 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -282,3 +282,117 @@ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession, _In_ OrtModelPackageComponentContext* context, _In_opt_ const OrtSessionOptions* session_options, _Outptr_ OrtSession** session) + +/** \brief Registers a callback to provide EPContext binary data during session load. + * + * When loading a compiled model with external (non-embedded) EPContext binary data, an execution provider can + * retrieve this callback from OrtEpContextConfig and call it instead of reading the binary data from disk. + * + * Reading happens at session load, so this callback is configured on OrtSessionOptions. The corresponding write + * callback runs only at compile time and is configured on OrtModelCompilationOptions via + * OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid while any session or EP created + * from these options may call the callback. If the same state may be used by multiple EPs or threads, the application + * is responsible for synchronization. + * + * \param[in] options The OrtSessionOptions instance. + * \param[in] read_func The OrtReadNamedBufferFunc callback. Pass NULL to clear a previously set callback (any + * previously set state is cleared as well). + * \param[in] state Opaque state passed to read_func. Can be NULL. Ignored when read_func is NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_SessionOptions_SetEpContextDataReadFunc, + _Inout_ OrtSessionOptions* options, _In_opt_ OrtReadNamedBufferFunc read_func, _In_opt_ void* state) + +/** \brief Sets a callback for writing EPContext binary data during compilation. + * + * When EPContext embed mode is disabled, execution providers can retrieve this callback from OrtEpContextConfig and + * call it instead of writing EPContext binary data directly to disk. + * + * This callback may be used together with OrtCompileApi::ModelCompilationOptions_SetEpContextBinaryInformation. The + * binary information still describes the compiled model/output location that EPs may use to generate stable logical + * EPContext data names or as a file-fallback location. If this callback is configured, EPs should call it for + * EPContext binary data instead of writing that data to the fallback file path. + * + * Writing happens only at compile time, so this callback is configured on OrtModelCompilationOptions. The + * corresponding read callback runs at session load and is configured on OrtSessionOptions via + * OrtApi_SessionOptions_SetEpContextDataReadFunc. + * + * The state pointer is stored as-is and is not owned by ORT. It must remain valid for the duration of the compile + * operation that may call the callback. If the same state may be used by multiple EPs or threads, the application is + * responsible for synchronization. + * + * Like OrtApi_SessionOptions_SetEpContextDataReadFunc, passing a NULL write_func clears any previously set callback + * (any previously set state is cleared as well). Calling this multiple times overwrites the previously configured + * callback. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_func The OrtWriteNamedBufferFunc callback used to write EPContext bytes. Pass NULL to clear a + * previously set callback (any previously set state is cleared as well). + * \param[in] state Opaque state passed to write_func. Can be NULL. Ignored when write_func is NULL. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_opt_ OrtWriteNamedBufferFunc write_func, _In_opt_ void* state) + +/** \brief Extracts the EPContext configuration (callbacks and state) from an OrtSessionOptions instance. + * + * The EP should call this during CreateEp() while session_options is still valid, and store the returned handle for + * use during Compile(). On success, `*config` is set to a non-NULL handle that must be released with + * OrtEpApi_ReleaseEpContextConfig. On failure, an error status is returned and `*config` is not modified. + * + * The returned handle owns only ORT's copy of callback function pointers and opaque state pointer values. It does not + * own the application-provided state. The application is responsible for keeping callback state valid and + * synchronized while an EP may call callbacks retrieved from this config. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[out] config The extracted OrtEpContextConfig. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_SessionOptions_GetEpContextConfig, + _In_ const OrtSessionOptions* session_options, _Outptr_ OrtEpContextConfig** config) + +/** \brief Release an OrtEpContextConfig instance. + * + * \param[in] config The OrtEpContextConfig instance to release. May be NULL. + */ +ORT_EXPERIMENTAL_API(28, void, OrtEpApi_ReleaseEpContextConfig, _Frees_ptr_opt_ OrtEpContextConfig* config) + +/** \brief Get the application-provided EPContext data read callback. + * + * Returns the OrtReadNamedBufferFunc and opaque state pointer registered via + * OrtApi_SessionOptions_SetEpContextDataReadFunc. If no callback was registered, *read_func and *state are set to + * NULL. The EP is responsible for calling the callback when present and for using its own normal read path when no + * callback is present. + * + * \param[in] config The OrtEpContextConfig from OrtEpApi_SessionOptions_GetEpContextConfig. + * \param[out] read_func The registered read callback, or NULL if none was registered. + * \param[out] state Opaque state pointer passed to read_func, or NULL if none was registered. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_EpContextConfig_GetEpContextDataReadFunc, + _In_ const OrtEpContextConfig* config, _Out_ OrtReadNamedBufferFunc* read_func, + _Out_ void** state) + +/** \brief Get the application-provided EPContext data write callback. + * + * Returns the OrtWriteNamedBufferFunc and opaque state pointer registered via + * OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc. If no callback was registered, *write_func and + * *state are set to NULL. The EP is responsible for calling the callback when present and for using its own normal + * write path when no callback is present. + * + * \param[in] config The OrtEpContextConfig from OrtEpApi_SessionOptions_GetEpContextConfig. + * \param[out] write_func The registered write callback, or NULL if none was registered. + * \param[out] state Opaque state pointer passed to write_func, or NULL if none was registered. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc, + _In_ const OrtEpContextConfig* config, _Out_ OrtWriteNamedBufferFunc* write_func, + _Out_ void** state) diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h index fbd7ba3659435..d3f3a37eca33d 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_cxx_api.h @@ -108,5 +108,79 @@ namespace Experimental { // C++ wrapper types or helpers go here in the `Ort::Experimental` namespace. // +// Move-only RAII owner for an OrtEpContextConfig handle, which carries the EPContext read/write callbacks and opaque +// state extracted from an OrtSessionOptions instance. The handle is released via OrtEpApi_ReleaseEpContextConfig when +// the wrapper is destroyed. +// +// Typical EP usage: construct from the session options during CreateEp(), keep the wrapper for the EP's lifetime, and +// query the callbacks via GetReadFunc() / GetWriteFunc(). +class EpContextConfig { + public: + explicit EpContextConfig(std::nullptr_t) noexcept {} + + explicit EpContextConfig(const SessionOptions& session_options) : EpContextConfig{session_options.GetConst()} {} + + // Extracts the EPContext config from `session_options`. Throws Ort::Exception (ORT_NOT_IMPLEMENTED) if the + // experimental functions are not available in this build, or propagates any error from the extraction. + explicit EpContextConfig(ConstSessionOptions session_options) { + const OrtApi* api = &GetApi(); + // Ensure the release function is available before creating a handle, so the handle can always be freed. + Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_FnOrThrow(api); + auto* get_config = Get_OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28_FnOrThrow(api); + ThrowOnError(get_config(static_cast(session_options), &config_)); + } + + EpContextConfig(EpContextConfig&& other) noexcept : config_{other.config_} { other.config_ = nullptr; } + + EpContextConfig& operator=(EpContextConfig&& other) noexcept { + if (this != &other) { + reset(); + config_ = other.config_; + other.config_ = nullptr; + } + return *this; + } + + EpContextConfig(const EpContextConfig&) = delete; + EpContextConfig& operator=(const EpContextConfig&) = delete; + + ~EpContextConfig() { reset(); } + + OrtEpContextConfig* get() const noexcept { return config_; } + explicit operator bool() const noexcept { return config_ != nullptr; } + + // Relinquishes ownership of the handle without releasing it. + OrtEpContextConfig* release() noexcept { + OrtEpContextConfig* released = config_; + config_ = nullptr; + return released; + } + + // Releases any owned handle and resets to empty. + void reset() noexcept { + if (config_ != nullptr) { + if (auto* release_fn = Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_Fn(&GetApi())) { + release_fn(config_); + } + config_ = nullptr; + } + } + + // Returns the configured read callback and opaque state (both nullptr if none was set). Throws on failure. + void GetReadFunc(OrtReadNamedBufferFunc& read_func, void*& state) const { + auto* get_read_func = Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_FnOrThrow(&GetApi()); + ThrowOnError(get_read_func(config_, &read_func, &state)); + } + + // Returns the configured write callback and opaque state (both nullptr if none was set). Throws on failure. + void GetWriteFunc(OrtWriteNamedBufferFunc& write_func, void*& state) const { + auto* get_write_func = Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_FnOrThrow(&GetApi()); + ThrowOnError(get_write_func(config_, &write_func, &state)); + } + + private: + OrtEpContextConfig* config_ = nullptr; +}; + } // namespace Experimental } // namespace Ort diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index 99fa21b1e4be8..b53a99084152f 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -56,6 +56,10 @@ const BufferWriteFuncHolder* ModelGenOptions::TryGetOutputModelWriteFunc() const return std::get_if(&output_model_location); } +const EpContextDataWriteFuncHolder* ModelGenOptions::TryGetEpContextDataWriteFunc() const { + return ep_context_data_write_func.write_func != nullptr ? &ep_context_data_write_func : nullptr; +} + bool ModelGenOptions::AreInitializersEmbeddedInOutputModel() const { return std::holds_alternative(initializers_location); } diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 6643516bfb4c3..344ec5f7f6b58 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -7,6 +7,9 @@ #include #include "core/framework/allocator.h" #include "core/framework/config_options.h" +// Needed for OrtWriteNamedBufferFunc (used by EpContextDataWriteFuncHolder below). This include can be removed +// once the experimental EPContext data callback APIs are promoted to the stable C API. +#include "core/session/onnxruntime_experimental_c_api.h" namespace onnxruntime { namespace epctx { @@ -27,6 +30,14 @@ struct BufferWriteFuncHolder { void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. }; +/// +/// Holds the opaque state and write function that EPs use to write EPContext binary data. +/// +struct EpContextDataWriteFuncHolder { + OrtWriteNamedBufferFunc write_func = nullptr; + void* state = nullptr; +}; + /// /// Holds path and size threshold used to write out initializers to an external file. /// @@ -84,10 +95,13 @@ struct ModelGenOptions { InitializerHandler> // Custom function called for every initializer to determine location. initializers_location = std::monostate{}; + EpContextDataWriteFuncHolder ep_context_data_write_func = {}; + bool HasOutputModelLocation() const; const std::filesystem::path* TryGetOutputModelPath() const; const BufferHolder* TryGetOutputModelBuffer() const; const BufferWriteFuncHolder* TryGetOutputModelWriteFunc() const; + const EpContextDataWriteFuncHolder* TryGetEpContextDataWriteFunc() const; bool AreInitializersEmbeddedInOutputModel() const; const ExternalInitializerFileInfo* TryGetExternalInitializerFileInfo() const; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index b328fc916f885..ddd21074afe8a 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -16,6 +16,9 @@ #include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" +// Needed for OrtReadNamedBufferFunc, the type of the EPContext data read callback stored in this struct. This include +// can be removed once the experimental EPContext data callback APIs are promoted to the stable C API. +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/optimizer/graph_transformer_level.h" #include "core/util/thread_utils.h" @@ -226,6 +229,9 @@ struct SessionOptions { bool has_explicit_ep_context_gen_options = false; epctx::ModelGenOptions ep_context_gen_options = {}; epctx::ModelGenOptions GetEpContextGenerationOptions() const; + + OrtReadNamedBufferFunc ep_context_data_read_func = nullptr; + void* ep_context_data_read_state = nullptr; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc index 458c47bcb58cd..5b1a8460d1e1f 100644 --- a/onnxruntime/core/session/experimental_c_api.cc +++ b/onnxruntime/core/session/experimental_c_api.cc @@ -6,12 +6,30 @@ #include #include +#include +#include "core/common/common.h" #include "core/framework/error_code_helper.h" +#include "core/framework/ep_context_options.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/ort_apis.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/model_compilation_options.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +// Backing definition of the OrtEpContextConfig handle used by the experimental OrtEpApi_* EPContext data functions. +// Holds copies of the application's EPContext read/write callbacks and opaque state extracted from an +// OrtSessionOptions instance. +struct OrtEpContextConfig { + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; +}; + // --------------------------------------------------------------------------- // Experimental function implementations // --------------------------------------------------------------------------- @@ -40,6 +58,94 @@ ORT_API_STATUS_IMPL(OrtApi_ExperimentalApiTest_SinceV28, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28, _Inout_ OrtSessionOptions* options, + _In_opt_ OrtReadNamedBufferFunc read_func, _In_opt_ void* state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + + // Passing a null read_func clears any previously set callback. Clear the state too so a stale state pointer is + // never paired with a missing callback. + options->value.ep_context_data_read_func = read_func; + options->value.ep_context_data_read_state = read_func != nullptr ? state : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_opt_ OrtWriteNamedBufferFunc write_func, _In_opt_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF(ort_model_compile_options == nullptr, ORT_INVALID_ARGUMENT, "OrtModelCompilationOptions is NULL"); + + // A null write_func clears any previously set callback (symmetric with OrtApi_SessionOptions_SetEpContextDataReadFunc + // and consistent with calling this multiple times to overwrite the callback). + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + model_compile_options->SetEpContextDataWriteFunc(write_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtEpContextConfig** config) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(session_options == nullptr, ORT_INVALID_ARGUMENT, "OrtSessionOptions is NULL"); + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "Output OrtEpContextConfig is NULL"); + + auto ep_context_config = std::make_unique(); + if (const auto* write_config = session_options->value.ep_context_gen_options.TryGetEpContextDataWriteFunc()) { + ep_context_config->write_func = write_config->write_func; + ep_context_config->write_state = write_config->state; + } + ep_context_config->read_func = session_options->value.ep_context_data_read_func; + ep_context_config->read_state = session_options->value.ep_context_data_read_state; + + *config = ep_context_config.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtEpApi_ReleaseEpContextConfig_SinceV28, _Frees_ptr_opt_ OrtEpContextConfig* config) { + delete config; +} + +ORT_API_STATUS_IMPL(OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28, + _In_ const OrtEpContextConfig* config, + _Out_ OrtReadNamedBufferFunc* read_func, + _Out_ void** state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "OrtEpContextConfig is NULL"); + ORT_API_RETURN_IF(read_func == nullptr, ORT_INVALID_ARGUMENT, "Output read_func is NULL"); + ORT_API_RETURN_IF(state == nullptr, ORT_INVALID_ARGUMENT, "Output state is NULL"); + + *read_func = config->read_func; + *state = config->read_func != nullptr ? config->read_state : nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28, + _In_ const OrtEpContextConfig* config, + _Out_ OrtWriteNamedBufferFunc* write_func, + _Out_ void** state) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(config == nullptr, ORT_INVALID_ARGUMENT, "OrtEpContextConfig is NULL"); + ORT_API_RETURN_IF(write_func == nullptr, ORT_INVALID_ARGUMENT, "Output write_func is NULL"); + ORT_API_RETURN_IF(state == nullptr, ORT_INVALID_ARGUMENT, "Output state is NULL"); + + *write_func = config->write_func; + *state = config->write_func != nullptr ? config->write_state : nullptr; + return nullptr; + API_IMPL_END +} + } // namespace OrtExperimentalApis // --------------------------------------------------------------------------- diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index a17802bdd7573..9f6d1f9f1a9bc 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -133,6 +133,15 @@ void ModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( }; } +void ModelCompilationOptions::SetEpContextDataWriteFunc(OrtWriteNamedBufferFunc write_func, void* state) { + // A null write_func clears any previously set callback. Clear the state too so a stale state pointer is never + // paired with a missing callback. + session_options_.value.ep_context_gen_options.ep_context_data_write_func = epctx::EpContextDataWriteFuncHolder{ + write_func, + write_func != nullptr ? state : nullptr, + }; +} + Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::filesystem::path& output_directory, const std::filesystem::path& model_name) { if (output_directory.empty() || model_name.empty()) { diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 47529e794677e..a15af565c4d54 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -13,6 +13,7 @@ #include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -97,6 +98,13 @@ class ModelCompilationOptions { void SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, void* state); + /// + /// Sets a user-provided function to handle EPContext binary data writes. + /// + /// The user-provided OrtWriteNamedBufferFunc callback used to write EPContext data. + /// The user's state. + void SetEpContextDataWriteFunc(OrtWriteNamedBufferFunc write_func, void* state); + /// /// Sets information relate to EP context binary file. /// EP use this information to decide the location and context binary file name. diff --git a/onnxruntime/test/autoep/ep_context_data_callbacks.h b/onnxruntime/test/autoep/ep_context_data_callbacks.h new file mode 100644 index 0000000000000..d25c7a5c571ee --- /dev/null +++ b/onnxruntime/test/autoep/ep_context_data_callbacks.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace test { + +// Shared EPContext read/write callback test doubles, used by both the EpContextDataUtils unit tests +// (ep_context_data_utils_test.cc) and the PluginEp end-to-end EPContext tests (test_execution.cc). +struct EpContextDataCallbackState { + bool write_called = false; + bool read_called = false; + std::string write_file_name; + std::string read_file_name; + std::vector payload; +}; + +inline OrtStatus* ORT_API_CALL StoreEpContextDataCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* callback_state = static_cast(state); + callback_state->write_called = true; + callback_state->write_file_name = file_name; + callback_state->payload.clear(); + if (buffer_size != 0) { + if (buffer == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, + "StoreEpContextDataCallback received a null buffer for non-empty data"); + } + callback_state->payload.assign(static_cast(buffer), static_cast(buffer) + buffer_size); + } + return nullptr; +} + +inline OrtStatus* ORT_API_CALL LoadEpContextDataCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* callback_state = static_cast(state); + callback_state->read_called = true; + callback_state->read_file_name = file_name; + + *buffer = nullptr; + *data_size = callback_state->payload.size(); + if (callback_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, callback_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::copy(callback_state->payload.begin(), callback_state->payload.end(), static_cast(*buffer)); + return nullptr; +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/ep_context_data_utils_test.cc b/onnxruntime/test/autoep/ep_context_data_utils_test.cc new file mode 100644 index 0000000000000..8425280e4845b --- /dev/null +++ b/onnxruntime/test/autoep/ep_context_data_utils_test.cc @@ -0,0 +1,327 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Unit tests for the sample-only EPContext data helpers in +// onnxruntime/test/autoep/library/ep_context_data_utils.h. + +#include +#include +#include +#include + +#include +#include + +#include "core/graph/model_editor_api_types.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" + +#include "test/autoep/ep_context_data_callbacks.h" +#include "test/autoep/library/ep_context_data_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +namespace { + +OrtStatus* ORT_API_CALL LoadInvalidEpContextDataCallback(void* state, const char* file_name, + OrtAllocator* /*allocator*/, void** buffer, + size_t* data_size) { + auto* callback_state = static_cast(state); + callback_state->read_called = true; + callback_state->read_file_name = file_name; + + *buffer = nullptr; + *data_size = 1; + return nullptr; +} + +void ExpectOrtStatusError(OrtStatus* status_ptr, OrtErrorCode expected_code, std::string_view expected_message) { + Ort::Status status{status_ptr}; + ASSERT_NE(status_ptr, nullptr) << "Expected a failure status, but the API returned nullptr (OK)."; + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), expected_code); + EXPECT_THAT(std::string{status.GetErrorMessage()}, ::testing::HasSubstr(std::string{expected_message})); +} + +std::filesystem::path PrepareTempTestDir(std::string_view name) { + std::filesystem::path test_dir = std::string{name}; + std::filesystem::remove_all(test_dir); + std::filesystem::create_directories(test_dir); + return test_dir; +} + +} // namespace + +TEST(OrtEpLibrary, EpContextDataUtils_PathHelpersRoundTrip) { + const auto& api = Ort::GetApi(); + const std::string file_name = "context_data.bin"; + +#ifdef _WIN32 + std::wstring wide_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8ToWideString(api, file_name, wide_file_name)); + ASSERT_FALSE(wide_file_name.empty()); + std::string round_tripped_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WideToUtf8String(api, wide_file_name, round_tripped_file_name)); + EXPECT_EQ(round_tripped_file_name, file_name); + + const std::string invalid_utf8(1, static_cast(0xff)); + std::wstring invalid_wide; + ExpectOrtStatusError(ep_context_data_utils::Utf8ToWideString(api, invalid_utf8, invalid_wide), + ORT_INVALID_ARGUMENT, "not valid UTF-8"); + EXPECT_TRUE(invalid_wide.empty()); +#endif + + std::filesystem::path file_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, file_name.c_str(), file_path)); + ASSERT_FALSE(file_path.empty()); + std::string round_tripped_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, file_path, round_tripped_path)); + EXPECT_EQ(round_tripped_path, file_name); + + std::filesystem::path empty_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, nullptr, empty_path)); + EXPECT_TRUE(empty_path.empty()); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::Utf8Path(api, "", empty_path)); + EXPECT_TRUE(empty_path.empty()); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathAndInvalidArguments) { + const auto& api = Ort::GetApi(); + std::filesystem::path data_path; + + data_path = "stale.ctx"; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, nullptr, nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + EXPECT_TRUE(data_path.empty()); + + data_path = "stale.ctx"; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "", nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + EXPECT_TRUE(data_path.empty()); + + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ResolveEpContextDataPath(api, "relative.ctx", nullptr, data_path)); + std::string resolved_data_path; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, data_path, resolved_data_path)); + EXPECT_EQ(resolved_data_path, "relative.ctx"); + + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataToFile(api, "unused.ctx", nullptr, nullptr, 1), + ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback(api, nullptr, "unused.ctx", nullptr, + nullptr, 1), + ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataWithFileFallback(api, nullptr, "", nullptr, data), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback(api, nullptr, "", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "logical_context_data.bin", "", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data fallback file name must not be empty"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathRejectsUnsafeNames) { + const auto& api = Ort::GetApi(); + std::filesystem::path data_path; + + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "../escape.ctx", nullptr, data_path), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + EXPECT_TRUE(data_path.empty()); + +#ifdef _WIN32 + const char* absolute_file_name = "C:\\temp\\escape.ctx"; + const char* drive_relative_file_name = "C:escape.ctx"; + const char* root_relative_file_name = "\\escape.ctx"; +#else + const char* absolute_file_name = "/tmp/escape.ctx"; +#endif + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ResolveEpContextDataPath(api, absolute_file_name, nullptr, data_path)); + EXPECT_TRUE(data_path.is_absolute()); + +#ifdef _WIN32 + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, drive_relative_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, root_relative_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); +#endif + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataFromFile(api, "../escape.ctx", nullptr, data), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, absolute_file_name, "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + + ModelEditorGraph empty_model_path_graph; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "../escape.ctx", + empty_model_path_graph.ToExternal(), data_path), + ORT_INVALID_ARGUMENT, "requires a model path"); + + // A model-derived name that designates a directory ("." or a trailing separator with an empty filename) is + // rejected up front, rather than resolving to a directory and failing later with a confusing I/O error. + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, ".", empty_model_path_graph.ToExternal(), + data_path), + ORT_INVALID_ARGUMENT, "must refer to a file"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, ".", "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "must refer to a file"); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "sub/", "unused.ctx", nullptr, nullptr, 0), + ORT_INVALID_ARGUMENT, "must refer to a file"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ResolvePathRejectsSymlinkEscape) { + const auto& api = Ort::GetApi(); + const std::filesystem::path test_dir = PrepareTempTestDir("ort_ep_context_data_utils_symlink_escape_test"); + auto cleanup = gsl::finally([&]() { std::filesystem::remove_all(test_dir); }); + + const std::filesystem::path model_dir = test_dir / "model_dir"; + const std::filesystem::path outside_dir = test_dir / "outside_dir"; + ASSERT_TRUE(std::filesystem::create_directories(model_dir)); + ASSERT_TRUE(std::filesystem::create_directories(outside_dir)); + + const std::filesystem::path symlink_path = model_dir / "linked_outside"; + // Relative symlink targets are resolved by the OS relative to the link's own directory, not the test's working + // directory. Point to the sibling outside_dir using a link-relative target; using the test_dir-relative + // `outside_dir` path here would create a dangling link under model_dir, and weakly_canonical() would not traverse it. + const std::filesystem::path symlink_target = std::filesystem::path{".."} / outside_dir.filename(); + std::error_code symlink_error; + std::filesystem::create_directory_symlink(symlink_target, symlink_path, symlink_error); + if (symlink_error) { + GTEST_SKIP() << "Unable to create directory symlink for containment test: " << symlink_error.message(); + } + + ModelEditorGraph graph; + graph.model_path = model_dir / "model.onnx"; + + std::filesystem::path data_path; + ExpectOrtStatusError(ep_context_data_utils::ResolveEpContextDataPath(api, "linked_outside/escape.ctx", + graph.ToExternal(), data_path), + ORT_INVALID_ARGUMENT, "resolve to a path within the model directory"); + EXPECT_TRUE(data_path.empty()); +} + +TEST(OrtEpLibrary, EpContextDataUtils_FileFallbackReadsAndWrites) { + const auto& api = Ort::GetApi(); + const std::filesystem::path test_dir = PrepareTempTestDir("ort_ep_context_data_utils_file_fallback_test"); + auto cleanup = gsl::finally([&]() { std::filesystem::remove_all(test_dir); }); + + const std::string payload = "file fallback payload"; + const std::filesystem::path data_path = test_dir / "context_data.bin"; + std::string data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, data_path, data_file_name)); + + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataToFile(api, data_file_name.c_str(), nullptr, + payload.data(), payload.size())); + + std::vector data; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(api, data_file_name.c_str(), nullptr, data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path wrapper_data_path = test_dir / "wrapper_context_data.bin"; + std::string wrapper_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, wrapper_data_path, wrapper_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, wrapper_data_file_name.c_str(), nullptr, payload.data(), payload.size())); + + data.clear(); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, nullptr, wrapper_data_file_name.c_str(), nullptr, data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path fallback_data_path = test_dir / "fallback_context_data.bin"; + std::string fallback_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, fallback_data_path, fallback_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "logical_context_data.bin", fallback_data_file_name.c_str(), nullptr, payload.data(), + payload.size())); + + const std::filesystem::path unsafe_logical_fallback_path = test_dir / "unsafe_logical_context_data.bin"; + std::string unsafe_logical_fallback_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, unsafe_logical_fallback_path, + unsafe_logical_fallback_file_name)); + ExpectOrtStatusError(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, "../logical_context_data.bin", + unsafe_logical_fallback_file_name.c_str(), nullptr, payload.data(), payload.size()), + ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + + data.clear(); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(api, fallback_data_file_name.c_str(), nullptr, + data)); + EXPECT_EQ(std::string(data.begin(), data.end()), payload); + + const std::filesystem::path empty_data_path = test_dir / "empty_context_data.bin"; + std::string empty_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, empty_data_path, empty_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, nullptr, empty_data_file_name.c_str(), nullptr, nullptr, 0)); + + data.assign({'s', 't', 'a', 'l', 'e'}); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, nullptr, empty_data_file_name.c_str(), nullptr, data)); + EXPECT_TRUE(data.empty()); + + const std::filesystem::path missing_data_path = test_dir / "missing_context_data.bin"; + std::string missing_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(api, missing_data_path, missing_data_file_name)); + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataFromFile(api, missing_data_file_name.c_str(), nullptr, + data), + ORT_FAIL, "Failed to open EPContext data file for read"); +} + +TEST(OrtEpLibrary, EpContextDataUtils_CallbackFallbackUsesCallbacks) { + const auto& api = Ort::GetApi(); + + EpContextDataCallbackState read_callback_state; + read_callback_state.payload = {'c', 'a', 'l', 'l', 'b', 'a', 'c', 'k'}; + EpContextDataCallbackState write_callback_state; + + std::vector data; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, LoadEpContextDataCallback, &read_callback_state, "callback_context.bin", nullptr, data)); + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, "callback_context.bin"); + EXPECT_EQ(data, read_callback_state.payload); + + const std::string payload = "callback write payload"; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, StoreEpContextDataCallback, &write_callback_state, "callback_write_context.bin", + "callback_write_context.bin", nullptr, payload.data(), payload.size())); + ASSERT_TRUE(write_callback_state.write_called); + EXPECT_EQ(write_callback_state.write_file_name, "callback_write_context.bin"); + EXPECT_EQ(std::string(write_callback_state.payload.begin(), write_callback_state.payload.end()), payload); + + write_callback_state = {}; + const std::string payload_with_unused_fallback = "callback write payload with unused fallback"; + // With a callback present the file fallback is never used, so the empty fallback name is accepted (not validated). + ASSERT_ORTSTATUS_OK(ep_context_data_utils::WriteEpContextDataWithFileFallback( + api, StoreEpContextDataCallback, &write_callback_state, + "callback_write_context_unused_fallback.bin", "", nullptr, + payload_with_unused_fallback.data(), payload_with_unused_fallback.size())); + ASSERT_TRUE(write_callback_state.write_called); + EXPECT_EQ(write_callback_state.write_file_name, "callback_write_context_unused_fallback.bin"); + EXPECT_EQ(std::string(write_callback_state.payload.begin(), write_callback_state.payload.end()), + payload_with_unused_fallback); +} + +TEST(OrtEpLibrary, EpContextDataUtils_ReadCallbackRejectsNullBufferForNonEmptyPayload) { + const auto& api = Ort::GetApi(); + + EpContextDataCallbackState read_callback_state; + + std::vector data; + ExpectOrtStatusError(ep_context_data_utils::ReadEpContextDataWithFileFallback( + api, LoadInvalidEpContextDataCallback, &read_callback_state, + "invalid_callback_context.bin", nullptr, data), + ORT_FAIL, "OrtReadNamedBufferFunc returned a null buffer for non-empty EPContext data"); + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, "invalid_callback_context.bin"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/library/ep_context_data_utils.h b/onnxruntime/test/autoep/library/ep_context_data_utils.h new file mode 100644 index 0000000000000..a3f9a377ee92f --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_context_data_utils.h @@ -0,0 +1,501 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +// Define NOMINMAX (and WIN32_LEAN_AND_MEAN) before so the min/max macros it would otherwise pull in do +// not clobber std::numeric_limits<...>::max() and std::min/std::max used in this header. +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#include "plugin_ep_utils.h" +#include "onnxruntime_experimental_cxx_api.h" + +// Sample-only EPContext data helpers shared by the example plugin EP and its tests. These are intentionally outside +// the ORT C and EP ABI and are provided as a reference for EP authors that need to handle external (non-embedded) +// EPContext binary data. +// +// The intended entry points for EP implementers are the ReadEpContextDataWithFileFallback / +// WriteEpContextDataWithFileFallback overloads: they prefer an application-supplied OrtReadNamedBufferFunc / +// OrtWriteNamedBufferFunc (carried by OrtEpContextConfig) and fall back to file I/O when no callback is configured. +// The other functions are lower-level building blocks. Production EPs should additionally apply their own sandboxing, +// size limits, and path policies; see the per-function notes on how untrusted, model-derived names are treated. +namespace ep_context_data_utils { + +#ifdef _WIN32 +inline std::string WindowsLastErrorMessage(std::string_view message, DWORD error_code) { + return std::string{message} + " GetLastError=" + std::to_string(error_code); +} + +// Converts a UTF-8 string to a wide string. Reports conversion failures (e.g., invalid UTF-8) via OrtStatus* instead +// of silently returning an empty string. An empty input yields an empty output and a success status. +inline OrtStatus* Utf8ToWideString(const OrtApi& api, std::string_view value, std::wstring& wide_value) { + wide_value.clear(); + if (value.empty()) { + return nullptr; + } + if (value.size() > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is too long to convert"); + } + + const int wide_length = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, value.data(), + static_cast(value.size()), nullptr, 0); + if (wide_length <= 0) { + const std::string message = WindowsLastErrorMessage( + "EPContext data file name is not valid UTF-8 or could not be converted to a wide string.", GetLastError()); + return api.CreateStatus(ORT_INVALID_ARGUMENT, message.c_str()); + } + + wide_value.resize(static_cast(wide_length)); + const int converted = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, value.data(), + static_cast(value.size()), wide_value.data(), wide_length); + if (converted != wide_length) { + wide_value.clear(); + const std::string message = WindowsLastErrorMessage("Failed to convert EPContext data file name to a wide string.", + GetLastError()); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + return nullptr; +} + +// Converts a wide string to UTF-8. Reports conversion failures via OrtStatus* instead of silently returning an empty +// string. An empty input yields an empty output and a success status. +inline OrtStatus* WideToUtf8String(const OrtApi& api, std::wstring_view value, std::string& utf8_value) { + utf8_value.clear(); + if (value.empty()) { + return nullptr; + } + if (value.size() > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is too long to convert"); + } + + const int utf8_length = WideCharToMultiByte(CP_UTF8, 0, value.data(), static_cast(value.size()), + nullptr, 0, nullptr, nullptr); + if (utf8_length <= 0) { + const std::string message = WindowsLastErrorMessage( + "EPContext data file name could not be converted to UTF-8.", GetLastError()); + return api.CreateStatus(ORT_INVALID_ARGUMENT, message.c_str()); + } + + utf8_value.resize(static_cast(utf8_length)); + const int converted = WideCharToMultiByte(CP_UTF8, 0, value.data(), static_cast(value.size()), + utf8_value.data(), utf8_length, nullptr, nullptr); + if (converted != utf8_length) { + utf8_value.clear(); + const std::string message = WindowsLastErrorMessage("Failed to convert EPContext data file name to UTF-8.", + GetLastError()); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + return nullptr; +} +#endif + +// Converts a UTF-8 path to a std::filesystem::path. A null or empty input yields an empty path and a success status; +// conversion failures are reported via OrtStatus*. +inline OrtStatus* Utf8Path(const OrtApi& api, const char* path, std::filesystem::path& out_path) { + out_path.clear(); + if (path == nullptr || path[0] == '\0') { + return nullptr; + } + +#ifdef _WIN32 + std::wstring wide_path; + RETURN_IF_ERROR(Utf8ToWideString(api, path, wide_path)); + out_path = std::filesystem::path{wide_path}; +#else + (void)api; + out_path = std::filesystem::path{path}; +#endif + return nullptr; +} + +inline OrtStatus* PathToUtf8String(const OrtApi& api, const std::filesystem::path& path, std::string& utf8_path) { + utf8_path.clear(); +#ifdef _WIN32 + RETURN_IF_ERROR(WideToUtf8String(api, path.wstring(), utf8_path)); +#else + (void)api; + utf8_path = path.string(); +#endif + return nullptr; +} + +inline std::string PathToUtf8StringForMessage(const std::filesystem::path& path) { + std::string utf8_path; + Ort::Status status{PathToUtf8String(Ort::GetApi(), path, utf8_path)}; + return status.IsOK() ? utf8_path : std::string{""}; +} + +// Lexical check for a ".." component. This is a coarse guard used when there is no filesystem base directory to +// contain against: logical callback-namespace names and trusted (graph == nullptr) physical paths. It is NOT a +// containment mechanism: it does not resolve symlinks and it rejects benign cases such as "a/b/c/../file.txt". +// Filesystem containment against a model directory is done by IsResolvedPathWithinBase() below, which the untrusted +// (model-relative) resolution path uses. +inline bool ContainsPathTraversal(const std::filesystem::path& path) { + const std::filesystem::path parent_dir{".."}; + for (const auto& component : path) { + if (component == parent_dir) { + return true; + } + } + return false; +} + +inline bool HasAbsoluteOrRootedPath(const std::filesystem::path& path) { + return path.is_absolute() || path.has_root_name() || path.has_root_directory(); +} + +// Returns true if the final component of `path` is empty (e.g., a trailing separator like "sub/") or is the +// current-directory entry ".", i.e. the name designates a directory rather than a file (".." is handled separately by +// ContainsPathTraversal()). Such a name resolves to a directory and would only surface later as a confusing file I/O +// failure, so model-derived names like these are rejected up front. +inline bool IsDirectoryOrEmptyName(const std::filesystem::path& path) { + const std::filesystem::path leaf = path.filename(); + return leaf.empty() || leaf == std::filesystem::path{"."}; +} + +// Returns true if `candidate_full` (a base-relative name already combined with `base`) resolves to a location inside +// `base`. Both are normalized with std::filesystem::weakly_canonical, which resolves "." / ".." and any symlinks in +// the existing portion of the path, so a name that escapes `base` directly or through a symlink is rejected. On +// success the canonical resolved path is written to `resolved`. +inline bool IsResolvedPathWithinBase(const std::filesystem::path& base, const std::filesystem::path& candidate_full, + std::filesystem::path& resolved) { + std::error_code ec; + const std::filesystem::path base_for_canon = base.empty() ? std::filesystem::path{"."} : base; + const std::filesystem::path canonical_base = std::filesystem::weakly_canonical(base_for_canon, ec); + if (ec) { + return false; + } + std::filesystem::path candidate_resolved = std::filesystem::weakly_canonical(candidate_full, ec); + if (ec) { + return false; + } + const std::filesystem::path relative = candidate_resolved.lexically_relative(canonical_base); + if (relative.empty() || *relative.begin() == std::filesystem::path{".."}) { + return false; + } + + resolved = std::move(candidate_resolved); + return true; +} + +inline OrtStatus* ValidateEpContextDataName(const OrtApi& api, const char* file_name, + std::filesystem::path& data_name) { + data_name.clear(); + + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + std::filesystem::path candidate_path; + RETURN_IF_ERROR(Utf8Path(api, file_name, candidate_path)); + if (candidate_path.empty()) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is not a valid path"); + } + + if (HasAbsoluteOrRootedPath(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + } + + if (ContainsPathTraversal(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + } + + if (IsDirectoryOrEmptyName(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must refer to a file, not a directory"); + } + + data_name = candidate_path; + return nullptr; +} + +// Resolves `file_name` to a filesystem path for reading or writing EPContext data (used by both the read path and +// the write-fallback path). +// +// When `graph` is null the caller is trusted and owns the path: `file_name` is returned as-is and may be absolute (a +// lexical ".." is still rejected as a coarse guard). When `graph` is non-null, `file_name` originates from the +// untrusted EPContext model "ep_cache_context" attribute: the graph must have a model path, the name must be +// relative, and after combining it with the model's directory the result must stay within that directory. Symlinks and +// ".." are resolved (via weakly_canonical), so a name that escapes the model directory - including through a symlink - +// is rejected. +// This helper only decides whether a model-derived file name resolves inside the model directory. Production EPs +// should still choose an application-approved storage root (sandbox), reject special files/locations as appropriate, +// and cap the number of bytes they will read or write for a single EPContext payload. +inline OrtStatus* ResolveEpContextDataPath(const OrtApi& api, const char* file_name, const OrtGraph* graph, + std::filesystem::path& data_path) { + data_path.clear(); + + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + std::filesystem::path candidate_path; + RETURN_IF_ERROR(Utf8Path(api, file_name, candidate_path)); + if (candidate_path.empty()) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name is not a valid path"); + } + + // Trusted direct callers (graph == nullptr) own the path and may pass an absolute physical path. + if (graph == nullptr) { + if (ContainsPathTraversal(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not contain path traversal"); + } + data_path = candidate_path; + return nullptr; + } + + // Untrusted (model-derived) name: must be relative and must resolve within the model directory. + if (HasAbsoluteOrRootedPath(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be absolute or rooted"); + } + + if (IsDirectoryOrEmptyName(candidate_path)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must refer to a file, not a directory"); + } + + const ORTCHAR_T* model_path = nullptr; + RETURN_IF_ERROR(api.Graph_GetModelPath(graph, &model_path)); + if (model_path == nullptr || model_path[0] == 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, + "EPContext data file fallback requires a model path to resolve relative names"); + } + + const std::filesystem::path base_dir = std::filesystem::path{model_path}.parent_path(); + std::filesystem::path resolved; + if (!IsResolvedPathWithinBase(base_dir, base_dir / candidate_path, resolved)) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, + "EPContext data file name must resolve to a path within the model directory"); + } + + data_path = resolved; + return nullptr; +} + +inline OrtStatus* WriteEpContextDataToResolvedFile(const OrtApi& api, const std::filesystem::path& data_path, + const void* buffer, size_t buffer_size) { + std::ofstream output_stream(data_path, std::ios::binary); + if (!output_stream) { + const std::string message = "Failed to open EPContext data file for write: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + if (buffer_size != 0) { + if (buffer_size > static_cast(std::numeric_limits::max())) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer is too large to write"); + } + + output_stream.write(static_cast(buffer), static_cast(buffer_size)); + if (!output_stream) { + const std::string message = "Failed to write EPContext data file: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + } + + return nullptr; +} + +inline OrtStatus* ReadEpContextDataFromFile(const OrtApi& api, const char* file_name, const OrtGraph* graph, + std::vector& data) { + data.clear(); + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, file_name, graph, data_path)); + + std::ifstream input_stream(data_path, std::ios::binary); + if (!input_stream) { + const std::string message = "Failed to open EPContext data file for read: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + data.assign(std::istreambuf_iterator{input_stream}, std::istreambuf_iterator{}); + if (!input_stream) { + const std::string message = "Failed to read EPContext data file: " + + PathToUtf8StringForMessage(data_path); + return api.CreateStatus(ORT_FAIL, message.c_str()); + } + + return nullptr; +} + +inline OrtStatus* WriteEpContextDataToFile(const OrtApi& api, const char* file_name, const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + if (buffer == nullptr && buffer_size != 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + } + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, file_name, graph, data_path)); + return WriteEpContextDataToResolvedFile(api, data_path, buffer, buffer_size); +} + +// Low-level overload that takes the read callback and its opaque state directly. Production EPs should use the +// overload below that takes an OrtEpContextConfig; this overload exists so unit tests can inject a callback without +// constructing an OrtEpContextConfig. When `read_func` is null the data is read from a file. +inline OrtStatus* ReadEpContextDataWithFileFallback( + const OrtApi& api, + OrtReadNamedBufferFunc read_func, void* read_state, + const char* file_name, const OrtGraph* graph, + std::vector& data) { + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + if (read_func == nullptr) { + return ReadEpContextDataFromFile(api, file_name, graph, data); + } + + // Use the C allocator API (not Ort::AllocatorWithDefaultOptions, whose constructor throws) so this OrtStatus*-based + // helper stays exception-free. The default allocator is owned by ORT and must not be released here. + OrtAllocator* allocator = nullptr; + RETURN_IF_ERROR(api.GetAllocatorWithDefaultOptions(&allocator)); + void* ep_context_data = nullptr; + size_t ep_context_data_size = 0; + OrtStatus* status = read_func(read_state, file_name, allocator, &ep_context_data, &ep_context_data_size); + auto buffer_deleter = [&api, allocator](void* buffer_to_free) { + if (buffer_to_free != nullptr) { + // Best-effort free during cleanup; release any returned status without throwing. + Ort::Status free_status{api.AllocatorFree(allocator, buffer_to_free)}; + static_cast(free_status); + } + }; + std::unique_ptr ep_context_data_guard(ep_context_data, buffer_deleter); + + if (status != nullptr) { + return status; + } + + if (ep_context_data_size != 0 && ep_context_data == nullptr) { + return api.CreateStatus( + ORT_FAIL, "OrtReadNamedBufferFunc returned a null buffer for non-empty EPContext data"); + } + + data.clear(); + if (ep_context_data != nullptr) { + const char* ep_context_data_begin = static_cast(ep_context_data); + data.assign(ep_context_data_begin, ep_context_data_begin + ep_context_data_size); + } + + return nullptr; +} + +// Reads EPContext binary data named `file_name`. If the session configured an OrtReadNamedBufferFunc (carried by +// `ep_context_config`), it is used; otherwise the data is read from a file. When `graph` is non-null it is the +// EPContext model graph: untrusted absolute/rooted/traversal names are rejected and relative names are resolved +// against the model directory. Pass `graph == nullptr` only for trusted callers supplying a physical path. `data` is +// cleared first and receives the bytes on success. +inline OrtStatus* ReadEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const OrtGraph* graph, + std::vector& data) { + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + if (ep_context_config != nullptr) { + auto get_read_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_Fn(&api); + if (get_read_func == nullptr) { + return api.CreateStatus(ORT_NOT_IMPLEMENTED, + "OrtEpApi_EpContextConfig_GetEpContextDataReadFunc is not available"); + } + RETURN_IF_ERROR(get_read_func(ep_context_config, &read_func, &read_state)); + } + return ReadEpContextDataWithFileFallback(api, read_func, read_state, file_name, graph, data); +} + +// Low-level overload that takes the write callback and its opaque state directly. Production EPs should use the +// overloads below that take an OrtEpContextConfig; this overload exists so unit tests can inject a callback without +// constructing an OrtEpContextConfig. When `write_func` is null the data is written to the file fallback. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + OrtWriteNamedBufferFunc write_func, void* write_state, + const char* file_name, const char* fallback_file_name, + const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + if (file_name == nullptr || file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data file name must not be empty"); + } + + if (buffer == nullptr && buffer_size != 0) { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data buffer must not be null for non-empty data"); + } + + // The app-supplied write callback owns its own logical namespace, so file_name is passed through unmodified. + // Only the file-fallback path below maps a name onto the filesystem, so it validates the logical name there. + if (write_func != nullptr) { + return write_func(write_state, file_name, buffer, buffer_size); + } + + // Even when the physical fallback path is supplied separately, `file_name` is the logical name written into the + // EPContext model's ep_cache_context attribute. Validate it as a safe relative name so a generated model cannot + // contain an unsafe logical reference that later reaches the read-side resolver. + std::filesystem::path logical_path; + RETURN_IF_ERROR(ValidateEpContextDataName(api, file_name, logical_path)); + + if (fallback_file_name == nullptr || fallback_file_name[0] == '\0') { + return api.CreateStatus(ORT_INVALID_ARGUMENT, "EPContext data fallback file name must not be empty"); + } + + std::filesystem::path data_path; + RETURN_IF_ERROR(ResolveEpContextDataPath(api, fallback_file_name, graph, data_path)); + return WriteEpContextDataToResolvedFile(api, data_path, buffer, buffer_size); +} + +// Writes EPContext binary data. If the compilation configured an OrtWriteNamedBufferFunc (carried by +// `ep_context_config`), it is used and `file_name` is passed through unmodified as the logical name. Otherwise the +// data is written to a file at `fallback_file_name`, which is resolved against the model directory when `graph` is +// non-null (and rejected if absolute or rooted in that case). `graph == nullptr` denotes a trusted caller that may +// supply an absolute physical path. `buffer` may be null only when `buffer_size` is 0. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const char* fallback_file_name, + const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + if (ep_context_config != nullptr) { + auto get_write_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_Fn(&api); + if (get_write_func == nullptr) { + return api.CreateStatus(ORT_NOT_IMPLEMENTED, + "OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc is not available"); + } + RETURN_IF_ERROR(get_write_func(ep_context_config, &write_func, &write_state)); + } + return WriteEpContextDataWithFileFallback(api, write_func, write_state, file_name, fallback_file_name, graph, buffer, + buffer_size); +} + +// Convenience overload that uses `file_name` as both the logical callback name and the file-fallback path. +// Because `file_name` doubles as the fallback path, it must be a safe relative name (this overload validates it and +// rejects absolute/rooted paths and `..` traversal). To write the file fallback to an absolute physical path (a +// trusted caller with `graph == nullptr`), use the overload above that takes a separate `fallback_file_name`. +inline OrtStatus* WriteEpContextDataWithFileFallback( + const OrtApi& api, + const OrtEpContextConfig* ep_context_config, + const char* file_name, const OrtGraph* graph, + const void* buffer, size_t buffer_size) { + return WriteEpContextDataWithFileFallback(api, ep_context_config, file_name, file_name, graph, buffer, buffer_size); +} + +} // namespace ep_context_data_utils diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index fecf7ac9a4038..90ad4e7976824 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -4,15 +4,17 @@ #include "ep.h" #include +#include #include #include -#include +#include #include #include #include +#include #include -#include +#include "../ep_context_data_utils.h" #include "ep_factory.h" #include "ep_stream_support.h" @@ -167,13 +169,15 @@ struct EpContextNodeComputeInfo : NodeComputeInfoBase { ExampleEp& ep; }; -ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger) +ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::Experimental::EpContextConfig ep_context_config) : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized ApiPtrs{static_cast(factory)}, factory_{factory}, name_{name}, config_{config}, - logger_{logger} { + logger_{logger}, + ep_context_config_{std::move(ep_context_config)} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. // Initialize the execution provider's function table @@ -193,8 +197,6 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C ORT_FILE, __LINE__, __FUNCTION__)); } -ExampleEp::~ExampleEp() = default; - /*static*/ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept { const auto* ep = static_cast(this_ptr); @@ -409,6 +411,26 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const auto fused_node_name = fused_node.GetName(); if (is_ep_context_node) { + Ort::ConstOpAttr embed_mode_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("embed_mode", embed_mode_attr)); + int64_t embed_mode = 1; + RETURN_IF_ERROR(embed_mode_attr.GetValue(embed_mode)); + + if (embed_mode == 0) { + Ort::ConstOpAttr ep_cache_context_attr; + RETURN_IF_ERROR(nodes[0].GetAttributeByName("ep_cache_context", ep_cache_context_attr)); + std::string ep_cache_context; + RETURN_IF_ERROR(ep_cache_context_attr.GetValue(ep_cache_context)); + + // This example only exercises the load-side read flow (callback first, file fallback otherwise) to show how + // an EP retrieves EPContext binary data during compile. A real EP would consume `ep_context_data` (e.g., + // initialize a kernel/engine from it); here it is intentionally read and then discarded. + std::vector ep_context_data; + RETURN_IF_ERROR(ep_context_data_utils::ReadEpContextDataWithFileFallback( + ep->ort_api, ep->ep_context_config_.get(), ep_cache_context.c_str(), ort_graphs[0], + ep_context_data)); + } + // Create EpContextKernel for EPContext nodes - clearly separates from MulKernel ep->ep_context_kernels_.emplace(fused_node_name, std::make_unique(ep->ort_api, ep->logger_)); @@ -449,7 +471,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext). if (ep->config_.enable_ep_context) { assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + RETURN_IF_ERROR(ep->CreateEpContextNodes(ort_graphs[0], gsl::span(fused_nodes, count), gsl::span(ep_context_nodes, count))); } } @@ -479,7 +501,8 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, // Creates EPContext nodes from the given fused nodes. // This is an example implementation that can be used to generate an EPContext model. However, this example EP // cannot currently run the EPContext model. -OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes, +OrtStatus* ExampleEp::CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes) { try { assert(fused_nodes.size() == ep_context_nodes.size()); @@ -512,11 +535,32 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes collect_input_output_names(fused_node_outputs, /*out*/ output_names); int64_t is_main_context = (i == 0); - int64_t embed_mode = 1; + int64_t embed_mode = config_.embed_ep_context_in_model ? 1 : 0; // Create node attributes. The CreateNode() function copies the attributes. std::array attributes = {}; - std::string ep_ctx = "binary_data"; + std::string ep_ctx = config_.embed_ep_context_in_model ? "binary_data" : fused_node_name + ".ctx"; + if (!config_.embed_ep_context_in_model) { + const std::string ep_context_data = "binary_data"; + std::string fallback_ep_ctx = ep_ctx; + const OrtGraph* fallback_graph = graph; + if (!config_.ep_context_output_model_path.empty()) { + std::filesystem::path output_model_path; + RETURN_IF_ERROR(ep_context_data_utils::Utf8Path(ort_api, config_.ep_context_output_model_path.c_str(), + output_model_path)); + const std::filesystem::path output_model_dir = output_model_path.parent_path(); + if (!output_model_dir.empty()) { + std::filesystem::path ep_ctx_path; + RETURN_IF_ERROR(ep_context_data_utils::Utf8Path(ort_api, ep_ctx.c_str(), ep_ctx_path)); + RETURN_IF_ERROR(ep_context_data_utils::PathToUtf8String(ort_api, output_model_dir / ep_ctx_path, + fallback_ep_ctx)); + } + fallback_graph = nullptr; + } + RETURN_IF_ERROR(ep_context_data_utils::WriteEpContextDataWithFileFallback( + ort_api, ep_context_config_.get(), ep_ctx.c_str(), fallback_ep_ctx.c_str(), fallback_graph, + ep_context_data.data(), ep_context_data.size())); + } attributes[0] = Ort::OpAttr("ep_cache_context", ep_ctx.data(), static_cast(ep_ctx.size()), ORT_OP_ATTR_STRING); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 5dcd9f07bef1f..4112abb723d39 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -6,6 +6,7 @@ #include #include "../plugin_ep_utils.h" +#include "onnxruntime_experimental_cxx_api.h" class ExampleEpFactory; @@ -61,13 +62,16 @@ class ExampleEp : public OrtEp, public ApiPtrs { public: struct Config { bool enable_ep_context = false; + bool embed_ep_context_in_model = false; bool enable_weightless_ep_context_nodes = false; + std::string ep_context_output_model_path; // Other EP configs (typically extracted from OrtSessionOptions or OrtHardwareDevice(s)) }; - ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger); + ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger, + Ort::Experimental::EpContextConfig ep_context_config); - ~ExampleEp(); + ~ExampleEp() = default; std::unordered_map>& MulKernels() { return mul_kernels_; @@ -108,7 +112,8 @@ class ExampleEp : public OrtEp, public ApiPtrs { static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, _Outptr_ const OrtMemoryDevice** device) noexcept; - OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, + OrtStatus* CreateEpContextNodes(const OrtGraph* graph, + gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); // Returns true if the EP should save constant initializers so that they are available during inference. @@ -122,6 +127,7 @@ class ExampleEp : public OrtEp, public ApiPtrs { std::string name_; Config config_{}; const OrtLogger& logger_; + Ort::Experimental::EpContextConfig ep_context_config_; std::unordered_map> mul_kernels_; std::unordered_map> ep_context_kernels_; std::unordered_map float_initializers_; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 875f70bd29f3c..a323ec0e8c15e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -196,6 +196,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, const OrtSessionOptions* session_options, const OrtLogger* logger, OrtEp** ep) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN auto* factory = static_cast(this_ptr); *ep = nullptr; @@ -219,20 +220,33 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateEpImpl(OrtEpFactory* this_ptr, // Create EP configuration from session options, if needed. // Note: should not store a direct reference to the session options object as its lifespan is not guaranteed. std::string ep_context_enable; + std::string ep_context_embed_mode; + std::string ep_context_output_model_path; std::string weightless_ep_context_nodes_enable; RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEnable, "0", ep_context_enable)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextEmbedMode, "0", + ep_context_embed_mode)); + RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpContextFilePath, "", + ep_context_output_model_path)); RETURN_IF_ERROR(GetSessionConfigEntryOrDefault(*session_options, kOrtSessionOptionEpEnableWeightlessEpContextNodes, "0", weightless_ep_context_nodes_enable)); ExampleEp::Config config = {}; config.enable_ep_context = ep_context_enable == "1"; + config.embed_ep_context_in_model = ep_context_embed_mode == "1"; + config.ep_context_output_model_path = std::move(ep_context_output_model_path); config.enable_weightless_ep_context_nodes = weightless_ep_context_nodes_enable == "1"; - auto dummy_ep = std::make_unique(*factory, factory->ep_name_, config, *logger); - + // The EpContextConfig wrapper extracts the EPContext callbacks from the session options and owns the handle. It + // throws if the experimental functions are unavailable or extraction fails; EXCEPTION_TO_RETURNED_STATUS_END + // converts that (and any other exception thrown in this function) into an OrtStatus. + auto dummy_ep = std::make_unique( + *factory, factory->ep_name_, config, *logger, + Ort::Experimental::EpContextConfig{Ort::ConstSessionOptions{session_options}}); *ep = dummy_ep.release(); return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } /*static*/ diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 93633f9a375bb..e95918c719324 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include +#include #include // #include #include @@ -12,11 +14,14 @@ #include "core/graph/constants.h" #include "core/graph/onnx_protobuf.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "nlohmann/json.hpp" +#include "test/autoep/ep_context_data_callbacks.h" #include "test/autoep/test_autoep_utils.h" +#include "test/autoep/library/ep_context_data_utils.h" #include "test/autoep/library/example_plugin_ep/ep_test_hooks.h" #include "test/shared_lib/utils.h" #include "test/util/include/api_asserts.h" @@ -29,6 +34,51 @@ namespace test { namespace { +// Invokes the experimental EPContext read setter on the public C API. +void SetEpContextDataReadFunc(Ort::SessionOptions& session_options, OrtReadNamedBufferFunc read_func, void* state) { + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&Ort::GetApi()); + ASSERT_ORTSTATUS_OK(set_read_func(session_options, read_func, state)); +} + +// Invokes the experimental EPContext write setter on the public C API. +void SetEpContextDataWriteFunc(Ort::ModelCompilationOptions& compile_options, OrtWriteNamedBufferFunc write_func, + void* state) { + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &Ort::GetApi()); + ASSERT_ORTSTATUS_OK(set_write_func(compile_options, write_func, state)); +} + +void LoadModelProtoFromFile(const ORTCHAR_T* model_file, ONNX_NAMESPACE::ModelProto& model_proto) { + std::ifstream model_stream{std::filesystem::path(model_file), std::ios::binary}; + ASSERT_TRUE(model_stream.is_open()); + ASSERT_TRUE(model_proto.ParseFromIstream(&model_stream)); +} + +std::vector GetEpContextNodes(const ONNX_NAMESPACE::ModelProto& model_proto) { + std::vector ep_context_nodes; + + for (const auto& node : model_proto.graph().node()) { + if (node.domain() == kMSDomain && node.op_type() == "EPContext") { + ep_context_nodes.push_back(&node); + } + } + + return ep_context_nodes; +} + +const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const ONNX_NAMESPACE::NodeProto& node, + std::string_view attribute_name) { + for (const auto& attribute : node.attribute()) { + if (attribute.name() == attribute_name) { + return &attribute; + } + } + + return nullptr; +} + void RunMulModelWithPluginEp(const ORTCHAR_T* model_path, const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, model_path, session_options); @@ -521,6 +571,222 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { } } +TEST(OrtEpLibrary, PluginEp_GenEpContextModel_EmbedModeDoesNotUseCallbacks) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_embedded_ctx.onnx"); + std::filesystem::remove(output_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(output_model_file); }); + + EpContextDataCallbackState write_callback_state; + EpContextDataCallbackState compile_read_callback_state; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &compile_read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(true); + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &write_callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + } + + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + EXPECT_FALSE(write_callback_state.write_called); + EXPECT_FALSE(compile_read_callback_state.read_called); + + ONNX_NAMESPACE::ModelProto compiled_model; + ASSERT_NO_FATAL_FAILURE(LoadModelProtoFromFile(output_model_file, compiled_model)); + + auto ep_context_nodes = GetEpContextNodes(compiled_model); + ASSERT_EQ(ep_context_nodes.size(), 1u); + + const ONNX_NAMESPACE::AttributeProto* embed_mode_attr = GetNodeAttribute(*ep_context_nodes[0], "embed_mode"); + ASSERT_NE(embed_mode_attr, nullptr); + EXPECT_EQ(embed_mode_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + EXPECT_EQ(embed_mode_attr->i(), 1); + + const ONNX_NAMESPACE::AttributeProto* ep_cache_context_attr = GetNodeAttribute(*ep_context_nodes[0], + "ep_cache_context"); + ASSERT_NE(ep_cache_context_attr, nullptr); + EXPECT_EQ(ep_cache_context_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + EXPECT_EQ(ep_cache_context_attr->s(), "binary_data"); + + EpContextDataCallbackState load_read_callback_state; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &load_read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, output_model_file, session_options); + } + + EXPECT_FALSE(load_read_callback_state.read_called); +} + +TEST(OrtEpLibrary, PluginEp_GenAndLoadEpContextModel_ExternalDataUsesFileFallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_file_ctx.onnx"); + std::vector files_to_cleanup{std::filesystem::path{output_model_file}}; + for (const auto& path : files_to_cleanup) { + std::filesystem::remove(path); + } + auto cleanup = gsl::finally([&]() { + for (const auto& path : files_to_cleanup) { + std::filesystem::remove(path); + } + }); + + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(false); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + } + + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + + ONNX_NAMESPACE::ModelProto compiled_model; + ASSERT_NO_FATAL_FAILURE(LoadModelProtoFromFile(output_model_file, compiled_model)); + + auto ep_context_nodes = GetEpContextNodes(compiled_model); + ASSERT_EQ(ep_context_nodes.size(), 1u); + + const ONNX_NAMESPACE::AttributeProto* embed_mode_attr = GetNodeAttribute(*ep_context_nodes[0], "embed_mode"); + ASSERT_NE(embed_mode_attr, nullptr); + EXPECT_EQ(embed_mode_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + EXPECT_EQ(embed_mode_attr->i(), 0); + + const ONNX_NAMESPACE::AttributeProto* ep_cache_context_attr = GetNodeAttribute(*ep_context_nodes[0], + "ep_cache_context"); + ASSERT_NE(ep_cache_context_attr, nullptr); + EXPECT_EQ(ep_cache_context_attr->type(), ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); + ASSERT_FALSE(ep_cache_context_attr->s().empty()); + + const std::filesystem::path output_model_dir = std::filesystem::path{output_model_file}.parent_path(); + std::filesystem::path ep_cache_context_rel; + ASSERT_ORTSTATUS_OK( + ep_context_data_utils::Utf8Path(Ort::GetApi(), ep_cache_context_attr->s().c_str(), ep_cache_context_rel)); + const std::filesystem::path context_data_path = output_model_dir / ep_cache_context_rel; + files_to_cleanup.push_back(context_data_path); + ASSERT_TRUE(std::filesystem::exists(context_data_path)); + + std::vector context_data; + std::string context_data_file_name; + ASSERT_ORTSTATUS_OK(ep_context_data_utils::PathToUtf8String(Ort::GetApi(), context_data_path, + context_data_file_name)); + ASSERT_ORTSTATUS_OK(ep_context_data_utils::ReadEpContextDataFromFile(Ort::GetApi(), context_data_file_name.c_str(), + nullptr, context_data)); + EXPECT_EQ(std::string(context_data.begin(), context_data.end()), "binary_data"); + + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, output_model_file, session_options); + } +} + +TEST(OrtEpLibrary, PluginEp_GenEpContextModel_ExternalDataUsesWriteCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx.onnx"); + std::filesystem::remove(output_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(output_model_file); }); + + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + EpContextDataCallbackState callback_state; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(false); + ASSERT_NO_FATAL_FAILURE(SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + ASSERT_TRUE(callback_state.write_called); + EXPECT_FALSE(callback_state.write_file_name.empty()); + EXPECT_EQ(std::string(callback_state.payload.begin(), callback_state.payload.end()), "binary_data"); +} + +TEST(OrtEpLibrary, PluginEp_LoadEpContextModel_ExternalDataUsesReadCallback) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_mul_1_external_ctx_load.onnx"); + std::filesystem::remove(compiled_model_file); + auto cleanup = gsl::finally([&]() { std::filesystem::remove(compiled_model_file); }); + + EpContextDataCallbackState write_callback_state; + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + compile_options.SetEpContextEmbedMode(false); + ASSERT_NO_FATAL_FAILURE( + SetEpContextDataWriteFunc(compile_options, StoreEpContextDataCallback, &write_callback_state)); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + ASSERT_TRUE(write_callback_state.write_called); + } + + EpContextDataCallbackState read_callback_state; + read_callback_state.payload = write_callback_state.payload; + { + Ort::SessionOptions session_options; + ASSERT_NO_FATAL_FAILURE(SetEpContextDataReadFunc(session_options, LoadEpContextDataCallback, &read_callback_state)); + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, compiled_model_file, session_options); + } + + ASSERT_TRUE(read_callback_state.read_called); + EXPECT_EQ(read_callback_state.read_file_name, write_callback_state.write_file_name); +} + TEST(OrtEpLibrary, PluginEp_GenWeightlessEpContextModel) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index ee5c8bb567e1e..c9065878da1ce 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -621,6 +621,11 @@ TEST(ModelPackageTest, CheckCompiledModelCompatibilityInfo) { compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + // Embed the EPContext binary data inside the compiled model so the model is self-contained. + // This test copies only the compiled .onnx into the model package, so it must not rely on a + // separate sidecar EPContext data file (which non-embedded mode would produce). + compile_options.SetEpContextEmbedMode(true); + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); ASSERT_TRUE(std::filesystem::exists(output_model_file)); } diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index fef185e20f341..4820f4e5c8898 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -7,12 +7,14 @@ #include #include #include +#include #include #include #include "gsl/gsl" #include "gtest/gtest.h" #include "core/common/logging/sinks/file_sink.h" +#include "core/common/path_string.h" #include "core/framework/config_options.h" #include "core/framework/kernel_def_builder.h" #include "core/framework/op_kernel.h" diff --git a/onnxruntime/test/shared_lib/test_ep_context_data_api.cc b/onnxruntime/test/shared_lib/test_ep_context_data_api.cc new file mode 100644 index 0000000000000..ec8107f92aa7a --- /dev/null +++ b/onnxruntime/test/shared_lib/test_ep_context_data_api.cc @@ -0,0 +1,331 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_experimental_cxx_api.h" + +#include "gmock/gmock.h" +#include "gsl/gsl" +#include "gtest/gtest.h" +#include "test/util/include/api_asserts.h" + +namespace { + +void ExpectFailureOrtStatus(OrtStatus* status_ptr, OrtErrorCode expected_code, const char* expected_message) { + Ort::Status status{status_ptr}; + ASSERT_NE(status_ptr, nullptr) << "Expected a failure status, but the API returned nullptr (OK)."; + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), expected_code); + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr(expected_message)); +} + +struct EpContextReadCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +OrtStatus* ORT_API_CALL EpContextReadCallback(void* state, const char* file_name, OrtAllocator* allocator, + void** buffer, size_t* data_size) { + auto* read_state = static_cast(state); + read_state->called = true; + read_state->file_name = file_name; + + *buffer = nullptr; + *data_size = read_state->payload.size(); + + if (read_state->payload.empty()) { + return nullptr; + } + + OrtStatus* status = Ort::GetApi().AllocatorAlloc(allocator, read_state->payload.size(), buffer); + if (status != nullptr) { + return status; + } + + std::memcpy(*buffer, read_state->payload.data(), read_state->payload.size()); + return nullptr; +} + +struct EpContextWriteCallbackState { + bool called = false; + std::string file_name; + std::vector payload; +}; + +OrtStatus* ORT_API_CALL EpContextWriteCallback(void* state, const char* file_name, const void* buffer, + size_t buffer_size) { + auto* write_state = static_cast(state); + write_state->called = true; + write_state->file_name = file_name; + write_state->payload.clear(); + if (buffer_size != 0) { + if (buffer == nullptr) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, + "EpContextWriteCallback received a null buffer for non-empty data"); + } + + const char* buffer_bytes = static_cast(buffer); + write_state->payload.assign(buffer_bytes, buffer_bytes + buffer_size); + } + + return nullptr; +} + +} // namespace + +TEST(EpContextDataApiTest, ReadFuncIsReturnedByEpApi) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{ + false, + {}, + {'e', 'p', 'c', 't', 'x'}, + }; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = nullptr; + void* callback_state_out = nullptr; + ep_context_config.GetReadFunc(read_func, callback_state_out); + ASSERT_EQ(read_func, EpContextReadCallback); + ASSERT_EQ(callback_state_out, &callback_state); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = nullptr; + size_t buffer_size = 0; + ASSERT_ORTSTATUS_OK(read_func(callback_state_out, "context.bin", allocator, &buffer, &buffer_size)); + auto release_buffer = gsl::finally([&]() { + if (buffer != nullptr) { + allocator.Free(buffer); + } + }); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "context.bin"); + ASSERT_EQ(buffer_size, callback_state.payload.size()); + EXPECT_TRUE(std::equal(callback_state.payload.begin(), callback_state.payload.end(), + static_cast(buffer))); +} + +TEST(EpContextDataApiTest, ApiRejectsInvalidArguments) { + const auto& ort_api = Ort::GetApi(); + + auto* get_config = Ort::Experimental::Get_OrtEpApi_SessionOptions_GetEpContextConfig_SinceV28_FnOrThrow(&ort_api); + auto* release_config_func = + Ort::Experimental::Get_OrtEpApi_ReleaseEpContextConfig_SinceV28_FnOrThrow(&ort_api); + auto* get_read_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + auto* get_write_func = + Ort::Experimental::Get_OrtEpApi_EpContextConfig_GetEpContextDataWriteFunc_SinceV28_FnOrThrow(&ort_api); + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + Ort::SessionOptions session_options; + OrtEpContextConfig* ep_context_config = nullptr; + ExpectFailureOrtStatus(get_config(nullptr, &ep_context_config), ORT_INVALID_ARGUMENT, "OrtSessionOptions is NULL"); + ExpectFailureOrtStatus(get_config(session_options, nullptr), ORT_INVALID_ARGUMENT, + "Output OrtEpContextConfig is NULL"); + + ExpectFailureOrtStatus(set_read_func(nullptr, EpContextReadCallback, nullptr), ORT_INVALID_ARGUMENT, + "'options' parameter must not be NULL"); + + ASSERT_ORTSTATUS_OK(get_config(session_options, &ep_context_config)); + auto release_config = gsl::finally([&]() { release_config_func(ep_context_config); }); + + OrtReadNamedBufferFunc read_func = nullptr; + OrtWriteNamedBufferFunc write_func = nullptr; + void* state = nullptr; + ExpectFailureOrtStatus(get_read_func(nullptr, &read_func, &state), ORT_INVALID_ARGUMENT, + "OrtEpContextConfig is NULL"); + ExpectFailureOrtStatus(get_read_func(ep_context_config, nullptr, &state), ORT_INVALID_ARGUMENT, + "Output read_func is NULL"); + ExpectFailureOrtStatus(get_read_func(ep_context_config, &read_func, nullptr), ORT_INVALID_ARGUMENT, + "Output state is NULL"); + ExpectFailureOrtStatus(get_write_func(nullptr, &write_func, &state), ORT_INVALID_ARGUMENT, + "OrtEpContextConfig is NULL"); + ExpectFailureOrtStatus(get_write_func(ep_context_config, nullptr, &state), ORT_INVALID_ARGUMENT, + "Output write_func is NULL"); + ExpectFailureOrtStatus(get_write_func(ep_context_config, &write_func, nullptr), ORT_INVALID_ARGUMENT, + "Output state is NULL"); + +#if !defined(ORT_MINIMAL_BUILD) + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataApiRejectsInvalidArguments"}; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + ExpectFailureOrtStatus(set_write_func(nullptr, EpContextWriteCallback, nullptr), ORT_INVALID_ARGUMENT, + "OrtModelCompilationOptions is NULL"); + // A null write_func is allowed: it clears any previously set callback (covered by WriteFuncCanBeCleared), so it is + // not rejected here. +#endif // !defined(ORT_MINIMAL_BUILD) +} + +TEST(EpContextDataApiTest, AccessorsReturnNullWhenCallbacksUnset) { + Ort::SessionOptions session_options; + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + + OrtReadNamedBufferFunc read_func = EpContextReadCallback; + OrtWriteNamedBufferFunc write_func = EpContextWriteCallback; + void* state = reinterpret_cast(0x1); + + ep_context_config.GetReadFunc(read_func, state); + EXPECT_EQ(read_func, nullptr); + EXPECT_EQ(state, nullptr); + + state = reinterpret_cast(0x1); + ep_context_config.GetWriteFunc(write_func, state); + EXPECT_EQ(write_func, nullptr); + EXPECT_EQ(state, nullptr); +} + +TEST(EpContextDataApiTest, ConfigReturnsConfiguredCallbacks) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + ep_context_config.GetReadFunc(read_func, read_state); + EXPECT_EQ(read_func, EpContextReadCallback); + EXPECT_EQ(read_state, &callback_state); + + OrtWriteNamedBufferFunc write_func = nullptr; + void* write_state = nullptr; + ep_context_config.GetWriteFunc(write_func, write_state); + EXPECT_EQ(write_func, nullptr); + EXPECT_EQ(write_state, nullptr); +} + +TEST(EpContextDataApiTest, ReadFuncCanBeCleared) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + ASSERT_ORTSTATUS_OK(set_read_func(session_options, nullptr, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = EpContextReadCallback; + void* read_state = reinterpret_cast(0x1); + ep_context_config.GetReadFunc(read_func, read_state); + EXPECT_EQ(read_func, nullptr); + EXPECT_EQ(read_state, nullptr); +} + +#if !defined(ORT_MINIMAL_BUILD) +TEST(EpContextDataApiTest, WriteFuncCanBeSetOnModelCompilationOptions) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeSetOnModelCompilationOptions"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + const std::vector payload{'b', 'i', 'n', 'a', 'r', 'y'}; + ASSERT_ORTSTATUS_OK(EpContextWriteCallback(&callback_state, "engine.bin", payload.data(), payload.size())); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "engine.bin"); + EXPECT_EQ(callback_state.payload, payload); +} + +TEST(EpContextDataApiTest, WriteFuncCanBeCleared) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeCleared"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + // A null write_func clears the previously set callback (symmetric with the read setter) and must be accepted + // rather than rejected with ORT_INVALID_ARGUMENT. + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, nullptr, &callback_state)); +} + +TEST(EpContextDataApiTest, WriteFuncCanBeUsedWithEpContextBinaryInformation) { + const auto& ort_api = Ort::GetApi(); + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpContextDataWriteFuncCanBeUsedWithEpContextBinaryInformation"}; + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compilation_options{env, session_options}; + + auto* set_write_func = + Ort::Experimental::Get_OrtCompileApi_ModelCompilationOptions_SetEpContextDataWriteFunc_SinceV28_FnOrThrow( + &ort_api); + + // The EPContext write callback and the EPContext binary information may be configured together; neither call + // rejects the other. + ASSERT_NO_THROW(compilation_options.SetEpContextBinaryInformation(ORT_TSTR("ep_context_dir/"), + ORT_TSTR("compiled_model.onnx"))); + + EpContextWriteCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_write_func(compilation_options, EpContextWriteCallback, &callback_state)); + + const std::vector payload{'c', 't', 'x'}; + ASSERT_ORTSTATUS_OK(EpContextWriteCallback(&callback_state, "logical_context.bin", payload.data(), payload.size())); + + ASSERT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "logical_context.bin"); + EXPECT_EQ(callback_state.payload, payload); +} +#endif // !defined(ORT_MINIMAL_BUILD) + +TEST(EpContextDataApiTest, ReturnedReadFuncAllowsEmptyPayloads) { + const auto& ort_api = Ort::GetApi(); + Ort::SessionOptions session_options; + + auto* set_read_func = + Ort::Experimental::Get_OrtApi_SessionOptions_SetEpContextDataReadFunc_SinceV28_FnOrThrow(&ort_api); + + EpContextReadCallbackState callback_state{}; + ASSERT_ORTSTATUS_OK(set_read_func(session_options, EpContextReadCallback, &callback_state)); + + Ort::Experimental::EpContextConfig ep_context_config{session_options}; + OrtReadNamedBufferFunc read_func = nullptr; + void* read_state = nullptr; + ep_context_config.GetReadFunc(read_func, read_state); + ASSERT_EQ(read_func, EpContextReadCallback); + ASSERT_EQ(read_state, &callback_state); + + Ort::AllocatorWithDefaultOptions allocator; + void* buffer = reinterpret_cast(0x1); + size_t buffer_size = 1; + ASSERT_ORTSTATUS_OK(read_func(read_state, "empty.bin", allocator, &buffer, &buffer_size)); + + EXPECT_TRUE(callback_state.called); + EXPECT_EQ(callback_state.file_name, "empty.bin"); + EXPECT_EQ(buffer, nullptr); + EXPECT_EQ(buffer_size, 0U); +} From c65d9b64456a5e52464261c5e5e4bac17402b3f2 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Fri, 26 Jun 2026 21:54:10 -0700 Subject: [PATCH 18/19] Normalize bool tensor raw_data to {0, 1} on unpack (#29238) ### Description Bool initializers supplied via `TensorProto` `raw_data` are copied verbatim by `UnpackTensor`, so their bytes are not guaranteed to be the canonical `{0, 1}` (the `int32_data` path normalizes via `static_cast`, but the `raw_data` path did not). Kernels across the codebase assume bool tensors hold `{0, 1}`. The CUDA `Compress` kernel is concretely affected: its output-sizing path sign-extends the condition bytes (`int8_t` -> `int32_t`) through `cub::DeviceScan::InclusiveSum`, while `_CompressKernel` selects elements using bool truthiness (`condition_data[div]`). For condition bytes outside `{0, 1}` the two interpretations disagree and the output is sized inconsistently with how elements are written. The CPU kernel uses truthiness for both sizing and selection and is unaffected. ### Changes - `UnpackTensor` (`tensorprotoutils.cc`): normalize `raw_data` bytes to `{0, 1}` after copy. The `UnpackTensorWithExternalData` specialization does the same for external data read through that path. - CUDA `Compress` `CastToInt32` (`compress_impl.cu`): normalize to `{0, 1}` (still returns `int32_t`, preserving the accumulator-widening intent of #9295) so the sizing path matches the kernel's write predicate, matching the CPU kernel and the CUDA `NonZero` `bool(x)` convention. This makes the CUDA `Compress` kernel correct independently of how its bool condition initializer was materialized. - Shared helper `utils::NormalizeBoolTensorIfNeeded(Tensor&)` (`tensorprotoutils.{h,cc}`): single normalization point, reused by `TensorProtoToTensor()` (external branch, after `MakeCpuTensorCopy`) and by the session-init external device-copy path. - `session_state_utils.cc::DeserializeTensorProto`: for **external** bool initializers loaded onto a non-CPU device, normalize the writable CPU staging copy before `CopyTensorFromCPUToDevice`. The `GetExtDataFromTensorProto` buffer may be a read-only mmap, so it is normalized via a writable copy rather than in place. - Unit tests in `tensorutils_test.cc` for bool `raw_data` with non-canonical bytes and for the `NormalizeBoolTensorIfNeeded` helper. A `Compress` `OpTester` test cannot reproduce the original bug because the test harness itself normalizes bool during input construction, so coverage is placed at the deserialization layer. Tests use only Status returns and gtest assertions, so they build and run in no-exception builds. ### Coverage / scope of bool normalization Fully covered: - In-proto `raw_data` bool initializers (all EPs), via `UnpackTensor`. - External bool initializers reaching `TensorProtoToTensor()`. - External bool initializers copied to a non-CPU device through the session-init device-copy path. Intentionally **not** normalized (by design): - The CPU zero-copy mmap path for external initializers (`GetExtDataFromTensorProto` returns a read-only/shared mapping that cannot be safely modified in place). - The custom external-data-loader path (`LoadExtDataToTensorFromTensorProto`), which loads directly into a device tensor. These remaining paths are safe for the concrete bug this PR targets because the CUDA `Compress` kernel is hardened in `compress_impl.cu` regardless of initializer storage. Other byte-comparing bool consumers fed by an external mmap/custom-loader initializer with non-canonical bytes are out of scope here. ### Motivation and Context `CastToInt32` was introduced in #9295 to widen the `cub::InclusiveSum` accumulator (an int8 overflow fix); it did not normalize the bool interpretation. The accumulator-width and bool-normalization concerns are independent. This change addresses the latter at the deserialization source and hardens the CUDA `Compress` kernel. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/android.yml | 4 +- .../core/framework/session_state_utils.cc | 14 ++++ .../core/framework/tensorprotoutils.cc | 50 ++++++++++--- onnxruntime/core/framework/tensorprotoutils.h | 9 +++ .../test/framework/tensorutils_test.cc | 72 +++++++++++++++++++ 5 files changed, 137 insertions(+), 12 deletions(-) diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 954c3313faf25..66141c58e7a01 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -78,8 +78,8 @@ jobs: run: | set -e -x BINARY_SIZE_THRESHOLD_ARGS="" - echo "Binary size threshold in bytes: 1438720" - BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1438720" + echo "Binary size threshold in bytes: 1440768" + BINARY_SIZE_THRESHOLD_ARGS="--threshold_size_in_bytes 1440768" # Ensure ANDROID_NDK_HOME is available and get its real path if [ -z "$ANDROID_NDK_HOME" ]; then diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 80e7829a88e62..28477db5ec172 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -136,6 +136,20 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st deserialized_value, &prepacked_for_graph)); + const Tensor& cpu_staging_tensor = deserialized_value.Get(); + // Bool external initializers are copied verbatim and may carry bytes outside the canonical + // {0, 1} set. The CPU staging tensor above can be backed by a read-only mmap, so normalize into + // a writable CPU copy before copying to the device (see utils::NormalizeBoolTensorIfNeeded). + if (cpu_staging_tensor.IsDataType()) { + Tensor normalized_cpu_tensor; + ORT_RETURN_IF_ERROR(AllocateTensorOnDeviceOrMemory(/* use_device_allocator_for_initializers =*/true, + tensor_shape, type, + default_cpu_alloc, normalized_cpu_tensor)); + utils::MakeCpuTensorCopy(cpu_staging_tensor, normalized_cpu_tensor); + utils::NormalizeBoolTensorIfNeeded(normalized_cpu_tensor); + return CopyTensorFromCPUToDevice(data_transfer_mgr, normalized_cpu_tensor, std::move(tensor), ort_value); + } + return CopyTensorFromCPUToDevice(data_transfer_mgr, deserialized_value.Get(), std::move(tensor), ort_value); } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index e6df3a9923ad5..7b8f78136be3d 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -725,6 +725,17 @@ void ConvertRawDataInTensorProto(TensorProto& tensor) { SwapByteOrderInplace(element_size, span); } +// Bool tensors must hold canonical {0, 1} byte values. Data sourced from raw_data or external +// files is copied verbatim and may contain other non-zero bytes; normalize any non-zero byte to 1 +// so every consumer observes a single, consistent value. Operate on the byte representation to +// avoid loading a bool object that does not yet hold a valid value. +static void NormalizeBoolBytes(uint8_t* bool_bytes, size_t num_elements) { + static_assert(sizeof(bool) == 1, "Normalization assumes 1 byte per bool element"); + for (size_t i = 0; i < num_elements; ++i) { + bool_bytes[i] = bool_bytes[i] != 0 ? 1 : 0; + } +} + #if !defined(ORT_MINIMAL_BUILD) static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, @@ -752,6 +763,19 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, reinterpret_cast(p_data)); } +// UnpackTensorWithExternalData +// External data is copied verbatim and may contain bytes outside the canonical {0, 1} set, so +// normalize them (see NormalizeBoolBytes). +template <> +Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, + const std::filesystem::path& tensor_proto_dir, size_t expected_num_elements, + /*out*/ bool* p_data) { + ORT_RETURN_IF_ERROR(UnpackTensorWithExternalDataImpl(tensor, tensor_proto_dir, expected_num_elements, sizeof(bool), + reinterpret_cast(p_data))); + NormalizeBoolBytes(reinterpret_cast(p_data), expected_num_elements); + return Status::OK(); +} + #define DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(FOUR_BIT_TYPE, CalcPairFun) \ template <> \ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, \ @@ -800,7 +824,9 @@ INSTANTIATE_UNPACK_EXTERNAL_TENSOR(int32_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(int64_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(uint64_t) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(uint32_t) -INSTANTIATE_UNPACK_EXTERNAL_TENSOR(bool) +// bool is intentionally omitted: UnpackTensorWithExternalData is explicitly specialized +// above (to normalize bytes to {0, 1}), so an explicit instantiation here would have no effect +// and triggers -Werror,-Winstantiation-after-specialization. INSTANTIATE_UNPACK_EXTERNAL_TENSOR(MLFloat16) INSTANTIATE_UNPACK_EXTERNAL_TENSOR(BFloat16) @@ -908,15 +934,9 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d if (raw_data != nullptr) { ORT_RETURN_IF_ERROR(UnpackTensorWithRawData(raw_data, raw_data_len, expected_size, p_data)); - // raw_data is copied verbatim and may contain bytes outside the canonical {0, 1} set. - // Consumers rely on bool tensors holding {0, 1}; normalize any non-zero byte to 1 so every - // reader observes a single, consistent value. Operate on the byte representation to avoid - // loading a bool object that does not yet hold a valid value. - auto* bool_bytes = reinterpret_cast(p_data); - static_assert(sizeof(bool) == 1, "Normalization loop writes expected_size bytes assuming 1 byte per bool element"); - for (size_t i = 0; i < expected_size; ++i) { - bool_bytes[i] = bool_bytes[i] != 0 ? 1 : 0; - } + // raw_data is copied verbatim and may contain bytes outside the canonical {0, 1} set (see + // NormalizeBoolBytes). + NormalizeBoolBytes(reinterpret_cast(p_data), expected_size); return Status::OK(); } @@ -1891,6 +1911,9 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa ORT_RETURN_IF_ERROR(GetExtDataFromTensorProto(env, model_path, tensor_proto, ort_value)); const auto& ext_tensor = ort_value.Get(); MakeCpuTensorCopy(ext_tensor, tensor); + // MakeCpuTensorCopy memcpy's external bytes verbatim. Bool external initializers may carry + // bytes outside the canonical {0, 1} set, so normalize them here as well (see NormalizeBoolBytes). + NormalizeBoolTensorIfNeeded(tensor); return Status::OK(); } @@ -2201,6 +2224,13 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { } } +void NormalizeBoolTensorIfNeeded(Tensor& tensor) { + if (tensor.IsDataType()) { + NormalizeBoolBytes(reinterpret_cast(tensor.MutableDataRaw()), + narrow(tensor.Shape().Size())); + } +} + #if !defined(DISABLE_SPARSE_TENSORS) // Validates the external data declaration on a sub-tensor of a SparseTensorProto (values or diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 63f7b5e78e478..c07e9703ad384 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -283,6 +283,15 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n /// void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor); +/// +/// Normalizes the bytes of a CPU bool tensor to the canonical {0, 1} set (any non-zero byte -> 1). +/// Bool data sourced from raw_data or external files is copied verbatim and may contain other +/// non-zero bytes; normalizing ensures every consumer observes a single, consistent value. +/// No-op for non-bool tensors. The tensor must reside in writable CPU memory. +/// +/// The CPU tensor to normalize in place. +void NormalizeBoolTensorIfNeeded(Tensor& tensor); + #if !defined(DISABLE_SPARSE_TENSORS) /// // The function supports only COO format with 1D or 2D indices. Values shape is expected to be 1D. diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 34575f481d3cc..e15098d9c8c3c 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -6,6 +6,7 @@ #include "core/framework/endian_utils.h" #include "core/framework/prepacked_weights.h" #include "core/framework/prepacked_weights_container.h" +#include "core/framework/tensor.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "test/util/include/asserts.h" @@ -249,6 +250,41 @@ TEST(TensorProtoUtilsTest, UnpackBoolTensorWithRawDataNormalizesToZeroOne) { EXPECT_EQ(bytes[3], 1); } +// NormalizeBoolTensorIfNeeded normalizes a CPU bool tensor's bytes to {0, 1} in place. It backs the +// external-initializer device-copy path (session_state_utils.cc), where bool bytes that may be +// non-canonical are normalized in a writable CPU staging copy before being copied to the device. +TEST(TensorProtoUtilsTest, NormalizeBoolTensorIfNeededNormalizesToZeroOne) { + auto cpu_allocator = std::make_shared(); + + // Bool tensor: write non-canonical bytes, then normalize. + Tensor bool_tensor(DataTypeImpl::GetType(), TensorShape({4}), cpu_allocator); + unsigned char* bool_bytes = reinterpret_cast(bool_tensor.MutableDataRaw()); + bool_bytes[0] = 0x00; + bool_bytes[1] = 0x01; + bool_bytes[2] = 0x02; + bool_bytes[3] = 0xFF; + + NormalizeBoolTensorIfNeeded(bool_tensor); + + EXPECT_EQ(bool_bytes[0], 0); + EXPECT_EQ(bool_bytes[1], 1); + EXPECT_EQ(bool_bytes[2], 1); + EXPECT_EQ(bool_bytes[3], 1); + + // Non-bool tensor: bytes must be left untouched. + Tensor int32_tensor(DataTypeImpl::GetType(), TensorShape({3}), cpu_allocator); + int32_t* int32_data = int32_tensor.MutableData(); + int32_data[0] = 0; + int32_data[1] = 2; + int32_data[2] = 255; + + NormalizeBoolTensorIfNeeded(int32_tensor); + + EXPECT_EQ(int32_data[0], 0); + EXPECT_EQ(int32_data[1], 2); + EXPECT_EQ(int32_data[2], 255); +} + namespace { template std::vector CreateValues() { @@ -372,6 +408,42 @@ TEST(TensorProtoUtilsTest, UnpackTensorWithExternalData) { TestUnpackExternalTensor(TensorProto_DataType_BOOL, model_path); } +// A bool initializer supplied through external data is copied verbatim, so its bytes are not +// restricted to {0, 1}. UnpackTensor must normalize them so downstream consumers (which assume +// canonical bool values) all observe the same result regardless of how they read the byte. +TEST(TensorProtoUtilsTest, UnpackBoolTensorWithExternalDataNormalizesToZeroOne) { + std::filesystem::path model_path; + + // Bytes outside {0, 1}: 0x00 -> 0, 0x01 -> 1, 0x02 -> 1, 0xFF -> 1. + const unsigned char raw_bytes[] = {0x00, 0x01, 0x02, 0xFF}; + + std::basic_string filename(ORT_TSTR("bool_tensor_XXXXXX")); + FILE* fp; + CreateTestFile(fp, filename); + ASSERT_EQ(sizeof(raw_bytes), fwrite(raw_bytes, 1, sizeof(raw_bytes), fp)); + ASSERT_EQ(0, fclose(fp)); + std::unique_ptr file_deleter(const_cast(filename.c_str()), + DeleteFileFromDisk); + + TensorProto bool_tensor_proto; + onnx::StringStringEntryProto* location = bool_tensor_proto.mutable_external_data()->Add(); + location->set_key("location"); + location->set_value(ToUTF8String(filename)); + bool_tensor_proto.add_dims(4); + bool_tensor_proto.set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + bool_tensor_proto.set_data_type(TensorProto_DataType_BOOL); + + auto arr = std::make_unique(4); + auto status = utils::UnpackTensor(bool_tensor_proto, model_path, arr.get(), 4); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + const auto* bytes = reinterpret_cast(arr.get()); + EXPECT_EQ(bytes[0], 0); + EXPECT_EQ(bytes[1], 1); + EXPECT_EQ(bytes[2], 1); + EXPECT_EQ(bytes[3], 1); +} + template static NodeProto CreateConstantNode(const std::string& attrib_name, AttributeProto_AttributeType type, std::function add_data) { From c3a5222d2a80fdad084d128da45e7565ec074a91 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jun 2026 21:55:35 -0700 Subject: [PATCH 19/19] =?UTF-8?q?[CUDA]=20Remove=20unused=20code=20in=20mo?= =?UTF-8?q?e=5Fkernels.cu=E2=80=8E=20(#29295)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This removes some unused code for MinLatency, which is not used in MoE or QMoE. --- .../cuda/llm/moe_gemm/moe_kernels.cu | 461 ------------------ .../cuda/llm/moe_gemm/moe_kernels.h | 39 -- 2 files changed, 500 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 3185a9a86b231..9db229555ecbc 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -200,204 +200,6 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( } } -/** - * Takes the input maps and prepares the expanded maps for min latency - * @param num_active_experts_per_node: Number of active experts on current node - * @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by - * the token. Only the first num_active_experts_per_ rows are valid - * @param active_expert_global_ids: The global expert id for each activated expert - * Only the first num_active_experts_per_ values are valid - * @param expert_first_token_offset: Store the first token offset for each expert - */ -template -__device__ __forceinline__ void initTensor(T* value, int const tid, int const total_num, T const init_value) { - for (int i = tid; i < total_num; i += BLOCK_SIZE) { - value[i] = init_value; - } -} - -template -__device__ __forceinline__ void setLocalExperts(int* s_local_experts, T const* token_selected_experts, - int const total_num_experts, int const tid, int const start_expert, int const end_expert) { - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = expert >= start_expert && expert < end_expert; - if (is_valid_expert) { - int local_expert_id = expert - start_expert; - if (s_local_experts[local_expert_id] == 0) { - s_local_experts[local_expert_id] = 1; // @TODO: Make sure that we allow duplicated write here - } - } - } - __syncthreads(); -} - -template -__device__ __forceinline__ void prefixSum(T* out, T* in, int const num, int const tid) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage tempStorage; - - T threadData = 0; - if (tid < num) { - threadData = in[tid]; - } - - BlockScan(tempStorage).InclusiveSum(threadData, threadData); - __syncthreads(); - - if (tid < num) { - out[tid] = threadData; - } - __syncthreads(); -} - -__device__ __forceinline__ void setActiveNum(int& num_active, int& num_active_offset_start, int& num_active_offset_end, - int const cluster_size, int const cluster_rank) { - int num_remainder = num_active % cluster_size; - int num_active_per_node = max(0, num_active - 1) / cluster_size; // num_active_per_node shouldn't be neg - if (cluster_rank < num_remainder) { - num_active = num_active_per_node + 1; - num_active_offset_start = cluster_rank * num_active; - } else { - num_active = num_active_per_node; - num_active_offset_start = cluster_rank * num_active_per_node + num_remainder; - } - num_active_offset_end = num_active_offset_start + num_active; -} - -template -__global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const num_experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, bool const smart_routing, int const cluster_rank, - int const cluster_size, int const num_experts_smem) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - // Use one block to process the min latency case - int tid = threadIdx.x; - // 0. init the global memory experts_to_token_scores [num_experts_per_node, num_token] - int const total_local_scales = num_experts_per_node * num_tokens; - initTensor(experts_to_token_scores, tid, total_local_scales, 0.0f); - initTensor(active_expert_global_ids, tid, num_experts_per_node, -1); - - __threadfence(); //@Todo: check do I need this fence for previous zero setting - - // 1. mask for the active expert: 1 stands for active - extern __shared__ int s_local_experts[]; - int* s_store_experts = s_local_experts + num_experts_smem; - initTensor(s_local_experts, tid, num_experts_smem, 0); - __syncthreads(); - - // 2. set the shared array s_local_experts[] - int const total_num_experts = num_tokens * num_experts_per_token; - setLocalExperts( - s_local_experts, token_selected_experts, total_num_experts, tid, start_expert, end_expert); - - // 3. perform prefix sum to acquire the store position and total active experts - //@TODO: Use cub first, might need to change it to self-defined api - prefixSum(s_store_experts, s_local_experts, num_experts_smem, tid); - - // 4. store the num of active experts - int num_active = s_store_experts[num_experts_smem - 1]; - int num_active_offset_start = 0; - int num_active_offset_end = 0; - - if (smart_routing) { - setActiveNum(num_active, num_active_offset_start, num_active_offset_end, cluster_size, cluster_rank); - } - - if (tid == 0) { - *num_active_experts_per_node = num_active; - } - - // 5. store the global expert id for each expert - if (smart_routing) { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - if (offset >= num_active_offset_start && offset < num_active_offset_end) { - active_expert_global_ids[offset - num_active_offset_start] = i; - } else { - s_local_experts[i] = 0; - } - } - } - __syncthreads(); // Need sync to update the s_local_experts - } else { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - active_expert_global_ids[offset] = i + start_expert; - } - } - } - - // 6. store the scale values - __threadfence(); //@Todo: check do I need this fence for previous zero setting - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is not in the current node, set it to num_experts_per_node - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = smart_routing ? s_local_experts[expert] : (expert >= start_expert && expert < end_expert); - - if (is_valid_expert) { - int token = i / num_experts_per_token; - float const scale = token_final_scales[i]; - int offset = s_store_experts[expert - start_expert] - 1 - num_active_offset_start; - experts_to_token_scores[offset * num_tokens + token] = scale; - } - } - // 7. set default value for redundant memory - for (int i_exp = num_active + tid; i_exp < num_experts_per_node; i_exp += BLOCK_SIZE) { - active_expert_global_ids[i_exp] = -1; - } - // 8. set expert_first_token_offset - for (int i_exp = tid; i_exp < num_experts_per_node + 1; i_exp += BLOCK_SIZE) { - if (i_exp < num_active) { - expert_first_token_offset[i_exp] = i_exp * num_tokens; - } else { - expert_first_token_offset[i_exp] = num_active * num_tokens; - } - } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, int const cluster_rank, int const cluster_size, - int const num_experts_smem, cudaStream_t const stream) { - ORT_ENFORCE(num_experts_per_node == (end_expert - start_expert), - "num_experts_per_node must be equal to end_expert - start_expert"); - - ORT_ENFORCE(num_experts_per_node <= 256, "don't support num_experts_per_node > 256 cases"); - - int const threads = 256; - int const blocks = 1; - bool const smart_routing = cluster_size > 1; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = num_experts_smem * sizeof(int) * 2; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, buildMinLatencyActiveExpertMapsKernel, num_active_experts_per_node, - experts_to_token_scores, active_expert_global_ids, expert_first_token_offset, token_selected_experts, - token_final_scales, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, - smart_routing, cluster_rank, cluster_size, num_experts_smem); -} - template __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, int* const unpermuted_row_to_permuted_row, @@ -1168,123 +970,6 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir #endif } -template -__global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int64_t const num_experts_per_node, T const* in1, T const* in2, - WeightType const* weights1, WeightType const* weights2, float const* alpha_scale_flat1, - float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* output1, OutputType* output2, - int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; - - if (expert >= num_experts_per_node) { - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // Note: expert is used to calculate the offset of the input and output - // local_expert is used to calculate the offset of the weight - auto const num_tokens_before_expert = expert * num_tokens; - bool const is_active_expert = expert < *num_active_experts_per; - int const local_expert = is_active_expert ? active_expert_global_ids[expert] - start_expert : -1; - auto const gemm_m = is_active_expert ? num_tokens : 0; - - // M and N transposed since we are using the #tokens as the N dimension - layout_info1.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, gemm1_k); - layout_info2.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, gemm2_k); - - if (alpha_scale_flat1) { - assert(alpha_scale_flat2); - if (is_active_expert) { - layout_info1.alpha_scale_ptr_array[expert] = alpha_scale_flat1 + local_expert; - layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + local_expert; - } else { - layout_info1.alpha_scale_ptr_array[expert] = nullptr; - layout_info2.alpha_scale_ptr_array[expert] = nullptr; - } - } - - if (quant_params.fp4.fc1.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info1, expert, - gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc1 uses the same A input for all experts and the scaling factor B offsets from - // the local expert index - if (is_active_expert) { - layout_info1.fpX_block_scaling_factors_A[expert] = fp4_act_flat1; - layout_info1.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc1.weight_block_scale + getOffsetWeightSF( - local_expert, gemm1_n, gemm1_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info1.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info1.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - - if (quant_params.fp4.fc2.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info2, expert, - gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc2 scaling factor B offsets by the local expert index - if (is_active_expert) { - layout_info2.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc2.weight_block_scale + getOffsetWeightSF( - local_expert, gemm2_n, gemm2_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info2.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info2.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif - - assert(gemm_m <= INT32_MAX); - assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); - assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); - assert(gemm2_n > 0 && gemm2_n <= INT32_MAX); - assert(gemm2_k > 0 && gemm2_k <= INT32_MAX); - computeTmaWarpSpecializedInputStrides(layout_info1, gemm_m, gemm1_n, gemm1_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - computeTmaWarpSpecializedInputStrides(layout_info2, gemm_m, gemm2_n, gemm2_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - - if (is_active_expert) { - // Note: under low latency mode, we use the same input for all experts - // so for gemm1, the inputs are the same, - // for gemm2, we use the input generated by gemm1 - layout_info1.ptr_a[expert] = in1; - layout_info2.ptr_a[expert] = safe_inc_ptr(in2, expert * num_tokens * gemm2_k); - - // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info1.ptr_b[expert] = safe_inc_ptr(weights1, local_expert * (gemm1_n * gemm2_k)); - layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); - - assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); - - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); - } - } else { - layout_info1.ptr_a[expert] = nullptr; - layout_info2.ptr_a[expert] = nullptr; - layout_info1.ptr_b[expert] = nullptr; - layout_info2.ptr_b[expert] = nullptr; - - layout_info1.default_epilogue.ptr_d[expert] = nullptr; - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; - } - } -} - // ========================== Permutation things ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. @@ -2886,68 +2571,6 @@ CutlassMoeFCRunner: return std::make_pair(layout_info1, layout_info2); } -template -std::pair -CutlassMoeFCRunner::computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream) { - ORT_ENFORCE(!use_w4afp8, "W4AFP8 is not supported in low latency mode"); - - // Always nullptr - layout_info1.ptr_c = nullptr; - layout_info1.stride_c = nullptr; - layout_info2.ptr_c = nullptr; - layout_info2.stride_c = nullptr; - - auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc1.global_scale - ? quant_params.fp8_mxfp4.fc1.global_scale - : quant_params.mxfp8_mxfp4.fc1.global_scale) - : fp8_dequant1; - auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc2.global_scale - ? quant_params.fp8_mxfp4.fc2.global_scale - : quant_params.mxfp8_mxfp4.fc2.global_scale) - : fp8_dequant2; - if (!alpha_scale_flat1) { - layout_info1.alpha_scale_ptr_array = nullptr; - } - if (!alpha_scale_flat2) { - layout_info2.alpha_scale_ptr_array = nullptr; - } - - layout_info1.int4_groupwise_params.enabled = false; - layout_info2.int4_groupwise_params.enabled = false; - - int const threads = std::min(1024, num_experts); - int const blocks = (num_experts + threads - 1) / threads; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, - computeStridesTmaWarpSpecializedLowLatencyKernel, layout_info1, - layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, input1, input2, weights1, weights2, - alpha_scale_flat1, alpha_scale_flat2, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, bias1, bias2, output1, - output2, num_active_experts_per, active_expert_global_ids, start_expert); - - return std::make_pair(layout_info1, layout_info2); -} - template std::pair CutlassMoeFCRunner::setupTmaWarpSpecializedInputs( @@ -3083,78 +2706,6 @@ __global__ void populateRandomBufferKernel(void* buffer_void, size_t size) { buffer[tid * elem_per_thread + i] = curand4(&state); } -template -__global__ void prepareMinLatencyBuffer(int* num_active_experts_per_node, int* active_expert_global_ids, - int64_t* expert_first_token_offset, int const num_tokens, int const num_experts_per_token, - int const num_experts_per_node) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - // 0. set offset - num_active_experts_per_node += bid; - active_expert_global_ids += bid * num_experts_per_node; - expert_first_token_offset += bid * (num_experts_per_node + 1); - - // 1. set the num_active_experts_per_node - int num_active = max(1, (int)(bid * ((float)num_experts_per_node / NUM_ROUTING_SAMPLES))); - *num_active_experts_per_node = num_active; - - // 2. generate random active experts - extern __shared__ float s_buf[]; - float* expert_refs = s_buf; - int* expert_refs_idx = reinterpret_cast(expert_refs + num_experts_per_node); - - curandState_t local_state; - curand_init(bid, tid, 0, &local_state); - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - expert_refs[i] = (float)curand_uniform(&local_state); - expert_refs_idx[i] = (int)i; - } - __syncthreads(); - - float thread_key[1]; - int thread_value[1]; - thread_key[0] = std::numeric_limits::max(); - thread_value[0] = num_experts_per_node; - - if (tid < num_experts_per_node) { - thread_key[0] = expert_refs[tid]; - thread_value[0] = expert_refs_idx[tid]; - } - - using BlockRadixSort = cub::BlockRadixSort; - using BlockRadixSortValue = cub::BlockRadixSort; - - union TempStorage { - typename BlockRadixSort::TempStorage key_value; - typename BlockRadixSortValue::TempStorage value; - }; - __shared__ union TempStorage temp_storage; - - BlockRadixSort(temp_storage.key_value).Sort(thread_key, thread_value); - __syncthreads(); - - if (tid > num_active) { - thread_value[0] = std::numeric_limits::max(); - } - BlockRadixSortValue(temp_storage.value).Sort(thread_value); - __syncthreads(); - - // 3. set the active_expert_global_ids and expert_first_token_offset - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - if (i < num_active) { - active_expert_global_ids[i] = thread_value[0]; - expert_first_token_offset[i] = i * num_tokens; - } else { - active_expert_global_ids[i] = -1; - expert_first_token_offset[i] = num_active * num_tokens; - } - } - if (tid == 0) { - expert_first_token_offset[num_experts_per_node] = num_active * num_tokens; - } -} - void populateRandomBuffer(void* buffer_void, size_t size, cudaStream_t stream) { // Each thread initialises 128 bytes ORT_ENFORCE(size % 128 == 0, "Unexpected size alignment"); @@ -3292,10 +2843,6 @@ std::map> GemmProfilerBackend::getProfile size_t const blocked_expert_counts_cumsum_size = blocked_expert_counts_size; size_t const blocked_row_to_unpermuted_row_size = num_experts_per_node * maxM * sizeof(int); - // The follow buffers are used in min_latency_mode - size_t num_active_experts_per_node_size = 0; - size_t active_expert_global_ids_size = 0; - size_t map_offset = 0; std::map> out_map; @@ -3316,8 +2863,6 @@ std::map> GemmProfilerBackend::getProfile ADD(blocked_expert_counts_cumsum); ADD(blocked_row_to_unpermuted_row); ADD(token_topk_unpermuted_scales); - ADD(num_active_experts_per_node); - ADD(active_expert_global_ids); ADD(input); ADD(output); ADD(intermediate); @@ -3358,8 +2903,6 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha GET_WS_PTR(int*, blocked_expert_counts); GET_WS_PTR(int*, blocked_expert_counts_cumsum); GET_WS_PTR(int*, blocked_row_to_unpermuted_row); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR_BASE #undef GET_WS_PTR @@ -3473,8 +3016,6 @@ void GemmProfilerBackend::prepareTmaWsInputs( GET_WS_PTR(void*, gemm_workspace); GET_WS_PTR(float*, alpha_scale_ptr_array); GET_WS_PTR(TmaWarpSpecializedGroupedGemmInput::ElementSF*, fp4_act_scale_flat); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR @@ -3573,8 +3114,6 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac GET_WS_PTR(float const*, token_topk_unpermuted_scales); auto const* token_topk_permuted_scales = token_topk_unpermuted_scales; - GET_WS_PTR_OFFSET(int*, num_active_experts_per_node, mSampleIndex); - GET_WS_PTR_OFFSET(int*, active_expert_global_ids, (mSampleIndex * mNumExpertsPerNode)); GET_WS_PTR(void const*, input); GET_WS_PTR(void*, output); GET_WS_PTR(void*, intermediate); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h index 9f780c313e1fd..bf9ab2e684c5a 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h @@ -305,16 +305,6 @@ class CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) = 0; - virtual std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) = 0; - virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; @@ -521,25 +511,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { reinterpret_cast(gemm2_output), stream); } - std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) override { - return Self::computeStridesTmaWarpSpecializedLowLatency(layout_info1, layout_info2, num_tokens, gemm1_n, - gemm1_k, gemm2_n, gemm2_k, num_experts, reinterpret_cast(input1), - reinterpret_cast(input2), reinterpret_cast(weights1), - reinterpret_cast(weights2), fp8_dequant1, fp8_dequant2, fc1_fp4_act_flat, - fc2_fp4_act_flat, quant_params, reinterpret_cast(bias1), - reinterpret_cast(bias2), reinterpret_cast(output1), - reinterpret_cast(output2), num_active_experts_per, active_expert_global_ids, - start_expert, stream); - } - private: std::pair setupTmaWarpSpecializedInputs( int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, bool use_fused_gated_activation, @@ -560,16 +531,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); - static std::pair - computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream); std::map> getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type,