From b77b31239621c45fe04856e8fcfad396bf84be85 Mon Sep 17 00:00:00 2001 From: adrastogi Date: Wed, 24 Jun 2026 14:47:00 -0700 Subject: [PATCH 01/10] Relax CompileModel validation to accept zero-input OrtModel graphs (#28771) ### Description Relax the input-validation in OrtApi::CompileModel to accept OrtModel instances with zero graph inputs. Previously, ModelCompilationOptions::Check() rejected such models with "OrtModel graph must have at least one input and one output defined." The check now requires only at least one graph output; the zero-input case is legal. Tests in test_model_builder_api.cc are restructured: - The old CompileFromModelWithEmptyInputsOutputs_Fails is renamed to CompileFromModelWithEmptyOutputs_Fails and reshaped to provide 1 input + 0 outputs, isolating the output-only check. - A new regression test CompileFromModelWithEmptyInputs_Succeeds builds a 0-input model with a RandomNormal node and verifies compilation succeeds. ### Motivation and Context Fixes #28135 The original check was too restrictive and impacts callers (e.g., WebNN/Chromium needs to call CompileModel on such models in a separate compiler process (and then load the compiled artifact via CreateSessionFromArray in the GPU process)). --- .../core/session/model_compilation_options.cc | 7 +- .../test/shared_lib/test_model_builder_api.cc | 83 +++++++++++++++++-- 2 files changed, 82 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index a393bb42fe2cb..a17802bdd7573 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -283,9 +283,12 @@ Status ModelCompilationOptions::Check() const { "OrtModel has no graph. Call AddGraphToModel before compilation."); } - if (input_model_->graph->GetNumInputs() == 0 || input_model_->graph->GetNumOutputs() == 0) { + // A model with zero graph inputs is legal (e.g., a graph composed of zero-input + // generator ops like RandomNormal that produces output without external input). + // We still require at least one graph output for the compiled model to be meaningful. + if (input_model_->graph->GetNumOutputs() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "OrtModel graph must have at least one input and one output defined."); + "OrtModel graph must have at least one output defined."); } if (input_model_->domain_to_version.empty()) { diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index a1f2f9102f027..91fc61f19e0df 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -887,11 +887,20 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithNoGraph_Fails) { EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("graph")); } -// Test validation: model with empty inputs/outputs -TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { - // Create a model with a graph that has no inputs or outputs +// Test validation: model with no outputs (one input but zero outputs). +// 0 outputs is still rejected because compilation produces an output model that +// would have no consumers for any computed values. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyOutputs_Fails) { Ort::Graph graph; - // Don't set inputs or outputs + + // Provide a single input but no outputs, to isolate the output-count check. + std::vector graph_inputs; + std::vector dims({4}); + TensorTypeAndShapeInfo tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims); + auto type_info = TypeInfo::CreateTensorInfo(tensor_info.GetConst()); + graph_inputs.emplace_back("X", type_info.GetConst()); + graph.SetInputs(graph_inputs); + // Intentionally do not call SetOutputs. std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; Model model(opsets); @@ -909,8 +918,70 @@ TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputsOutputs_Fails) { compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with empty inputs/outputs"; - EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("input")); + EXPECT_FALSE(status.IsOK()) << "Expected CompileModel to fail for model with no outputs"; + EXPECT_THAT(status.GetErrorMessage(), ::testing::HasSubstr("at least one output")); +} + +// Test: model with zero graph inputs is now accepted by CompileModel. +// Mirrors what CreateSessionFromModel already accepts (e.g., a graph composed of +// zero-input generator ops like RandomNormal that produces output without external input). +// Regression test for https://github.com/microsoft/onnxruntime/issues/28135. +// +// Scope: this test exercises only the ORT-side validation in ModelCompilationOptions::Check(). +// EP-specific validation (e.g., whether the WebNN EP's partitioner accepts a 0-input subgraph) +// is owned by the respective EP and is not covered here. +TEST(ModelEditorCompileAPITest, CompileFromModelWithEmptyInputs_Succeeds) { + Ort::Graph graph; + + // Zero graph inputs; one graph output produced by a RandomNormal node. + // RandomNormal takes 0 inputs and produces a tensor with shape specified via attribute. + // Use RandomNormal rather than Constant because Constant nodes are folded into initializers + // at load time (see graph.cc Graph::LoadFromModelEditorApiModel) and would not exercise + // the true 0-input producer path. + std::vector output_dims = {2, 3}; + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + + std::vector graph_outputs; + graph_outputs.emplace_back("Y", output_type_info.GetConst()); + graph.SetOutputs(graph_outputs); + // Intentionally do not call SetInputs (zero graph inputs). + + std::vector attributes; + std::vector shape_attr_value = {2, 3}; + attributes.push_back(OpAttr("shape", shape_attr_value.data(), + static_cast(shape_attr_value.size()), + OrtOpAttrType::ORT_OP_ATTR_INTS)); + + Node node("RandomNormal", onnxruntime::kOnnxDomain, "RandomNormal1", + /*input_names*/ {}, /*output_names*/ {"Y"}, attributes); + graph.AddNode(node); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + Model model(opsets); + model.AddGraph(graph); + + Ort::SessionOptions session_options; + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + + compile_options.SetInputModel(static_cast(model)); + compile_options.SetEpContextEmbedMode(true); + + std::unique_ptr allocator = std::make_unique(); + void* output_buffer = nullptr; + size_t output_size = 0; + compile_options.SetOutputModelBuffer(allocator.get(), &output_buffer, &output_size); + + // Compile should succeed. No specific compiling EP is required; the default + // kGenerateModel action emits an output model even when no EPContext nodes are produced. + ASSERT_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + + // Fatal: a compile that returns OK but produces no artifact bytes is a regression. + ASSERT_NE(output_buffer, nullptr); + ASSERT_GT(output_size, 0u); + + allocator->Free(output_buffer); } // Test: model can be reused after compilation. From eade9ea991bb1cc52864bd27548d93fc2b1460be Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Wed, 24 Jun 2026 15:08:38 -0700 Subject: [PATCH 02/10] Clamp derived sequence lengths and KV-cache index in CUDA GroupQueryAttention (#29240) ### Description The CUDA `GroupQueryAttention` kernel derives a KV-cache append offset from the `seqlens_k` input (`past_seq_lens = (seqlens_k + 1) - sequence_length`). On the CUDA EP `seqlens_k` is device-resident (only `total_sequence_length` is a CPU input), so the host-side range validation in the operator/helper is skipped. The device kernel `UnpackRoPEAppend` then guarded the cache store with only a one-sided upper bound (`cache_s < max_seqlen`), so an out-of-range `seqlens_k` could produce a negative offset that is sign-extended into the cache-index arithmetic. The CPU operator already validates `seqlens_k` host-side; this change brings the CUDA path to parity by guarding on the device. ### Changes - `group_query_attention_impl.cu` (`GetSequenceLengths`): clamp the negative case at the source so both `total_seq_lens` and the append offset `past_seq_lens` stay non-negative for all downstream consumers. - `group_query_attention_qkv.cuh` (`UnpackRoPEAppend`): make the KV-cache store bound two-sided (`cache_s >= 0 && cache_s < max_seqlen`), mirroring the existing position-index guard a few lines above. This also covers the fast-decode path, where `past_seq_lens` points directly at the raw input and bypasses `GetSequenceLengths`. - Added `NegativeSeqlensK_CacheAppend_NoOOB_CUDA` regression test exercising the KV-cache append path with an out-of-range `seqlens_k` (CUDA-guarded; skips when CUDA EP is unavailable). ### Notes - The two-sided guard matches the pattern introduced for the rotary position index in #27597. - CPU is unaffected (already validated host-side); WebGPU relies on the CPU-validated `total_sequence_length`. The CUDA implementation is shared with ROCm via hipify. - The regression is a device-memory write best observed under `compute-sanitizer`; the test asserts the run completes with finite outputs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cuda/bert/group_query_attention_impl.cu | 9 +- .../cuda/bert/group_query_attention_qkv.cuh | 4 +- .../group_query_attention_op_test.cc | 103 ++++++++++++++++++ 3 files changed, 113 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 6a2d95f089f2b..aa84ee05c2bd1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -641,13 +641,18 @@ __global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, const bool is_first_prompt) { int i = threadIdx.x + blockIdx.x * blockDim.x; if (i < batch_size) { - const int total_len = total_seq_lens_minus_one[i] + 1; + // total_seq_lens_minus_one is the seqlens_k input and is not range-checked on the device. + // Clamp the negative case at the source so the derived lengths below stay non-negative and + // cannot flow as negative offsets into KV-cache or attention index computations. + const int seqlens_k = total_seq_lens_minus_one[i]; + const int total_len = (seqlens_k > 0 ? seqlens_k : 0) + 1; total_seq_lens[i] = total_len; if (is_first_prompt) { past_seq_lens[i] = 0; padded_seq_lens[i] = sequence_length; } else { - past_seq_lens[i] = total_len - sequence_length; + const int past_len = total_len - sequence_length; + past_seq_lens[i] = past_len > 0 ? past_len : 0; padded_seq_lens[i] = 0; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 0c62aef11d53a..1a5dc6e0fca6b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -231,7 +231,9 @@ __global__ void UnpackRoPEAppend( } else { // Store K or V into the KV cache at index (past_seqlen + s) const int cache_s = past_seq_lens[b] + s; - if (cache_s < max_seqlen) { + // Two-sided bound: the lower check mirrors the position guard above and prevents a + // negative offset from being sign-extended into the cache index arithmetic below. + if (cache_s >= 0 && cache_s < max_seqlen) { void* cache_ptr = (head_type == KEY) ? k_cache : v_cache; if (cache_ptr != nullptr) { int64_t cache_idx; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 645564d01abc0..c6f6e0affc163 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -1569,6 +1569,109 @@ TEST(GroupQueryAttentionTest, SharedKV_EmptyKV_WithPast_Rotary_Prompt_CUDA) { ExpectOutputsMatch(cuda_output, cpu_output, 0.05f, "SharedKV_Rotary_Prompt_CUDA_vs_CPU"); } +// CUDA: out-of-range (negative) seqlens_k must not drive an out-of-bounds KV-cache write. +// On the CUDA EP seqlens_k is device-resident, so the host-side range check in the operator is +// skipped and the derived append offset is clamped on the device instead. With sequence_length > 1 +// the non-fast-decode path is taken, exercising both the derived-length clamp and the cache-store +// bound. The run must complete and yield finite outputs. This is a memory-safety regression that is +// most precisely observed under compute-sanitizer, where the pre-clamp code reported an invalid +// device write at this site. +TEST(GroupQueryAttentionTest, NegativeSeqlensK_CacheAppend_NoOOB_CUDA) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 2; // > 1 forces the non-fast-decode path + constexpr int past_seq_len = 4; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 16; // must be a multiple of 16 for rotary + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int total_sequence_length = past_seq_len + sequence_length; + constexpr int present_seq_len = total_sequence_length; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + std::vector past_key_data(batch_size * kv_num_heads * past_seq_len * head_size); + std::vector past_value_data(batch_size * kv_num_heads * past_seq_len * head_size); + for (size_t i = 0; i < query_data.size(); ++i) query_data[i] = 0.05f * static_cast(i % 7 + 1); + for (size_t i = 0; i < key_data.size(); ++i) key_data[i] = 0.04f * static_cast(i % 5 + 1); + for (size_t i = 0; i < value_data.size(); ++i) value_data[i] = 0.03f * static_cast(i % 3 + 1); + for (size_t i = 0; i < past_key_data.size(); ++i) past_key_data[i] = 0.02f * static_cast(i % 11 + 1); + for (size_t i = 0; i < past_value_data.size(); ++i) past_value_data[i] = 0.01f * static_cast(i % 13 + 1); + + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, ToFloat16(query_data)); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(key_data)); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, ToFloat16(value_data)); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, head_size}, ToFloat16(past_key_data)); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, head_size}, + ToFloat16(past_value_data)); + + // seqlens_k is negative, so the derived past length, (max(seqlens_k, 0) + 1) - sequence_length, is + // negative (here 0 + 1 - 2 = -1). The device-side derivation must neutralize this so the cache append + // for the new tokens stays within the present buffer instead of indexing before its start. + tester.AddInput("seqlens_k", {batch_size}, {-1}); + // Marked as an initializer so shape inference can read the value at graph-build time and size + // present_kv to max(past_seq_len, total_sequence_length), matching the declared present outputs below. + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, ToFloat16(cos_cache)); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, ToFloat16(sin_cache)); + + // Valid position_ids so the rotary index path is well-formed and only the cache-store bound is stressed. + std::vector position_ids(batch_size * sequence_length); + for (int s = 0; s < sequence_length; ++s) { + position_ids[s] = static_cast(past_seq_len + s); + } + tester.AddInput("position_ids", {batch_size, sequence_length}, position_ids); + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, MLFloat16(0.0f))); + const int present_size = batch_size * kv_num_heads * present_seq_len * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, head_size}, + std::vector(present_size, MLFloat16(0.0f))); + + // The malformed seqlens_k drives the derived past length negative, which is the condition under test. + // That leaves the KV length under-specified for the query, so the attention is degenerate and its + // outputs may be non-finite; this is expected and intentionally not asserted. The regression point is + // that the cache append and attention complete without indexing outside their buffers (which a + // sanitizer build would otherwise flag), so only the output shape is verified. + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([](const std::vector& fetches, + const std::string& /*provider*/) { + ASSERT_FALSE(fetches.empty()); + ASSERT_TRUE(fetches[0].IsTensor()); + EXPECT_EQ(fetches[0].Get().Shape().Size(), static_cast(output_size)); + }); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // --------------------------------------------------------------------------- // Quantized KV cache tests for CPU GroupQueryAttention // --------------------------------------------------------------------------- From 996cea1e829b9f577fa8ffa37045d26f053b0960 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 24 Jun 2026 15:55:42 -0700 Subject: [PATCH 03/10] Add flash attention for non-quantized CPU GroupQueryAttention (#28962) ## Summary Adds an FP32 flash attention path for the CPU `com.microsoft.GroupQueryAttention` (GQA) contrib op, mirroring the existing quantized-KV flash attention path. The new tiled, online-softmax kernel avoids materializing the full `[S, T]` attention score matrix. It is restricted to prefill / chunked-prefill (`sequence_length > 1`); single-token decode falls back to the naive path. With causal early-termination it is faster than the naive path across all measured prefill lengths while using a fraction of the memory. ## Key changes - **New MLAS kernel** `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (`MlasFlashAttentionGQA`): - Tiled QK / softmax / SV with online-softmax (running max/sum rescaling). - GQA head grouping (`num_heads % kv_num_heads == 0`), causal masking, local window, additive attention bias, and packed-QKV input. - **Causal early-termination**: during prefill, KV blocks that fall entirely in the causally masked upper triangle are skipped (`break` once `ir >= past_seqlen + q_idx + row_size_q`), avoiding the wasted QK/SV GEMMs over roughly half of the square prefill attention matrix. - Per-batch invocation for ragged / shared-buffer `seqlens_k`. - **MLAS API** `onnxruntime/core/mlas/inc/mlas.h`: new `MlasFlashAttentionGQAArgs` struct and `MlasFlashAttentionGQA` declaration. - **Dispatch** `onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h`: new `ApplyAttentionFlash` that concatenates new K/V into the FP32 present cache and invokes the kernel. The per-thread scratch buffer size is computed with `SafeInt` to guard against `size_t` overflow on large/malformed shapes before allocation. - **Wiring** `onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc`: float-only flash dispatch, active only for prefill (`sequence_length > 1`) and when `softcap == 0`, no smooth softmax, no head sink, no QK output; falls back to the naive path otherwise. The existing `ORT_GQA_DISABLE_FLASH_ATTENTION` env var disables it. - **CMake** `cmake/onnxruntime_mlas.cmake`: register the new source file. - **Docs** `docs/contrib_ops/cpu/gqa.md`: document the non-quantized flash attention path, activation conditions, causal early-termination, file list, and FP32 flash-vs-naive benchmark results. - **Benchmark** `onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py`: add an FP32 (non-quantized) mode (`--fp32`) for operator-level flash-vs-naive comparison. ### Why prefill-only (`sequence_length > 1`) Single-token decode (`sequence_length == 1`) produces only a `[1, total_sequence_length]` score row per head, so there is nothing to tile away and the extra online-softmax bookkeeping makes the flash kernel slower and noisier than naive in practice. Restricting the flash path to prefill keeps the consistent prefill win without regressing decode. Because decode is excluded, the two-phase flash-decoding kernels are unreachable and have been removed for a smaller, simpler implementation. `float16` continues to use the naive path (the kernel is float-only, matching the quantized flash constraint). ## Performance Operator-level, AMD EPYC 7763 (16 physical cores), threads=8, FP32 KV cache, `B=1, num_heads=16, kv_num_heads=8, head_size=128`. Flash is faster than naive across all measured prefill lengths (and single-threaded as well, 1.4-1.8x), confirming the gain is algorithmic - the causal early-termination removes the wasted upper-triangle work that previously made flash slower than naive at short sequences. | Prefill Seq Length | Naive (ms) | Flash (ms) | Speedup | |---:|---:|---:|---:| | 512 | 5.8-8.4 | 4.2-5.3 | 1.4-1.6x | | 1024 | 25-29 | 13-18 | 1.6-2.0x | | 2048 | 87-118 | 52-65 | 1.5-2.0x | | 4096 | 365-380 | 213-234 | 1.6-1.7x | The flash path's primary structural benefit is memory: it never allocates the full O(N x S x T) attention matrix (~1 GB at S=4096, N=16) and instead uses an O(S x Bc) per-thread tile. ## Testing - **C++ op tests**: `onnxruntime_provider_test --gtest_filter="GroupQueryAttentionTest.*"` - 38 passed (12 GPU/WebGPU skipped) with flash on (default) and with `ORT_GQA_DISABLE_FLASH_ATTENTION=1`. - **Flash vs. naive parity** (FP32): output of the flash path matches the naive path (max abs diff ~1e-7) across prefill (block-aligned and non-aligned `S`), MHA and GQA head ratios, and local window. Decode now uses the naive path on both sides (diff 0). - **Python parity** (`test_gqa_cpu.py`, flash vs. naive reference): focused FP32 sweep of 600 prompt configurations covering all head sizes (32-256), GQA ratios `(6,6)/(6,3)/(9,9)/(9,3)`, batches `1/3/5`, causal/local window, attention bias, position ids, packed QKV, and with/without KV buffer - all passed. The official `test_gqa_cpu.py` suite passes. Two correctness bugs were found and fixed via the parity sweep while developing this path: 1. Attention-bias batch stride ignored head broadcasting for `[batch, 1, S, T]` bias. 2. Query batch stride was hardcoded to `num_heads * S * H`, which is incorrect for packed-QKV input (correct stride is `(num_heads + 2 * kv_num_heads) * S * H`). --- cmake/onnxruntime_mlas.cmake | 1 + docs/contrib_ops/cpu/gqa.md | 101 +++++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 284 ++++++++++++++++ .../cpu/bert/group_query_attention.cc | 25 ++ onnxruntime/core/mlas/inc/mlas.h | 51 +++ onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 310 ++++++++++++++++++ .../transformers/benchmark_gqa_cpu_flash.py | 197 ++++++++--- 7 files changed, 911 insertions(+), 58 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/flashattn_gqa.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b8ab7142b6b35..d55a4d49fb455 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -56,6 +56,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/flashattn_qkv.cpp + ${MLAS_SRC_DIR}/flashattn_gqa.cpp ${MLAS_SRC_DIR}/qkv_quant.cpp ${MLAS_SRC_DIR}/cast.cpp ${MLAS_SRC_DIR}/layernorm.cpp diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index e5a211c9fd11a..8b81fdba8f1a6 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -17,7 +17,12 @@ Quantized KV-cache GEMM helpers are implemented in MLAS: - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx2.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_avx512vnni.cpp` - `onnxruntime/core/mlas/lib/qkv_quant_kernel_neon.cpp` -- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (flash attention tiled kernel) +- `onnxruntime/core/mlas/lib/flashattn_qkv.cpp` (quantized-KV flash attention tiled kernel) + +The non-quantized flash attention tiled kernel is implemented in MLAS: + +- `onnxruntime/core/mlas/lib/flashattn_gqa.cpp` (FP32-KV flash attention tiled kernel) +- `onnxruntime/core/mlas/inc/mlas.h` (`MlasFlashAttentionGQA` declaration and `MlasFlashAttentionGQAArgs`) The operator schema itself is defined in: @@ -48,12 +53,14 @@ At a high level, the CPU kernel executes GroupQueryAttention in these stages: The non-quantized and quantized paths share the surrounding validation, masking, softmax, and output flow. Their main difference is how the K/V cache is stored and read during QK and SV GEMMs. -The quantized path has two execution strategies: +Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode. + +The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). ## Supported Cache Modes @@ -144,9 +151,9 @@ For quantized V cache, the CPU path calls `MlasSVGemm` with: As with QK GEMM, the default MLAS contract preserves the FP32 left-hand operand and dequantizes only the cached V values on the fly. -## Flash Attention Path +## Quantized Flash Attention Path -The flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. +The quantized flash attention path (`MlasFlashAttentionQuantizedKV`) processes K/V in blocks with online softmax, fusing QK, masking, softmax, and SV into a single tiled loop. This avoids the O(S×T) memory allocation for the full attention matrix. ### Algorithm @@ -204,6 +211,58 @@ The partials buffer is allocated alongside the per-thread scratch in a single al - Per-thread scratch: `scores[Bc]` (one float per KV block element) - Partials: `batch × num_heads × kv_chunks × (2 + H)` floats (m, l, and partial output per chunk) +## Non-Quantized Flash Attention Path + +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here. + +### Differences from the Quantized Path + +- **Cache element type**: The present K/V cache is FP32, laid out as BNSH (`[batch, kv_num_heads, seqlen_present, head_size]`). There is no quantize-on-write or dequantize-on-read step. +- **QK GEMM**: Uses the single-threaded SGEMM primitive `MlasSgemmOperation(CblasNoTrans, CblasTrans, ...)` on an FP32 K block instead of `MlasQKGemm`. +- **SV accumulate**: Uses `MlasSgemmOperation(CblasNoTrans, CblasNoTrans, ..., beta)` with `beta = 0` for the first KV block and `beta = 1` afterwards (accumulate) instead of `MlasSVGemm`. +- **Cache concat**: New K/V tokens are appended into the FP32 present cache with `ConcatStateChunkGQA` before the tiled loop runs. + +### Algorithm + +For each (batch, head, q_block) tile: + +1. **QK GEMM** — `MlasSgemmOperation` of the query tile against a block slice of the FP32 K cache (Bc rows at a time) +1b. **Attention bias** — Add the corresponding tile of the bias tensor (if present) to QK scores +2. **Causal + local window masking** — Set masked positions to −∞ before softmax +3. **Online softmax** — Track running max `m` and sum `l`, rescale accumulated output with `exp(m_old − m_new)` +4. **SV accumulate** — `MlasSgemmOperation(..., beta)` accumulates `softmax(QK_block) × V_block` into the output tile +5. **Finalize** — Normalize accumulated output by `1/l` after all KV blocks are processed + +#### Causal early-termination + +During prefill, every KV block whose start index is at or beyond the largest global query +position in the current q_block is fully causally masked and contributes nothing. The kernel +computes a per-q_block bound +`kv_causal_limit = past_seqlen + q_idx + row_size_q` and breaks out of the KV loop once +`ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle +QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the +main reason the FP32 flash path is faster than naive even at short sequence lengths +(see the benchmark results below). + +### Activation Conditions + +The non-quantized flash path is selected when ALL of the following hold: + +- The kernel specialization is `float` (FP16 uses the naive path) +- `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) +- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path) +- No softcap +- No smooth softmax +- No head sink +- No output QK capture +- `present_key` and `present_value` are provided + +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes and Threading + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode. + ## MLAS Dispatch Paths MLAS selects the best available quantized KV-cache GEMM implementation through the platform dispatch table. @@ -428,7 +487,39 @@ Flash decoding IS active (batch×heads=4 < threads=8, KV partitioned across idle | 4096 (N=32) | +2131 | +87 | 24.5x | **Summary**: The flash path's primary benefit for prefill is **memory reduction** — avoiding the full O(N×S×T) attention matrix. For S=4096 with 16 heads, the naive path allocates ~1 GB for attention scores while the flash path uses ~80 MB regardless of sequence length. The prefill latency speedup (1.2–2.7x at kernel level, 1.2–1.9x at operator level) comes from improved cache locality. For decode, the tiled kernel provides 1.2–1.8x kernel-level speedup from fused single-pass KV access; at operator level the gain is visible for T≥1024 but masked by KV concat overhead at shorter sequences. When flash decoding is active (batch×heads < threads), KV partitioning across idle threads yields an additional 2–5x speedup for long sequences. +### Non-Quantized (FP32) Flash Attention vs Naive benchmark results + +Measured on an AMD EPYC 7763 (32 logical / 16 physical cores), threads=8, FP32 KV cache, +`B=1, num_heads=16, kv_num_heads=8, head_size=128`. Operator-level, measured with: +```bash +python onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py \ + --fp32 --prompt_only --warmup 10 --repeats 30 +``` + +#### Latency — Prefill (S = T, prompt phase) + +| Seq Length | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 512 | 5.8\u20138.4 | 4.2\u20135.3 | 1.4\u20131.6x | +| 1024 | 25\u201329 | 13\u201318 | 1.6\u20132.0x | +| 2048 | 87\u2013118 | 52\u201365 | 1.5\u20132.0x | +| 4096 | 365\u2013380 | 213\u2013234 | 1.6\u20131.7x | + +The FP32 flash path is faster than naive across all measured prefill lengths. With the causal +early-termination described above, roughly half of the QK/SV work (the causally masked +upper triangle of the square prefill attention matrix) is skipped entirely, which more than +offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/output rescale). +The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is +algorithmic rather than purely from threading. + +#### Decode (S = 1, token generation) + +Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls +back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head, +so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel +slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps +the consistent prefill win without regressing decode. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 59313cf527c91..413483756ed5c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -909,6 +909,290 @@ class GQAAttentionBase { return Status::OK(); } + // Non-quantized flash attention path. Only supports T = float. + // Concatenates new K/V into the FP32 present cache, then runs the tiled + // online-softmax kernel MlasFlashAttentionGQA (QK^T + softmax + S*V fused). + Status ApplyAttentionFlash( + const float* Q, // Q data [B, N, S, H] BNSH + const float* K, // K data [B, N_kv, L, H] or nullptr for packed_qkv + const float* V, // V data [B, N_kv, L, H] or nullptr for packed_qkv + const Tensor* attention_bias, // additive bias [B|1, N|1, S, T] or nullptr + const Tensor* past_key, // past K (float) + const Tensor* past_value, // past V (float) + Tensor* output, // output [B, S, N*H] float + Tensor* present_key, // present K (float) + Tensor* present_value, // present V (float) + const Tensor* seqlens_k, + GroupQueryAttentionParameters& parameters, + AllocatorPtr allocator, + OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int head_size = parameters.head_size; + const int hidden_size = parameters.hidden_size; + const bool packed_qkv = parameters.is_packed_qkv; + + auto* tp = context->GetOperatorThreadPool(); + + int seqlen_past_kv_cache = 0; + if (past_key != nullptr && past_value != nullptr) { + seqlen_past_kv_cache = static_cast(past_key->Shape().GetDims()[2]); + } + int seqlen_present_kv_cache = present_key != nullptr + ? static_cast(present_key->Shape().GetDims()[2]) + : parameters.total_sequence_length; + + if (kv_sequence_length == 0) { + ORT_ENFORCE(parameters.total_sequence_length <= seqlen_past_kv_cache, + "total_seqlen (", parameters.total_sequence_length, ") exceeds past buffer size (", + seqlen_past_kv_cache, ") in shared KV mode"); + } + + ORT_RETURN_IF(present_key == nullptr || present_value == nullptr, + "present_key and present_value must be provided for flash attention"); + + const float* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; + float* present_key_data = present_key->MutableData(); + const float* past_value_data = past_value != nullptr ? past_value->Data() : nullptr; + float* present_value_data = present_value->MutableData(); + + bool past_present_share_buffer = (past_key_data == present_key_data) && + (past_value_data == present_value_data); + + const int32_t* seqlens_k_data = seqlens_k->Data(); + + // Attention bias setup + const float* attention_bias_data = nullptr; + int attention_bias_seqlen_stride = 0; + bool attention_bias_broadcast_batch = true; + bool attention_bias_broadcast_head = true; + if (attention_bias != nullptr) { + attention_bias_data = attention_bias->Data(); + auto bias_shape = attention_bias->Shape().GetDims(); + attention_bias_seqlen_stride = static_cast(bias_shape[3]); + attention_bias_broadcast_batch = (bias_shape[0] == 1); + attention_bias_broadcast_head = (bias_shape[1] == 1); + } + + // K/V base pointers (FP32, new tokens) + const float* k_base = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; + const float* v_base = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; + + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); + const size_t kv_input_chunk_length = SafeInt(kv_sequence_length) * head_size; + const size_t past_buff_chunk_length = SafeInt(seqlen_past_kv_cache) * head_size; + const size_t present_buff_chunk_length = SafeInt(seqlen_present_kv_cache) * head_size; + + // ---- Phase 1: Concat new K/V into present cache ---- + // We must do this first so the flash attention kernel can read the full present cache. + if (present_key_data && !past_present_share_buffer) { + memset(present_key_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + memset(present_value_data, 0, + SafeInt(batch_size) * kv_num_heads_ * present_buff_chunk_length * sizeof(float)); + } + + // Concat K and V caches (parallelize over batch * kv_num_heads) + { + const size_t concat_loop_len = batch_size * kv_num_heads_; + TensorOpCost concat_cost; + concat_cost.compute_cycles = static_cast(kv_sequence_length * head_size); + concat_cost.bytes_loaded = static_cast((past_buff_chunk_length + kv_input_chunk_length) * sizeof(float)); + concat_cost.bytes_stored = static_cast(present_buff_chunk_length * sizeof(float)); + + ThreadPool::TryParallelFor(tp, concat_loop_len, concat_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t kv_idx = begin; kv_idx != end; ++kv_idx) { + const size_t batch_index = kv_idx / kv_num_heads_; + const size_t kv_head_index = kv_idx % kv_num_heads_; + const size_t total_seqlen = SafeInt(seqlens_k_data[batch_index]) + 1; + + size_t past_seqlen; + if (past_key == nullptr) { + past_seqlen = 0; + } else if (kv_sequence_length == 0) { + past_seqlen = total_seqlen; + } else if (is_prompt) { + past_seqlen = 0; + } else { + past_seqlen = total_seqlen - sequence_length; + } + const size_t past_chunk_length = past_seqlen * head_size; + + // Concat K + const float* k_new; + if (packed_qkv) { + k_new = k_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + k_new = k_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_key_data, k_new, present_key_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + + // Concat V + const float* v_new; + if (packed_qkv) { + v_new = v_base + packed_batch_stride * batch_index + + kv_input_chunk_length * kv_head_index; + } else { + v_new = v_base + kv_input_chunk_length * kv_idx; + } + ConcatStateChunkGQA(past_value_data, v_new, present_value_data, + present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, + past_present_share_buffer, kv_idx); + } + }); + } + + // ---- Phase 2: Flash Attention with FP32 KV cache ---- + // Compute L2-aware block sizes (same formula as MHA flash attention). + const auto& env = Env::Default(); + int l2_cache_size = env.GetL2CacheSize(); + + int kv_block_size = l2_cache_size / (static_cast(sizeof(float)) * 4 * (head_size + head_size)); + kv_block_size = std::max(kv_block_size, 1); + int q_block_size = std::min(kv_block_size, 2 * head_size); + + // The flash kernel uses a single (past_seqlen, total_seqlen) pair for all batch items. + // When batch items have different seqlens_k (ragged), fall back to per-batch invocation + // so each batch item gets its own correct causal offset. + int max_total_seqlen = 0; + int min_total_seqlen = std::numeric_limits::max(); + int common_past_seqlen = 0; + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + max_total_seqlen = std::max(max_total_seqlen, total_sl); + min_total_seqlen = std::min(min_total_seqlen, total_sl); + } + const bool ragged_seqlens = (max_total_seqlen != min_total_seqlen); + + if (ragged_seqlens) { + common_past_seqlen = -1; // sentinel: per-batch + } else if (past_key == nullptr || is_prompt) { + common_past_seqlen = 0; + } else if (kv_sequence_length == 0) { + // Shared buffer mode: each batch item has its own past_seqlen. + common_past_seqlen = -1; // sentinel: per-batch + } else { + common_past_seqlen = max_total_seqlen - sequence_length; + } + + // Cap block sizes + kv_block_size = std::min(kv_block_size, max_total_seqlen); + q_block_size = std::min(q_block_size, sequence_length); + + int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + thread_count = std::max(thread_count, 1); + + // Per-thread scratch: l + m + scores[q_block_size * kv_block_size] + temp_output[q_block_size * head_size] + const size_t buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count; + auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); + BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; + + // If all batch items share the same past_seqlen, use the unified flash kernel. + // Otherwise, fall back to per-batch invocation. + if (common_past_seqlen >= 0) { + MlasFlashAttentionGQAArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = max_total_seqlen; + args.head_size = head_size; + args.past_seqlen = common_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = kv_block_size; + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + args.query = Q; + args.q_batch_stride = packed_qkv + ? static_cast(packed_batch_stride) + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.k_cache = present_key_data; + args.v_cache = present_value_data; + args.output = output->MutableData(); + args.attention_bias = attention_bias_data; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + + MlasFlashAttentionGQA(&args, tp); + } else { + // Per-batch handling for variable past_seqlen (shared KV buffer mode or ragged seqlens) + for (int b = 0; b < batch_size; ++b) { + int total_sl = seqlens_k_data[b] + 1; + int batch_past_seqlen = (past_key == nullptr || is_prompt) + ? 0 + : std::max(0, total_sl - sequence_length); + + MlasFlashAttentionGQAArgs args; + args.batch_size = 1; + args.num_heads = num_heads_; + args.kv_num_heads = kv_num_heads_; + args.sequence_length = sequence_length; + args.total_seqlen = total_sl; + args.head_size = head_size; + args.past_seqlen = batch_past_seqlen; + args.local_window_size = local_window_size_; + args.seqlen_present_kv = seqlen_present_kv_cache; + args.q_block_size = q_block_size; + args.kv_block_size = std::min(kv_block_size, total_sl); + args.scale = scale; + args.thread_count = thread_count; + args.buffer = reinterpret_cast(flash_buffer_alloc); + args.buffer_size_per_thread = buffer_size_per_thread; + + // Offset Q and output for this batch + const ptrdiff_t q_batch_stride_elems = packed_batch_stride > 0 + ? packed_batch_stride + : static_cast(SafeInt(num_heads_) * sequence_length * head_size); + args.query = Q + static_cast(b) * static_cast(q_batch_stride_elems); + args.q_batch_stride = SafeInt(num_heads_) * sequence_length * head_size; + args.k_cache = present_key_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.v_cache = present_value_data + + static_cast(b) * kv_num_heads_ * present_buff_chunk_length; + args.output = output->MutableData() + + static_cast(b) * sequence_length * hidden_size; + + // Slice attention bias for this batch (the kernel sees batch_size=1, so batch_idx=0 inside). + // Bias shape is [batch|1, num_heads|1, S, T]; the batch stride uses the actual head + // extent (1 when the head dim is broadcast). + const float* batch_bias = attention_bias_data; + if (attention_bias_data != nullptr && !attention_bias_broadcast_batch) { + const size_t bias_head_extent = attention_bias_broadcast_head ? 1 : static_cast(num_heads_); + batch_bias += static_cast(b) * bias_head_extent * sequence_length * attention_bias_seqlen_stride; + } + args.attention_bias = batch_bias; + args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; + args.attention_bias_broadcast_batch = true; // batch offset handled above + args.attention_bias_broadcast_head = attention_bias_broadcast_head; + + MlasFlashAttentionGQA(&args, tp); + } + } + + return Status::OK(); + } + private: // Helper function to compute the attention probs. It does 2 things: // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index e36bdb2de263a..692a5759d1adc 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -343,6 +343,31 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Compute the attention score and apply the score to V const T* k_data = packed_qkv ? nullptr : k_rotary; const T* v_data = packed_qkv ? nullptr : V.Get().Data(); + + // Non-quantized flash attention path (float only). Uses the tiled online-softmax + // kernel to avoid materializing the full attention score matrix. Falls back to the + // naive path when an unsupported feature is requested (softcap, smooth softmax, + // head sink, or QK output). + if constexpr (std::is_same_v) { + // Restrict the flash path to prefill / chunked-prefill (query length > 1). Single-token + // decode (sequence_length == 1) has no flash benefit: the naive score matrix is only + // [1, total_sequence_length] per head, so there is nothing to tile away, and the extra + // online-softmax bookkeeping makes it slower in practice. + const bool use_flash = !disable_gqa_flash_ && + parameters.sequence_length > 1 && + softcap_ == 0.0f && + !use_smooth_softmax_ && + head_sink_data == nullptr && + output_qk == nullptr && + present_k != nullptr && present_v != nullptr; + if (use_flash) { + return ApplyAttentionFlash(q_rotary, k_data, v_data, + attention_bias, past_key, past_value, + output, present_k, present_v, seqlens_k, + parameters, allocator, context); + } + } + return ApplyAttention(q_rotary, k_data, v_data, head_sink_data, attention_bias, past_key, past_value, output, present_k, present_v, output_qk, seqlens_k, parameters, allocator, context); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 811bad15ebbab..2410fcc83e7cd 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2297,6 +2297,57 @@ MlasFlashAttention( MLAS_THREADPOOL* ThreadPool ); +// +// Flash Attention for non-quantized (FP32) GroupQueryAttention KV cache. +// +// Adapts the online-softmax tiled algorithm to operate on an FP32 present +// K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). +// Supports GQA head grouping, causal masking, local window attention, and +// additive attention bias. Intended for prefill / chunked-prefill +// (sequence_length > 1). +// +struct MlasFlashAttentionGQAArgs { + int batch_size; + int num_heads; // number of query heads + int kv_num_heads; // number of key/value heads (num_heads % kv_num_heads == 0) + int sequence_length; // number of new query tokens (S) + int total_seqlen; // total tokens (past + new) for this invocation (T) + int head_size; // per-head size (H) + int past_seqlen; // causal offset (number of cached tokens before the new ones) + int local_window_size; // -1 disables local window masking + int seqlen_present_kv; // sequence dimension of the present K/V buffer + int q_block_size; // query tile size (Br) + int kv_block_size; // key/value tile size (Bc) + float scale; // QK scaling factor + int thread_count; // number of partitions / threads + float* buffer; // per-thread scratch + size_t buffer_size_per_thread; + + const float* query; // [batch, num_heads, sequence_length, head_size] BNSH + size_t q_batch_stride; // element stride between consecutive batches in `query` + // (num_heads*S*H for unpacked, (num_heads+2*kv_num_heads)*S*H for packed QKV) + const float* k_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + const float* v_cache; // [batch, kv_num_heads, seqlen_present, head_size] FP32 + float* output; // [batch, sequence_length, num_heads, head_size] BSNH + + const float* attention_bias; // [batch|1, num_heads|1, S, T] additive bias, or nullptr + int attention_bias_seqlen_stride; + bool attention_bias_broadcast_batch; + bool attention_bias_broadcast_head; +}; + +/** + * @brief FP32 Flash Attention for GroupQueryAttention with an FP32 KV cache. + * @param args Arguments + * @param ThreadPool Thread pool + */ +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +); + /** * @brief Enumeration of supported GELU algorithm variants. * diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp new file mode 100644 index 0000000000000..4d0ff65733a44 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -0,0 +1,310 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + flashattn_gqa.cpp + +Abstract: + + Flash Attention kernel for the non-quantized (FP32) GroupQueryAttention + KV cache. + + Adapts the online-softmax tiled algorithm from flashattn.cpp to operate on + an FP32 present K/V cache laid out as BNSH + ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head + grouping (num_heads % kv_num_heads == 0), causal masking, local window + attention, and additive attention bias. Intended for prefill / + chunked-prefill (sequence_length > 1). + + QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; + the outer parallelism is provided by MlasExecuteThreaded. + +--*/ + +#include +#include +#include +#include + +#include "mlasi.h" + +void +MlasFlashAttentionGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t q_block_size = static_cast(args->q_block_size); + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t sequence_length = static_cast(args->sequence_length); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: one per (batch, head, q_block) + const ptrdiff_t q_chunk_count = (sequence_length + q_block_size - 1) / q_block_size; + const ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + // Per-thread buffer layout: + // l[q_block_size] - running sum for online softmax + // m[q_block_size] - running max for online softmax + // scores[q_block_size * kv_block_size] - QK scores (S) + // temp_output[q_block_size * head_size] - accumulated output + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_ptr); + float* m = l + q_block_size; + float* scores = m + q_block_size; + float* temp_output = scores + q_block_size * kv_block_size; + + // Initialize running state + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + l[t] = 0.0f; + } + memset(temp_output, 0, static_cast(q_block_size * head_size) * sizeof(float)); + + const size_t row_size_q = static_cast(std::min(q_block_size, sequence_length - q_idx)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers. Layout: [batch, kv_num_heads, seqlen_present, head_size] + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, seq, head_size]. The batch stride is + // supplied separately (args->q_batch_stride) so the kernel works with both the + // standard BNSH layout and packed-QKV input where Q/K/V are interleaved per batch. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(sequence_length) * static_cast(head_size) + + static_cast(q_idx) * static_cast(head_size); + + // Causal early-termination bound: the largest global query position in this + // q_block is (past_seqlen + q_idx + row_size_q - 1), so it can attend to KV + // positions up to that index inclusive. Any KV block starting at or beyond + // (past_seqlen + q_idx + row_size_q) is fully causally masked for every row in + // the block, so it contributes nothing and can be skipped. This avoids the + // wasted QK/SV GEMMs over the causal upper triangle during prefill. + const ptrdiff_t kv_causal_limit = + past_seqlen + q_idx + static_cast(row_size_q); + + // Iterate over KV blocks + for (ptrdiff_t ir = 0; ir < total_seqlen; ir += kv_block_size) { + if (ir >= kv_causal_limit) { + break; + } + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Step 1: QK^T GEMM with FP32 K block + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasTrans, + row_size_q, // M + row_size_kv, // N + static_cast(head_size), // K + scale, // alpha + q_ptr, // A (FP32 query) + static_cast(head_size), // lda + k_block, // B (FP32 K block) + static_cast(head_size), // ldb + 0.0f, // beta + scores, // C (output scores) + row_size_kv // ldc + ); + + // Step 1b: Apply attention bias (additive) if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = + static_cast(sequence_length) * bias_seqlen_stride; + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch + // stride uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + // Add bias tile: bias[q_idx + irow, ir + jcol] + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + const float* bias_row = args->attention_bias + bias_offset + + (q_idx + irow) * bias_seqlen_stride + ir; + float* s_row = scores + irow * static_cast(row_size_kv); + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + s_row[jcol] += bias_row[jcol]; + } + } + } + + // Step 2: Apply causal mask and Step 3: Online softmax update + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float* p = scores + irow * static_cast(row_size_kv); + const ptrdiff_t global_q_pos = past_seqlen + q_idx + irow; + const ptrdiff_t causal_limit = global_q_pos + 1; // can attend to positions [0, causal_limit) + + // Apply causal masking + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + p[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + p[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Online softmax: find row max, update running max +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv); +#endif + + // If the entire row is masked (all scores are -inf), zero the scores + // so the S*V GEMM contributes nothing and skip the softmax state update. + if (rowmax == std::numeric_limits::lowest()) { + memset(p, 0, row_size_kv * sizeof(float)); + continue; + } + + float m_old = m[irow]; + m[irow] = std::max(m[irow], rowmax); + float m_diff = m_old - m[irow]; // <= 0 + + // Compute exp(score - m_new) for each element + float negmax = -m[irow]; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv, &negmax); +#endif + + // Rescale previous state + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + // Rescale accumulated output + float* out_row = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + out_row[icol] *= exp_diff; + } + } else { + l[irow] = rowsum; + } + } + + // Step 4: Accumulate O += S_exp * V_block + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasSgemmOperation( + CblasNoTrans, + CblasNoTrans, + row_size_q, // M + static_cast(head_size), // N + row_size_kv, // K + 1.0f, // alpha + scores, // A (exp softmax scores) + row_size_kv, // lda + v_block, // B (FP32 V block) + static_cast(head_size), // ldb + ir == 0 ? 0.0f : 1.0f, // beta (accumulate after first block) + temp_output, // C (accumulated output) + static_cast(head_size) // ldc + ); + } + + // Final: normalize output by l (softmax denominator) + // Output layout: [batch, sequence_length, num_heads, head_size] + float* output_row = args->output + + (static_cast(batch_idx) * static_cast(sequence_length) + + static_cast(q_idx)) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + const ptrdiff_t output_row_stride = num_heads * head_size; + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q); ++irow) { + float inv_l = (l[irow] > 0.0f) ? (1.0f / l[irow]) : 0.0f; + float* src = temp_output + irow * head_size; + for (ptrdiff_t icol = 0; icol < head_size; ++icol) { + output_row[icol] = src[icol] * inv_l; + } + output_row += output_row_stride; + } + } +} + +void +MLASCALL +MlasFlashAttentionGQA( + MlasFlashAttentionGQAArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); +} diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py index 77ac08cf50d6c..7dbcb16a75973 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_cpu_flash.py @@ -106,6 +106,70 @@ def create_quantized_gqa_graph( return model.SerializeToString() +def create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=None, +): + """Create an ONNX graph for GroupQueryAttention with a non-quantized FP32 KV cache.""" + if buffer_seq_len is None: + buffer_seq_len = seq_len + + hidden_size = num_heads * head_size + kv_hidden_size = kv_num_heads * head_size + + inputs = [ + "query", + "key", + "value", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + ] + + node = helper.make_node( + op_type="GroupQueryAttention", + inputs=inputs, + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=num_heads, + kv_num_heads=kv_num_heads, + domain="com.microsoft", + ) + + graph_input = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, seq_len, kv_hidden_size]), + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + graph_output = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size]), + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, kv_num_heads, buffer_seq_len, head_size] + ), + ] + + graph = helper.make_graph([node], "BenchGQA", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + def benchmark_gqa( batch_size, seq_len, @@ -117,6 +181,7 @@ def benchmark_gqa( past_seq_len=0, warmup=5, repeats=20, + non_quantized=False, ): """Benchmark a single GQA configuration. Returns elapsed time in ms.""" hidden_size = num_heads * head_size @@ -126,54 +191,76 @@ def benchmark_gqa( total_seqlen = past_seq_len + seq_len buffer_seq_len = total_seqlen - onnx_model_str = create_quantized_gqa_graph( - batch_size, - seq_len, - num_heads, - kv_num_heads, - head_size, - quant_type, - bit_width, - buffer_seq_len=buffer_seq_len, - ) - sess_options = SessionOptions() sess_options.intra_op_num_threads = 8 - sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) - # Generate inputs np.random.seed(42) query = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, hidden_size)).astype(np.float32) key = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) value = np.random.uniform(-0.5, 0.5, (batch_size, seq_len, kv_hidden_size)).astype(np.float32) - - cache_dtype = np.uint8 if bit_width == 4 else np.int8 - past_k = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - past_v = np.random.randint( - 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 - ).view(cache_dtype) - seqlens_k = np.array([total_seqlen - 1] * batch_size, dtype=np.int32) total_seq = np.array([total_seqlen], dtype=np.int32) - per_channel = quant_type == "PER_CHANNEL" - scale_size = kv_num_heads * head_size if per_channel else 1 - k_scale = np.full(scale_size, 0.01, dtype=np.float32) - v_scale = np.full(scale_size, 0.01, dtype=np.float32) - - feeds = { - "query": query, - "key": key, - "value": value, - "past_key": past_k, - "past_value": past_v, - "seqlens_k": seqlens_k, - "total_sequence_length": total_seq, - "k_scale": k_scale, - "v_scale": v_scale, - } + if non_quantized: + onnx_model_str = create_fp32_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + past_k = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + past_v = np.random.uniform(-0.5, 0.5, (batch_size, kv_num_heads, buffer_seq_len, head_size)).astype(np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + } + else: + onnx_model_str = create_quantized_gqa_graph( + batch_size, + seq_len, + num_heads, + kv_num_heads, + head_size, + quant_type, + bit_width, + buffer_seq_len=buffer_seq_len, + ) + sess = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) + + cache_dtype = np.uint8 if bit_width == 4 else np.int8 + past_k = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + past_v = np.random.randint( + 0, 255, (batch_size, kv_num_heads, buffer_seq_len, packed_head_size), dtype=np.uint8 + ).view(cache_dtype) + + per_channel = quant_type == "PER_CHANNEL" + scale_size = kv_num_heads * head_size if per_channel else 1 + k_scale = np.full(scale_size, 0.01, dtype=np.float32) + v_scale = np.full(scale_size, 0.01, dtype=np.float32) + + feeds = { + "query": query, + "key": key, + "value": value, + "past_key": past_k, + "past_value": past_v, + "seqlens_k": seqlens_k, + "total_sequence_length": total_seq, + "k_scale": k_scale, + "v_scale": v_scale, + } # Warmup for _ in range(warmup): @@ -242,20 +329,21 @@ def run_benchmarks(args): "past_seq_len": 2048, } ) - # INT4 prefill - configs.append( - { - "label": "Prefill S=2048 INT4", - "batch_size": 1, - "seq_len": 2048, - "num_heads": 16, - "kv_num_heads": 8, - "head_size": 128, - "quant_type": "PER_TENSOR", - "bit_width": 4, - "past_seq_len": 0, - } - ) + # INT4 prefill (quantized mode only) + if not args.fp32: + configs.append( + { + "label": "Prefill S=2048 INT4", + "batch_size": 1, + "seq_len": 2048, + "num_heads": 16, + "kv_num_heads": 8, + "head_size": 128, + "quant_type": "PER_TENSOR", + "bit_width": 4, + "past_seq_len": 0, + } + ) warmup = args.warmup repeats = args.repeats @@ -263,13 +351,15 @@ def run_benchmarks(args): # Save and restore env var to avoid side effects on callers saved_env = os.environ.get("ORT_GQA_DISABLE_FLASH_ATTENTION") + kv_mode = "FP32 (non-quantized)" if args.fp32 else "INT8/INT4 quantized" print("\nBenchmark: CPU GroupQueryAttention — Flash vs Naive") - print(f"Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") + print(f"KV cache: {kv_mode}, Threads: {8}, Warmup: {warmup}, Repeats: {repeats}") print(f"{'Config':<25} {'Naive (ms)':>12} {'Flash (ms)':>12} {'Speedup':>10}") print("-" * 62) for cfg in configs: label = cfg.pop("label") + cfg["non_quantized"] = args.fp32 # Flash path (default) os.environ.pop("ORT_GQA_DISABLE_FLASH_ATTENTION", None) @@ -296,5 +386,6 @@ def run_benchmarks(args): parser.add_argument("--repeats", type=int, default=20, help="Measurement iterations") parser.add_argument("--decode_only", action="store_true", help="Only run decode benchmarks") parser.add_argument("--prompt_only", action="store_true", help="Only run prompt benchmarks") + parser.add_argument("--fp32", action="store_true", help="Use non-quantized FP32 KV cache instead of quantized") args = parser.parse_args() run_benchmarks(args) From 4ed5a4addef306170c9b39c4f094aff0cb16ba7c Mon Sep 17 00:00:00 2001 From: Vineeth Chelur Date: Thu, 25 Jun 2026 05:34:30 +0530 Subject: [PATCH 04/10] Update cpuinfo to include cpuinfo_deinitialize(), fix QNN ETW logging, GQA underflow, and ep_weight_sharing_ctx_gen build (#28245) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This PR contains three commits: **Commit 1: Miscellaneous fixes** - Downgrade QNN ETW profiling mismatch logs from ERROR to VERBOSE to reduce excessive telemetry noise (~1 billion events/week across Windows devices) - Add bounds checking in GQA attention to prevent `size_t` underflow when `seqlens_k` contains invalid data (fixes #27170) - Build `ep_weight_sharing_ctx_gen` for TensorRT, OpenVINO, and VitisAI in addition to QNN **Commit 2: Bump cpuinfo and add `cpuinfo_deinitialize()` integration** Applications that dynamically load and unload the onnxruntime DLL leave orphaned heap allocations from cpuinfo when the library is unloaded mid-process. These are flagged as memory leaks by App Verifier, Valgrind, AddressSanitizer, and LeakSanitizer. This commit bumps `pytorch/cpuinfo` to a version that implements `cpuinfo_deinitialize()` ([pytorch/cpuinfo#387](https://github.com/pytorch/cpuinfo/pull/387)) and adds ORT integration: - `CPUIDInfo::ShutDown()` calls `cpuinfo_deinitialize()` to free heap-allocated globals - `DllMain` calls `ShutdownCpuInfo()` on `DLL_PROCESS_DETACH` - In memleak-check builds, shutdown also runs during process termination - `InstanceCreated` atomic guard prevents singleton creation during DLL unload **Commit 3: Update to official cpuinfo merged fix** After [pytorch/cpuinfo#387](https://github.com/pytorch/cpuinfo/pull/387) merged upstream, updated the dependency to point to `pytorch/cpuinfo` main (`4628dc06`). Patch changes: - **Removed** `win_arm_fp16_detection_fallback.patch` — upstreamed via [pytorch/cpuinfo#348](https://github.com/pytorch/cpuinfo/pull/348) - **Updated** `patch_vcpkg_arm64ec_support.patch` — regenerated for new cpuinfo; still needed ([pytorch/cpuinfo#324](https://github.com/pytorch/cpuinfo/pull/324) not yet merged) - **Updated** `patch_cpuinfo_h_for_arm64ec.patch` — retained, not yet upstream - **Regenerated** `fix_missing_sysfs_fallback.patch` — updated context lines for new cpuinfo code ### Motivation and Context - https://github.com/pytorch/cpuinfo/issues/150 - https://github.com/microsoft/onnxruntime/issues/16117 - https://github.com/microsoft/onnxruntime/issues/23762 --- cmake/deps.txt | 2 +- .../external/onnxruntime_external_deps.cmake | 4 +- cmake/onnxruntime_unittests.cmake | 4 +- .../cpuinfo/fix_missing_sysfs_fallback.patch | 58 ++++++++++++++++--- .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 4 +- .../win_arm_fp16_detection_fallback.patch | 19 ------ .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 4 +- cmake/vcpkg-ports/cpuinfo/portfile.cmake | 7 +-- .../win_arm_fp16_detection_fallback.patch | 19 ------ onnxruntime/core/common/cpuid_info.cc | 9 +++ onnxruntime/core/common/cpuid_info.h | 1 + onnxruntime/core/platform/posix/env.cc | 5 ++ .../qnn/builder/qnn_backend_manager.cc | 4 +- 13 files changed, 78 insertions(+), 62 deletions(-) delete mode 100644 cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch delete mode 100644 cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch diff --git a/cmake/deps.txt b/cmake/deps.txt index e303ccd9f8a98..d6a5b71221dc8 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -50,7 +50,7 @@ protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/downlo psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v3.0.2.zip;a064e663b4d7a337ac291d1bef7337ef4e60a1ae -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/403d652dca4c1046e8145950b1c0997a9f748b57.zip;30b2a07fe4bae8574f89176e56274cacdd6d135b +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/4628dc060ce4e82345dc166bbac875609db4ff69.zip;e58d4b47c16a982111c897e669ae4f1821a393d7 re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 1a1e4921a41e6..ed3b0aa8192a7 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -371,9 +371,7 @@ if (CPUINFO_SUPPORTED) PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && # https://github.com/pytorch/cpuinfo/pull/324 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch && - # https://github.com/pytorch/cpuinfo/pull/348 - ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/win_arm_fp16_detection_fallback.patch + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 23eccb22476df..c32aa7f4ae75a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1609,8 +1609,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() - - if(onnxruntime_USE_QNN) + # Build ep_weight_sharing_ctx_gen for all supported EPs (QNN, TensorRT, OpenVINO, VitisAI) + if(onnxruntime_USE_QNN OR onnxruntime_USE_TENSORRT OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_VITISAI) #qnn ctx generator set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) set(ep_weight_sharing_ctx_gen_src_patterns diff --git a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch index 005cd458fdd2b..47a1054e25107 100644 --- a/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch +++ b/cmake/patches/cpuinfo/fix_missing_sysfs_fallback.patch @@ -1,10 +1,19 @@ diff --git a/src/linux/processors.c b/src/linux/processors.c -index 47bee76..d0c5569 100644 +index fd040a3..2ca8ec4 100644 --- a/src/linux/processors.c +++ b/src/linux/processors.c -@@ -2,0 +3 @@ +@@ -3,6 +3,7 @@ + #include + #include + #include +#include -@@ -291,0 +293,22 @@ + + #if !defined(__ANDROID__) + /* +@@ -289,6 +290,28 @@ static bool max_processor_number_parser(uint32_t processor_list_start, uint32_t + return true; + } + +static uint32_t cpuinfo_linux_get_max_processor_from_sysconf( + uint32_t max_processors_count, + const char* processor_list_name) { @@ -27,13 +36,31 @@ index 47bee76..d0c5569 100644 + return max_processor; +} + -@@ -301 +324 @@ + uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) { + uint32_t max_possible_processor = 0; + if (!cpuinfo_linux_parse_cpulist( +@@ -298,7 +321,7 @@ uint32_t cpuinfo_linux_get_max_possible_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, POSSIBLE_CPULIST_FILENAME); -@@ -323 +346 @@ + } + if (max_possible_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -320,7 +343,7 @@ uint32_t cpuinfo_linux_get_max_present_processor(uint32_t max_processors_count) + #else + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); + #endif - return UINT32_MAX; + return cpuinfo_linux_get_max_processor_from_sysconf(max_processors_count, PRESENT_CPULIST_FILENAME); -@@ -357,0 +381,31 @@ + } + if (max_present_processor >= max_processors_count) { + cpuinfo_log_warning( +@@ -355,6 +378,37 @@ static bool detect_processor_parser(uint32_t processor_list_start, uint32_t proc + return true; + } + +static bool cpuinfo_linux_detect_processors_from_sysconf( + uint32_t max_processors_count, + uint32_t* processor0_flags, @@ -65,7 +92,13 @@ index 47bee76..d0c5569 100644 + return true; +} + -@@ -373 +427,6 @@ + bool cpuinfo_linux_detect_possible_processors( + uint32_t max_processors_count, + uint32_t* processor0_flags, +@@ -370,7 +424,12 @@ bool cpuinfo_linux_detect_possible_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of possible processors in %s", POSSIBLE_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -73,7 +106,13 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + possible_flag, + POSSIBLE_CPULIST_FILENAME); -@@ -392 +451,6 @@ + } + } + +@@ -389,7 +448,12 @@ bool cpuinfo_linux_detect_present_processors( + return true; + } else { + cpuinfo_log_warning("failed to parse the list of present processors in %s", PRESENT_CPULIST_FILENAME); - return false; + return cpuinfo_linux_detect_processors_from_sysconf( + max_processors_count, @@ -81,3 +120,6 @@ index 47bee76..d0c5569 100644 + processor_struct_size, + present_flag, + PRESENT_CPULIST_FILENAME); + } + } + diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/patches/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch index af0f039b6c2a3..18ed80f7944f8 100644 --- a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch +++ b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index aedc983..dab589e 100644 +index 072c987..e43d6ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am @@ -7,7 +7,7 @@ index aedc983..dab589e 100644 IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") +ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") -+ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID for non-VS generators (e.g. Ninja) with MSVC. + IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") + SET(CPUINFO_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index 67bd18e61cc28..9140a233e2ccd 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,13 +6,12 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 403d652dca4c1046e8145950b1c0997a9f748b57 - SHA512 f7cd6dc44bd1120af610cae1337ed4c0f557ba78d2de9c73fed350fa3dfe9512643a1619ae55f5a540c6316a87d641856cca27297bb8766e48f39b7b7a59da1f - HEAD_REF master + REF 4628dc060ce4e82345dc166bbac875609db4ff69 + SHA512 db7a93279f2f6daaf825fbd8552935d8ed671d276b65ad614e11f722b6a6848e663850d65180d33b554d67ef1a36aae842feb368699f90be8f21172a1af1924e + HEAD_REF main PATCHES patch_cpuinfo_h_for_arm64ec.patch patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 - win_arm_fp16_detection_fallback.patch # https://github.com/pytorch/cpuinfo/pull/348 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch b/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch deleted file mode 100644 index 44ac0f13f5466..0000000000000 --- a/cmake/vcpkg-ports/cpuinfo/win_arm_fp16_detection_fallback.patch +++ /dev/null @@ -1,19 +0,0 @@ -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index 5c0a5f3..a07fbe4 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -249,6 +249,14 @@ static void set_cpuinfo_isa_fields(void) { - // guarantee that, but it holds in practice. - cpuinfo_isa.rdm = dotprod; - -+ // PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE may not be available in older -+ // Windows versions. If fp16arith was not detected with -+ // IsProcessorFeaturePresent(PF_ARM_V82_FP16_INSTRUCTIONS_AVAILABLE), fall -+ // back to using the value of dotprod. -+ if (!cpuinfo_isa.fp16arith) { -+ cpuinfo_isa.fp16arith = dotprod; -+ } -+ - /* Windows API reports all or nothing for cryptographic instructions. */ - const bool crypto = IsProcessorFeaturePresent(PF_ARM_V8_CRYPTO_INSTRUCTIONS_AVAILABLE) != 0; - cpuinfo_isa.aes = crypto; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index ebf3cc9f50be6..ec5c1386e8336 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -405,4 +405,13 @@ CPUIDInfo::CPUIDInfo() { #endif #endif // defined(CPUIDINFO_ARCH_RISCV64) } + +CPUIDInfo::~CPUIDInfo() { +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + cpuinfo_deinitialize(); + pytorch_cpuinfo_init_ = false; + } +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index bf502c645c9eb..6eed234332f46 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -110,6 +110,7 @@ class CPUIDInfo { static void LogEarlyWarning(std::string_view message); CPUIDInfo(); + ~CPUIDInfo(); void VendorInfoInit(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 0270bf9d4d79c..c34d8b3dbf696 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -657,6 +657,11 @@ class PosixEnv : public Env { } } } + ~PosixEnv() { + if (cpuinfo_available_) { + cpuinfo_deinitialize(); + } + } bool cpuinfo_available_{false}; #endif // ORT_USE_CPUINFO }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5758ff3ad2847..f586fc8e117a6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1871,13 +1871,13 @@ Status QnnBackendManager::ExtractBackendProfilingInfo(qnn::profile::ProfilingInf // ETW disabled previously, but enabled now if (ProfilingLevel::INVALID == profiling_level_etw_ && tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW disabled previously, but enabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } // ETW enabled previously, but disabled now if (ProfilingLevel::INVALID != profiling_level_etw_ && !tracelogging_provider_ep_enabled) { - LOGS(*logger_, ERROR) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; + LOGS(*logger_, WARNING) << "ETW enabled previously, but disabled now. Can't do the switch! Won't output any profiling."; return Status::OK(); } From 3b022ecb71f4756d899e06380afe1c2d2850edc1 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 24 Jun 2026 17:20:56 -0700 Subject: [PATCH 05/10] [CUDA] Support user compute stream with CUDA graph in CUDA plugin EP (#29221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description The CUDA plugin EP previously rejected combining a user-provided compute stream (`user_compute_stream`) with CUDA graph capture (`enable_cuda_graph`), returning `ORT_INVALID_ARGUMENT`. This PR removes that restriction so the two options can be used together: when both are set, graph capture and replay run on the user-owned stream (the same stream the kernels are issued to), matching the bundled (non-plugin) CUDA EP behavior. Several supporting fixes make capture on a shared stream stable and Memcpy-free. ## Summary of Changes ### Allow user stream + CUDA graph | File | Change | |------|--------| | [onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc](onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc) | Remove the validation that rejected `user_compute_stream` + `enable_cuda_graph` together. | | [onnxruntime/core/providers/cuda/plugin/cuda_ep.cc](onnxruntime/core/providers/cuda/plugin/cuda_ep.cc) | `PerThreadContext` accepts an optional external graph stream. When both options are set it captures/replays on the user stream and does **not** create or destroy it (the user owns its lifetime); otherwise it owns a dedicated graph stream as before. | ### Stable, Memcpy-free CUDA graph capture | File | Change | |------|--------| | [onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h](onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h) | Route kernel scratch/workspace allocations through the EP allocator (BFC arena) instead of raw `cudaMallocAsync`/`cudaMalloc`. After warmup the arena reaches steady state, so the capture run serves scratch from already-reserved chunks and the device free-memory footprint stays stable — required for correct capture. Matches the built-in CUDA EP. | | [onnxruntime/core/providers/cuda/tensor/shape_op.cc](onnxruntime/core/providers/cuda/tensor/shape_op.cc) | Add an adapter-based `Shape` kernel under `#ifdef BUILD_CUDA_EP_AS_PLUGIN` with identical semantics to the CPU `Shape`. Registering `Shape` on the EP keeps it off the CPU EP and avoids the Memcpy nodes that would otherwise break CUDA graph capture. | | [cmake/onnxruntime_providers_cuda_plugin.cmake](cmake/onnxruntime_providers_cuda_plugin.cmake) | Stop excluding `shape_op.cc` from the plugin build so the adapter-based `Shape` kernel is compiled in. | ### Null-allocator fallback in PrePack (plugin boundary) In the plugin build the `AllocatorPtr` passed to `PrePack` can arrive null across the library boundary. Each kernel now falls back to its own default-memory allocator (`Info().GetAllocator(OrtMemTypeDefault)`), which is always valid. - [onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc](onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc) - [onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc](onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc) - [onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc](onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc) ### Misc - [onnxruntime/core/framework/session_state.cc](onnxruntime/core/framework/session_state.cc) — wrap a long line (no behavior change). ## Testing - New test: [onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc](onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc) covering: 1. Session creation succeeds with both `user_compute_stream` and `enable_cuda_graph` set (regression for the removed validation). 2. Capture + replay on the user stream produce correct results. 3. Replay after an in-place input update on the user stream is correct. - Tests are gated on `ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP` and skip gracefully when no CUDA device or plugin library is available. ## Motivation and Context Users that drive ORT from their own CUDA stream (e.g. to interleave ORT inference with their own kernels) previously could not also benefit from CUDA graph capture on the plugin EP. This change brings the plugin EP to parity with the bundled CUDA EP for that workflow. ## Checklist - [x] Tests added/updated - [x] No breaking changes (relaxes a previously rejected option combination) - [ ] Documentation updated (if applicable) --- .github/workflows/linux_cuda_plugin_ci.yml | 24 ++ cmake/onnxruntime_providers_cuda_plugin.cmake | 9 +- .../arena_allocator_migration_design.md | 13 +- .../cuda_graph_for_cuda_plugin.md | 50 +++- docs/cuda_plugin_ep/cuda_plugin_ep_design.md | 15 +- include/onnxruntime/ep/adapter/allocator.h | 24 +- onnxruntime/core/framework/session_state.cc | 11 +- .../core/providers/cuda/llm/attention.cc | 7 +- .../core/providers/cuda/plugin/cuda_ep.cc | 64 ++++- .../providers/cuda/plugin/cuda_ep_factory.cc | 9 +- .../cuda/plugin/cuda_kernel_adapter.h | 100 ++++---- .../core/providers/cuda/tensor/shape_op.cc | 75 ++++++ .../cuda_plugin_user_stream_graph_test.cc | 240 ++++++++++++++++++ 13 files changed, 542 insertions(+), 99 deletions(-) create mode 100644 onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc diff --git a/.github/workflows/linux_cuda_plugin_ci.yml b/.github/workflows/linux_cuda_plugin_ci.yml index a9197b3732dd8..e88c6beff5280 100644 --- a/.github/workflows/linux_cuda_plugin_ci.yml +++ b/.github/workflows/linux_cuda_plugin_ci.yml @@ -141,3 +141,27 @@ jobs: cd /onnxruntime_src/onnxruntime/test/python/transformers python test_cuda_plugin_ep.py " + + # --- Run the CUDA plugin EP C++ GoogleTest binary --- + # onnxruntime_provider_test is built into the artifact and links the plugin tests + # (gated by ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP). The user-stream + CUDA graph test + # registers the plugin .so via GetSharedLibraryFileName("onnxruntime_providers_cuda_plugin"), + # which returns the platform-specific filename without a directory component. Run from + # /build/Release/Release so that filename resolves to the plugin .so built there. + - name: Run CUDA Plugin EP C++ Tests + run: | + docker run --rm --gpus all \ + -v ${{ github.workspace }}:/onnxruntime_src \ + -v ${{ runner.temp }}/Release:/build/Release \ + -e NVIDIA_VISIBLE_DEVICES=all \ + ${{ steps.build_docker_image_step.outputs.full-image-name }} \ + bash -c " + set -ex + export PATH=/opt/python/cp312-cp312/bin:\$PATH + # Make libcudart.so.13 (and the plugin's CUDA deps) findable; see note above. + export LD_LIBRARY_PATH=/build/Release/Release:/usr/local/cuda-13.0/lib64:\${LD_LIBRARY_PATH:-} + + cd /build/Release/Release + ls -la onnxruntime_provider_test libonnxruntime_providers_cuda_plugin.so + ./onnxruntime_provider_test --gtest_filter='CudaPluginUserStreamGraphTest.*' + " diff --git a/cmake/onnxruntime_providers_cuda_plugin.cmake b/cmake/onnxruntime_providers_cuda_plugin.cmake index 551e877d5f6d8..86e5579eb6761 100644 --- a/cmake/onnxruntime_providers_cuda_plugin.cmake +++ b/cmake/onnxruntime_providers_cuda_plugin.cmake @@ -88,10 +88,11 @@ list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/sequence_op\\.cc$") # in the CPU provider and is not linked into the plugin. list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/size\\.cc$") -# Permanently excluded — pure CPU ops, handled by GetCpuPreferredNodes. -# shape_op.cc inherits from onnxruntime::OpKernel (framework) -# which cannot convert to ep::adapter::OpKernel in the plugin build. -list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/tensor/shape_op\\.cc$") +# shape_op.cc is INCLUDED in the plugin build. It provides an adapter-based +# Shape kernel under #ifdef BUILD_CUDA_EP_AS_PLUGIN (the CPU onnxruntime::Shape +# class, which derives from the framework OpKernel, is only used in the +# non-plugin build). Registering Shape on the EP keeps it off the CPU EP and +# avoids Memcpy nodes that would otherwise break CUDA Graph capture. # Exclude contrib training ops (shrunken_gather depends on provider_api.h in header). list(FILTER CUDA_PLUGIN_EP_CC_SRCS EXCLUDE REGEX ".*/contrib_ops/cuda/tensor/shrunken_gather\\.cc$") diff --git a/docs/cuda_plugin_ep/arena_allocator_migration_design.md b/docs/cuda_plugin_ep/arena_allocator_migration_design.md index 285aa3e60ed5c..f082b444e10b0 100644 --- a/docs/cuda_plugin_ep/arena_allocator_migration_design.md +++ b/docs/cuda_plugin_ep/arena_allocator_migration_design.md @@ -62,7 +62,18 @@ if (!factory.arena_allocator_) { **Stream-aware allocation.** `ArenaImpl::AllocOnStream(size, stream)` tracks which chunks are assigned to which stream. `ResetChunksUsingStream(stream_impl)` is called from `OrtSyncStreamImpl::OnSessionRunEnd` to release chunk-to-stream assignments when a session run completes. -**Read-only allocator bypasses arena.** The factory creates a plain `CustomAllocator` (no arena) for `OrtReadOnlyAllocator` (initializers), since initializer memory doesn't benefit from arena allocation. +**Kernel-side consumption of the arena.** Migrated CUDA kernels obtain scratch/workspace memory from this arena through `CudaKernel::GetScratchBuffer`, which calls `Info().GetAllocator(OrtMemTypeDefault)`. Inside the plugin build that allocator is exposed to internal code as an `IAllocatorWrappingOrtAllocator` (`include/onnxruntime/ep/adapter/allocator.h`), which implements `IsStreamAware()`/`AllocOnStream()` by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` (ORT ≥ 1.23), falling back to plain `Alloc` otherwise. The plugin `GetScratchBuffer` deliberately passes a **null stream** to the arena rather than forwarding the kernel's compute stream. A plugin kernel only has the raw `cudaStream_t` (via `KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena persists in each chunk (`chunk->stream`) and later dereferences through the EP stream API (`SyncStream_GetImpl`/`SyncStream_GetSyncId`). Synthesizing a temporary framework `Stream` wrapper over the raw handle would be unsafe: it would dangle once `GetScratchBuffer` returns while the arena still holds the pointer, and it would be type-confused (a framework `Stream*` reinterpreted as an `OrtSyncStream*` that ORT never created for this stream). With a null stream the arena tracks scratch chunks as freely reusable (the same semantics as a plain non-stream-aware BFC arena). This is still what keeps scratch allocations served from already-reserved chunks during CUDA graph capture — capture stability comes from chunk reuse, not from stream tagging — and it is safe for the CUDA graph path, which runs on a single unified stream. + +#### Scratch buffer stream tagging — limitation and future work + +A common review question is: *"Passing a null stream to the scratch allocator looks wrong — won't it cause a synchronization issue? Shouldn't the scratch buffer use the same stream as the kernel?"* The short answer is that, for the path this code targets, it is correct and safe. The longer answer clarifies what the `stream` argument actually does and why forwarding the real stream is not currently possible. + +- **The `stream` argument is bookkeeping, not execution.** The stream passed to a stream-aware arena's `AllocOnStream()` is only metadata the arena uses to decide whether a *freed* chunk may be reused on a *different* stream without an intervening synchronization. It does **not** change where the kernel runs: the returned buffer is always consumed by the kernel on its real compute stream. So a null tag does not move work onto the default stream or skip any required sync — it only relaxes cross-stream chunk reuse. +- **Why null is safe here.** The scratch routing targets serialized runs and the CUDA graph path, which runs on a single **unified stream** when graph capture and a user compute stream are combined. On a single stream, alloc -> use -> free -> reuse are implicitly ordered by the stream itself, so there is never a second stream that could reuse a chunk while the first is still using it. A null-tagged ("freely reusable") chunk behaves exactly like a plain non-stream-aware BFC arena chunk, which is the correct behavior for one stream. Because null-tagged chunks are not safe for overlapping runs on different CUDA streams, the CUDA plugin EP does not advertise concurrent `Session::Run()` support until scratch chunks can be properly stream-tagged. +- **Why we cannot forward the real stream today (C-API limitation).** The stream-aware arena needs the framework `OrtSyncStream*` (`struct OrtSyncStream : public onnxruntime::Stream` in `core/framework/plugin_ep_stream.h`) — the ORT-core wrapper it stored in `chunk->stream`. A plugin kernel only has the raw `cudaStream_t`. `CudaSyncStream::FromCudaStream()` can recover the plugin-side `CudaSyncStream` (an `OrtSyncStreamImpl`), but that is a *different* object from the ORT-core `OrtSyncStream*` the arena expects; passing it (or a stack-allocated shim over the raw handle) would be both dangling and type-confused. +- **Future work.** To properly stream-tag scratch chunks — which only becomes necessary if this path is extended to support concurrent multi-stream runs sharing one arena — ORT needs new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels at dispatch time (e.g. via `KernelContext`). Until then, the null-stream tag is the correct and intentional choice. The matching code comment lives in `CudaKernel::GetScratchBuffer` (`onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h`). + + ### 2.2 How ORT Core Calls the Factory diff --git a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md index 8092a15e26973..1cac9464430dc 100644 --- a/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md +++ b/docs/cuda_plugin_ep/cuda_graph_for_cuda_plugin.md @@ -34,6 +34,7 @@ Session::Run() **Key design choices:** - Each thread gets its own dedicated graph `cudaStream_t`, `CudaGraphManager`, and capture bookkeeping for the EP instance. `CudaSyncStream::InitHandlesWithExternalStream()` wraps the thread's graph stream so graph capture sees the same stream as kernels. The manager stores captured `cudaGraphExec_t` executables keyed by annotation ID, allowing multiple graphs (e.g., different input shapes) for that thread. +- When a `user_compute_stream` is supplied together with graph capture, the per-thread context adopts that user-owned stream as its graph stream instead of creating one, so capture/replay run on the same stream the caller drives. The context records that it does not own the stream and never destroys it. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). - Warm-up runs (default: 2) allow memory allocations to stabilize before capture begins. - Graph annotation IDs are parsed from `OrtRunOptions` key `"gpu_graph_id"`. ID `-1` skips capture; `0` is the default. @@ -53,25 +54,50 @@ Session::Run() Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supported. For the warm-up count, `ep.cuda.min_num_runs_before_cuda_graph_capture` is also accepted. +The provider option `user_compute_stream` (a `cudaStream_t` passed as a pointer) may be combined with `enable_cuda_graph`. See [User Compute Stream + CUDA Graph](#user-compute-stream--cuda-graph). + --- +## User Compute Stream + CUDA Graph + +A caller can supply its own CUDA stream through the `user_compute_stream` provider option and enable CUDA graph capture at the same time. This combination was previously rejected with `ORT_INVALID_ARGUMENT`; it is now supported and matches the bundled (non-plugin) CUDA EP. + +When both options are set: + +- `CudaEpFactory::CreateEpImpl` no longer rejects the pair. Setting `user_compute_stream` still forces unified-stream mode (matching the bundled EP). +- `CudaEp::CreateSyncStreamForDeviceImpl` wraps the user stream via `InitHandlesWithUserStream()`, attaching full cuBLAS/cuDNN/cuBLASLt handles to it. +- `CudaEp::GetPerThreadContext()` builds the thread's `PerThreadContext` around the user stream (`external_graph_stream`) instead of creating an EP-owned graph stream. Capture and replay therefore run on the same stream the kernels are issued to. +- The context records `owns_graph_stream = false`, so it tears down captured graph execs on destruction but never calls `cudaStreamDestroy` on the user-owned stream. Stream lifetime stays with the caller. + +Because the user supplies one stream, this mode is inherently single-stream; the per-thread graph isolation still applies if the same session is driven from multiple threads, but each thread must drive its own captures on the stream it provides. + +### `user_compute_stream` is not limited to the CUDA graph case + +A natural question when reading `GetPerThreadContext()` is why `use_external_stream` is gated on `has_user_compute_stream && enable_cuda_graph` — does that restrict a user compute stream to graph-enabled runs? It does not. + +- A user compute stream is honored for kernels in **both** graph and non-graph runs. That happens in `CudaEp::CreateSyncStreamForDeviceImpl`, whose first branch wraps `config_.user_compute_stream` via `InitHandlesWithUserStream()` **independently of `enable_cuda_graph`**. +- The `enable_cuda_graph` term in `use_external_stream` only governs the `PerThreadContext`'s *graph stream*. `PerThreadContext` is a graph-capture-only object: `GetPerThreadContext()` is reached exclusively from the graph path (the `enable_cuda_graph` branch of `CreateSyncStreamForDeviceImpl`, `OnRunStart`/`OnRunEnd`, `IsGraphCaptured`, `ReplayGraph`). With graph disabled, no `PerThreadContext` is ever constructed, so its stream-ownership flag is irrelevant. +- The flag therefore answers a narrower question — *"should the per-thread capture/replay graph stream adopt (and not destroy) the user's stream?"* — which is only meaningful when a graph is actually being captured. + ## Implementation Summary ### Files Changed | File | Change | |------|--------| -| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to use the current thread's graph stream when graph capture is enabled, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep.cc` | Implemented graph capture callbacks (`OnRunStartImpl`, `OnRunEndImpl`, `IsGraphCaptureEnabledImpl`, `IsGraphCapturedImpl`, `ReplayGraphImpl`, `IsConcurrentRunSupportedImpl`), updated `CreateSyncStreamForDeviceImpl` to wrap a `user_compute_stream` or otherwise use the current thread's graph stream when graph capture is enabled, made `PerThreadContext` adopt the user stream as its (non-owned) graph stream when `user_compute_stream` + `enable_cuda_graph` are combined, added per-thread graph state, preserved `sync_stream` synchronization, and added a `cudaMemGetInfo` defensive allocation check | | `onnxruntime/core/providers/cuda/plugin/cuda_ep.h` | Added `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` config fields, graph callback declarations, and a per-thread graph context cache | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.cc` | **NEW** — Complete `CudaGraphSet` and `CudaGraphManager` implementation | | `onnxruntime/core/providers/cuda/plugin/cuda_graph_plugin.h` | **NEW** — Header for graph manager types and constants | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.cc` | Added `InitHandlesWithExternalStream()`, updated destructor for `owns_stream_` | | `onnxruntime/core/providers/cuda/plugin/cuda_stream_plugin.h` | Added `InitHandlesWithExternalStream()` declaration, `owns_stream_` member | -| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture` | +| `onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc` | Added config parsing for `enable_cuda_graph` and `min_num_runs_before_cuda_graph_capture`; removed the validation that rejected `user_compute_stream` + `enable_cuda_graph` (the combination is now supported) | +| `onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h` | `CudaKernel::GetScratchBuffer` now allocates through `Info().GetAllocator()` (the EP arena) with a null stream, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call, so scratch allocations are served from already-reserved arena chunks during capture | +| `include/onnxruntime/ep/adapter/allocator.h` | Implemented `IAllocatorWrappingOrtAllocator::IsStreamAware`/`AllocOnStream` (previously `ORT_NOT_IMPLEMENTED`) so plugin adapters can forward stream-aware allocations when a framework stream is available; `GetScratchBuffer` still passes a null stream until plugin kernels can receive a stable framework `OrtSyncStream*` | | `include/onnxruntime/core/session/onnxruntime_ep_c_api.h` | Added `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` callbacks and `OrtGraphCaptureNodeAssignmentPolicy` enum to `OrtEp` | | `include/onnxruntime/core/framework/execution_provider.h` | Added `GetGraphCaptureNodeAssignmentPolicy()` virtual to `IExecutionProvider` | | `onnxruntime/core/session/inference_session.cc` | Replaced hard-coded EP name list with policy-driven graph capture validation loop; added bounded recursion via `RunImpl()` with `kMaxGraphCaptureWarmupRuns`; graph-enabled runs now reacquire stream collections through ORT core's thread-affine pool across internal warm-up/capture recursion | -| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread | +| `onnxruntime/core/framework/session_state.cc` | Sharded the `DeviceStreamCollection` cache by caller thread using per-thread lifetime tokens, so stream wrappers are only reused on the creating thread; added a fallback in the PrePack loop to resolve the kernel's default-memory allocator (`Info().GetAllocator()`) when the device-keyed initializer-allocator lookup returns null for a separately-registered plugin EP | | `onnxruntime/core/framework/session_state.h` | Added thread-affine stream pool bucket state for `DeviceStreamCollection` reuse | | `onnxruntime/core/session/inference_session.h` | Added `RunImpl()` private method and `kMaxGraphCaptureWarmupRuns` constant | | `onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc` | Added version-gated `IsGraphCaptureEnabled`, `IsGraphCaptured`, `ReplayGraph`, `GetGraphCaptureNodeAssignmentPolicy` bridge implementations | @@ -83,7 +109,8 @@ Legacy aliases `ep.cuda.enable_cuda_graph` and `enable_cuda_graph` are also supp - **Thread safety**: Mutable graph state and graph streams are stored per thread. ORT core's `DeviceStreamCollection` cache is also thread-affine, so graph-enabled runs can recycle stream wrappers without exposing them to a different thread. - **Scope**: Capture/replay pipeline plus allocator compatibility. Arena integration is complete — see the [Arena Allocator Integration](#arena-allocator-integration) section. - **Callback assignment**: `IsGraphCaptureEnabled` and `GetGraphCaptureNodeAssignmentPolicy` are always set. `OnRunStart`, `OnRunEnd` are conditional on `enable_cuda_graph`. `IsGraphCaptured` and `ReplayGraph` are always set (return false/error when disabled). -- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally to use the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when disabled. +- **Stream management**: `CreateSyncStreamForDevice` remains unconditional — it branches internally: it wraps a user-provided `user_compute_stream` (via `InitHandlesWithUserStream`) when one is set, otherwise uses the current thread's graph stream (via `InitHandlesWithExternalStream`) when graph capture is enabled, or creates an owned stream when both are disabled. +- **User compute stream + CUDA graph**: These options can now be combined (previously rejected at factory creation). When both are set, `CudaEp::GetPerThreadContext()` builds the `PerThreadContext` around the user's stream (`external_graph_stream`) so capture and replay run on the same stream the kernels use, and the context never destroys the user-owned stream (`owns_graph_stream = false`). - **Run-end synchronization**: `OnRunEndImpl` honors the `sync_stream` flag without double-synchronizing replayed graphs, preserving the normal EP completion contract. - **Stream collection reuse**: ORT core now recycles `DeviceStreamCollection` objects into a thread-affine session pool keyed by a per-thread lifetime token. Warm-up, capture, replay, and later user-visible `Run()` calls on the same thread can reuse the same stream wrappers, while dead-thread buckets are pruned before they can be reused by another thread. - **Per-thread context lifecycle**: Thread-local caches hold the strong `PerThreadContext` references, so CUDA streams and captured graph executables are released when the owning thread exits. The EP tracks weak references to those cache maps to remove stale entries during EP destruction without keeping the contexts alive. @@ -101,6 +128,7 @@ CUDA graph capture requires that all memory allocations happen during warmup, no **Arena integration details (now implemented):** - Default CUDA device allocations come from the plugin-hosted arena (`CudaArenaAllocator`). During warmup runs, the arena grows to accommodate all needed chunks; during capture and replay, the same chunks are reused without `cudaMalloc` calls. +- Kernel scratch/workspace allocations (`CudaKernel::GetScratchBuffer`) also flow through the EP arena via `Info().GetAllocator()`, rather than issuing a fresh `cudaMallocAsync`/`cudaMalloc` per call. After warmup the arena has reached its steady-state working set, so the capture run serves every scratch request from an already-reserved chunk and the device free-memory footprint stays stable across the capture window. This is what makes the `cudaMemGetInfo` allocation-during-capture detector pass for graphs that use scratch buffers, and it matches the bundled CUDA EP (which also obtains scratch from `Info().GetAllocator()`). `GetScratchBuffer` passes a **null stream** to the arena. This is *not* a synchronization bug: the `stream` argument is only bookkeeping metadata the stream-aware arena uses to decide when a freed chunk may be reused on a *different* stream without a sync - it does not change where the kernel runs (the buffer is still consumed on the real compute stream). In a serialized run (and within one graph-capture run), alloc/free/reuse are implicitly ordered on that stream, so a null-tagged ("freely reusable") chunk is correct and safe. It is also currently the only safe option, because a plugin kernel only has the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` the stream-aware arena persists per chunk and later dereferences through the EP stream API; note that the ORT-core `OrtSyncStream` (`struct OrtSyncStream : public onnxruntime::Stream`) is a different object from the plugin's `CudaSyncStream` (an `OrtSyncStreamImpl`). Synthesizing a temporary `Stream*` over the raw handle would dangle after `GetScratchBuffer` returns and be type-confused, so scratch chunks are tracked with a null stream (freely reusable, like a plain BFC arena). Capture stability comes from chunk reuse, not stream tagging. Properly stream-tagging scratch chunks (required before this path can support concurrent multi-stream runs) is **future work** that requires new C-API surface to expose the framework `OrtSyncStream*` to plugin kernels — see [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). - When `arena.use_cuda_mempool=1` is configured, CUDA device allocations come from `CudaMempoolOrtAllocator`, which wraps `cudaMallocFromPoolAsync`/`cudaFreeAsync`. These async allocation/free operations are CUDA-graph-safe since CUDA 11.4+ and become part of the captured graph topology. - Pinned allocations are also arena-backed, but remain non-stream-aware. - The graph stream created by `CudaEp::PerThreadContext` flows through `CudaSyncStream::InitHandlesWithExternalStream()` so stream-aware arena allocation uses the same `cudaStream_t` during warm-up, capture, and replay. @@ -109,14 +137,12 @@ CUDA graph capture requires that all memory allocations happen during warmup, no ### Concurrent Run Support -Concurrent `Session::Run()` is supported with CUDA graph enabled: +Concurrent `Session::Run()` is intentionally **not** advertised by the CUDA plugin EP while migrated kernels route scratch/workspace allocations through the EP arena with a null stream tag. -- `CudaEp::PerThreadContext` owns the graph stream, graph manager, warm-up run counts, and memory watermark for the current thread. -- The current thread's cache owns the `PerThreadContext`; new threads get independent contexts, and exited threads release their contexts automatically. -- `CreateSyncStreamForDeviceImpl()` wraps the current thread's graph stream, so warm-up, capture, and replay all use the same stream for that thread. -- `CudaGraphManager::CaptureBegin()` uses `cudaStreamCaptureModeThreadLocal`, allowing overlapping capture scopes on different threads. -- ORT core recycles graph-enabled `DeviceStreamCollection` objects into a thread-affine session pool, so internal warm-up/capture recursion and later top-level `Run()` calls on the same thread reuse the same stream wrappers without cross-thread leakage. -- `IsGraphCaptured()` and `ReplayGraph()` resolve the current thread's graph context. If a new thread runs a graph-enabled session for the first time, that thread performs its own warm-up and capture before replaying. +- `CudaEp::PerThreadContext` still owns graph stream, graph manager, warm-up run counts, and memory watermark state per thread. This keeps graph bookkeeping thread-local and avoids sharing captured graph executables across threads. +- However, plugin kernels currently receive only the raw `cudaStream_t` (`KernelContext::GetGPUComputeStream`), not the framework `OrtSyncStream*` that the stream-aware arena stores in each chunk and later uses for safe cross-stream reuse checks. +- Because `CudaKernel::GetScratchBuffer` cannot safely provide that framework stream, it passes a null stream tag. Null-tagged scratch chunks are freely reusable, which is safe for serialized runs and single-unified-stream graph capture but unsafe for overlapping runs on different CUDA streams. +- Therefore `CudaEp::IsConcurrentRunSupportedImpl()` returns false. Re-enabling concurrent multi-stream runs is future work and requires new C-API surface to expose a stable framework stream (or equivalent sync id) to plugin kernels so scratch chunks can be properly stream-tagged. ## Verification @@ -128,7 +154,9 @@ Concurrent `Session::Run()` is supported with CUDA graph enabled: - `test_cuda_graph_with_mempool` — graph capture with `arena.use_cuda_mempool=1` - `test_cuda_graph_annotation_id` — multiple graphs via `gpu_graph_id` run config - `test_cuda_graph_add_model` — graph capture with Add op (arena-backed) +4. `onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc` is a C++ test (gated by `ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP`) covering `user_compute_stream` combined with `enable_cuda_graph`: it verifies session creation succeeds with both options set (regression for the removed validation), capture + replay on the user stream produce correct results, and replay after an in-place input update on the user stream is correct. ## Future Work 1. **Profiling integration**: CUDA graph replay currently bypasses the CUDA plugin EP profiler path because the CUDA plugin EP does not yet implement `OrtEp::CreateProfiler`. Wiring graph replay into that path is future work. +2. **Stream-tagged scratch allocations**: `CudaKernel::GetScratchBuffer` passes a null stream to the EP arena because plugin kernels cannot currently obtain the framework `OrtSyncStream*` the stream-aware arena needs (they only have the raw `cudaStream_t`). This is correct and safe for serialized runs and within one graph-capture run, but it is why the EP does not advertise concurrent `Session::Run()` support. Supporting concurrent multi-stream runs that share one arena would require new C-API surface to expose the framework `OrtSyncStream*` (or its sync-id) to plugin kernels so scratch chunks can be properly stream-tagged. See [arena_allocator_migration_design.md](arena_allocator_migration_design.md) ("Scratch buffer stream tagging — limitation and future work"). diff --git a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md index 2f61da90233b4..438fb8606fc09 100644 --- a/docs/cuda_plugin_ep/cuda_plugin_ep_design.md +++ b/docs/cuda_plugin_ep/cuda_plugin_ep_design.md @@ -459,9 +459,18 @@ The NHWC rollout is effectively in a "runtime enabled, cleanup remaining" state: | 2 | Cache the shim provider pointer in the adapter `OpKernelInfo` | Implemented; fixes the observed NHWC runtime crash | | 3 | Consolidate allowlists, improve internal-domain diagnostics, and strengthen structural NHWC assertions | Recommended follow-up work | +#### 5.3.2 Allocator Resolution for Kernels (Scratch and PrePack) + +Migrated kernels need a valid device allocator in two places: scratch/workspace buffers during `Compute()`, and one-time weight conversion or packing during `PrePack()`. Both now resolve the allocator the same way the bundled CUDA EP does, through the kernel's own `OpKernelInfo`. + +- **Scratch buffers.** `CudaKernel::GetScratchBuffer` allocates through `Info().GetAllocator(OrtMemTypeDefault)` (the EP arena) with a null stream tag, instead of issuing a raw `cudaMallocAsync`/`cudaMalloc` per call. The adapter `OpKernelInfo::GetAllocator` resolves the EP's default-memory (device) allocator and is always valid for a migrated kernel, so no plugin-only scratch path is needed. Routing through the arena is also what keeps the device free-memory footprint stable during CUDA graph capture (see [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#arena-allocator-integration)). The null stream tag is intentional: plugin kernels only have the raw `cudaStream_t`, not the framework `OrtSyncStream*` that the stream-aware arena persists in chunks for safe cross-stream reuse. +- **PrePack.** The framework prepack loop (`SessionState::PrepackConstantInitializedTensors`) resolves the allocator with `GetInitializerAllocator(kernel->Info().GetDevice(OrtMemTypeDefault))`, a session map keyed by device. For a plugin EP registered as a separate library, that device-keyed lookup can miss and return null. The loop now falls back to `kernel->Info().GetAllocator(OrtMemTypeDefault)` when the lookup is null, so every `PrePack` implementation receives a valid allocator at the single framework call site. This replaces the earlier approach of adding a per-kernel `if (!alloc) alloc = Info().GetAllocator(...)` guard to each prepacking op (which only covered the few ops that were touched and risked missing future ones). The fallback is behavior-neutral for in-tree EPs, whose device-keyed lookup already succeeds, and it does **not** force `is_packed`/`prepacked_weights` handling \u2014 ops such as `QMoE` and `MatMulNBits` still set `is_packed = true` and populate prepacked weights normally. + +The enabling adapter change is in [`include/onnxruntime/ep/adapter/allocator.h`](../../include/onnxruntime/ep/adapter/allocator.h): `IAllocatorWrappingOrtAllocator` now implements `IsStreamAware()`/`AllocOnStream()` (previously `ORT_NOT_IMPLEMENTED`) by forwarding to the underlying `OrtAllocator`'s `AllocOnStream` when it is available (ORT >= 1.23), falling back to plain `Alloc` otherwise. `GetScratchBuffer` does not use that stream-aware path yet because the plugin kernel layer cannot safely provide the framework `OrtSyncStream*`; stream-tagged scratch allocation is future work and is documented in [arena_allocator_migration_design.md](arena_allocator_migration_design.md#scratch-buffer-stream-tagging--limitation-and-future-work). + ### 5.4 CUDA Graph Support -CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and concurrent `Session::Run()` support. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and concurrent run details — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. +CUDA Graph capture/replay is fully implemented for the plugin EP, including arena integration (both default BFC arena and CUDA native mempool), multi-graph via annotation IDs with different input shapes, and combining a caller-supplied `user_compute_stream` with capture/replay. Concurrent `Session::Run()` is intentionally not advertised while scratch allocations are null-stream-tagged; supporting concurrent multi-stream runs requires future C-API work to expose a stable framework stream or sync id to plugin kernels. The full design — plugin-side implementation, per-thread isolation, arena integration, capture flow, and user-stream mode — is in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md). This section documents only the framework-level and C API changes that affect the broader ORT architecture. #### 5.4.1 OrtEp C API Extensions (v1.26) @@ -488,6 +497,10 @@ Session-level changes in `inference_session.cc`: - **Bounded recursion**: After each normal run when graph capture is enabled, the session recursively calls `RunImpl()` (bounded by `kMaxGraphCaptureWarmupRuns = 8`) until the graph is captured. From the user's perspective, a single `Run()` call handles the entire warm-up + capture sequence. - **Stream collection lifetime**: ORT core now caches `DeviceStreamCollection` objects in thread-affine session buckets keyed by a per-thread lifetime token. Graph-enabled runs recycle and reacquire stream wrappers only on the creating thread, which preserves warm-up/capture reuse without cross-thread leakage. +#### 5.4.3 User Compute Stream with CUDA Graph + +A caller-provided `user_compute_stream` may be combined with `enable_cuda_graph` (the factory previously rejected this pair). When both are set, `CudaEp::GetPerThreadContext()` builds the per-thread graph context around the user-owned stream rather than an EP-owned one, so capture and replay run on the same stream the kernels are issued to (matching the bundled CUDA EP). The context marks the stream as not owned and never destroys it. Details are in [cuda_graph_for_cuda_plugin.md](cuda_graph_for_cuda_plugin.md#user-compute-stream--cuda-graph). + --- ## 6. EP Adapter Layer (`include/onnxruntime/ep/adapter/`) diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h index 0db30f39b3f57..1fb78f81fce19 100644 --- a/include/onnxruntime/ep/adapter/allocator.h +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -41,21 +41,19 @@ class IAllocatorWrappingOrtAllocator final : public IAllocator { } bool IsStreamAware() const override { - return false; - - // TODO: Enable once AllocOnStream() is implemented. - // static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; - // const OrtAllocator* raw = ort_allocator_; - // return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + const OrtAllocator* raw = ort_allocator_; + return raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr; } - void* AllocOnStream(size_t /*size*/, Stream* /*stream*/) override { - // TODO: Implement AllocOnStream(). - // The internal `onnxruntime::IAllocator::AllocOnStream` signature takes an internal `onnxruntime::Stream*` - // argument, while the public `::OrtAllocator::AllocOnStream` signature takes an `::OrtSyncStream*` argument. - // We need to properly map from one to the other. - // `::OrtSyncStream*` should be treated as an opaque type from the plugin EP's perspective. - ORT_NOT_IMPLEMENTED("IAllocatorWrappingOrtAllocator::AllocOnStream is not implemented yet."); + void* AllocOnStream(size_t size, Stream* stream) override { + static constexpr uint32_t kOrtAllocatorAllocOnStreamMinVersion = 23; + OrtAllocator* raw = ort_allocator_; + if (raw->version >= kOrtAllocatorAllocOnStreamMinVersion && raw->AllocOnStream != nullptr) { + return raw->AllocOnStream(raw, size, reinterpret_cast(stream)); + } + + return raw->Alloc(raw, size); } void GetStats(AllocatorStats* stats) override { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 6ef2319c1d3f4..ad92ddd797d3a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -603,7 +603,16 @@ Status SessionState::PrepackConstantInitializedTensors( // within this session. Or if the weight is not present on disk, // we store the newly minted pre-packed data. - AllocatorPtr session_initializer_alloc = GetInitializerAllocator(kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + AllocatorPtr session_initializer_alloc = GetInitializerAllocator( + kernel->Info().GetDevice(OrtMemType::OrtMemTypeDefault)); + // A plugin EP registered as a separate library may not have an initializer + // allocator registered under the kernel's device key, so the lookup above can + // return null. Fall back to the kernel's own default-memory allocator (resolved + // through the EP), which is always valid. This keeps PrePack implementations from + // each having to special-case a null allocator at the library boundary. + if (!session_initializer_alloc) { + session_initializer_alloc = kernel->Info().GetAllocator(OrtMemType::OrtMemTypeDefault); + } PrePackedWeights weights_to_be_filled_in; // The reason we invoke PrePack() before looking into the container for any pre-packed weight // cached by another instance of the same op_type (for the same constant initializer) is because diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index edafbfb3ede65..dc53b02141207 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1361,9 +1361,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // cross-attention case; MEA handles it via the causal_from_top_left flag and Unified // Unfused uses past_kv_length=0. (When an external cache is present — nonpad_kv_seqlen — // the required frontier IS bottom-right, so Flash is eligible; see below.) - const bool causal_cross_no_past = parameters.is_causal && - parameters.q_sequence_length != parameters.total_sequence_length && - parameters.past_sequence_length == 0; + [[maybe_unused]] const bool causal_cross_no_past = + parameters.is_causal && + parameters.q_sequence_length != parameters.total_sequence_length && + parameters.past_sequence_length == 0; // is_causal=1 + nonpad_kv_seqlen (external KV cache) without past_key defines a // bottom-right causal frontier per onnx/onnx#8068: query in-block index i attends key j diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index d611366c8ad7e..13f7bfd7a40cf 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -83,20 +83,33 @@ void DestroyCudaStreamForDevice(cudaStream_t stream, int device_id) { } // namespace struct CudaEp::PerThreadContext { - explicit PerThreadContext(int device_id) + // When use_external_stream is true (user_compute_stream combined with CUDA graph), capture and + // replay happen on that user-owned stream so they see the same stream as the kernels; the + // context neither creates nor destroys it. Ownership is derived from the caller's intent rather + // than from external_stream being non-null, because a user may legitimately select the CUDA + // default stream (cudaStream_t(0), i.e. nullptr) as the compute stream — that is still an + // external, user-owned stream and must not be destroyed by the context. When use_external_stream + // is false the context creates and owns a dedicated graph stream. + explicit PerThreadContext(int device_id, bool use_external_stream = false, + cudaStream_t external_stream = nullptr) : device_id(device_id), - graph_stream(CreateCudaStreamForDevice(device_id)), + owns_graph_stream(!use_external_stream), + graph_stream(use_external_stream ? external_stream + : CreateCudaStreamForDevice(device_id)), cuda_graph(graph_stream) { } ~PerThreadContext() { // Destroy captured graph execs before destroying the stream they replay on. cuda_graph.Reset(); - DestroyCudaStreamForDevice(graph_stream, device_id); + if (owns_graph_stream) { + DestroyCudaStreamForDevice(graph_stream, device_id); + } graph_stream = nullptr; } int device_id; + bool owns_graph_stream; cudaStream_t graph_stream = nullptr; CudaGraphManager cuda_graph; size_t pre_capture_free_mem = 0; @@ -391,8 +404,14 @@ OrtStatus* ORT_API_CALL CudaEp::CreateSyncStreamForDeviceImpl( auto cuda_stream = std::make_unique(ep->factory_, device_id, this_ptr); - if (ep->config_.has_user_compute_stream && ep->config_.user_compute_stream != nullptr) { - // Wrap the user-provided external CUDA stream with full cuBLAS/cuDNN handles. + if (ep->config_.has_user_compute_stream) { + // A user-provided compute stream is honored for kernels regardless of whether CUDA graph + // capture is enabled - this branch is taken in both graph and non-graph runs. Use the caller's + // intent flag rather than checking the handle for non-null: cudaStream_t(0) / nullptr is the + // valid CUDA default stream and can be selected explicitly by the user. Wrap the external CUDA + // stream with full cuBLAS/cuDNN handles. When CUDA graph capture is also enabled, + // capture/replay run on this same user stream (see GetPerThreadContext), so kernels and graph + // capture share one stream. RETURN_IF_ERROR(cuda_stream->InitHandlesWithUserStream( static_cast(ep->config_.user_compute_stream))); } else if (ep->config_.enable_cuda_graph) { @@ -439,15 +458,20 @@ OrtStatus* ORT_API_CALL CudaEp::SyncImpl(OrtEp* this_ptr) noexcept { /*static*/ OrtStatus* ORT_API_CALL CudaEp::IsConcurrentRunSupportedImpl( OrtEp* this_ptr, bool* is_supported) noexcept { + ORT_UNUSED_PARAMETER(this_ptr); + if (is_supported == nullptr) { return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "is_supported must not be null."); } - auto* ep = static_cast(this_ptr); - // When a unified stream is in use (either from user_compute_stream, external - // allocator, or explicit use_ep_level_unified_stream), all operations share a - // single stream so concurrent runs are not safe. - *is_supported = !ep->config_.use_ep_level_unified_stream; + // Plugin kernels currently expose only the raw cudaStream_t to GetScratchBuffer(), not the + // framework OrtSyncStream* that the stream-aware arena needs to tag scratch chunks by stream. + // Scratch chunks are therefore allocated with a null stream tag and can be reused freely. That is + // safe when runs are serialized, but it is not safe to advertise concurrent Session::Run(): two + // runs on different CUDA streams could reuse the same scratch chunk while earlier work is still + // in flight. Re-enable concurrent runs only after the plugin kernel layer can pass a stable + // framework stream (or equivalent sync id) to the arena. + *is_supported = false; return nullptr; } @@ -467,7 +491,25 @@ CudaEp::PerThreadContext& CudaEp::GetPerThreadContext() const { return *cached_context_it->second; } - auto context = std::make_shared(config_.device_id); + // NOTE: `enable_cuda_graph` in this condition does NOT restrict using a user compute stream to + // the graph case. A user compute stream is honored for kernels in BOTH graph and non-graph runs + // — that happens in CreateSyncStreamForDeviceImpl(), which wraps config_.user_compute_stream + // independently of enable_cuda_graph. This flag only governs the PerThreadContext's *graph + // stream*, and PerThreadContext is a graph-capture-only object: GetPerThreadContext() is reached + // exclusively from the graph path (CreateSyncStreamForDeviceImpl's enable_cuda_graph branch, + // OnRunStart/OnRunEnd, IsGraphCaptured, ReplayGraph). With graph disabled, no PerThreadContext is + // ever constructed, so its stream ownership is irrelevant. + // + // When a user compute stream IS combined with CUDA graph capture, capture/replay must run on the + // user's stream (the same stream the kernels are issued to) rather than a separate EP-owned + // stream. The user owns the stream's lifetime, so the context must not destroy it. Derive this + // from the caller's intent (has_user_compute_stream && enable_cuda_graph), not from whether the + // handle is null: a user may explicitly choose the CUDA default stream (nullptr), which is still + // an external stream that the context must not own/destroy. + const bool use_external_stream = config_.has_user_compute_stream && config_.enable_cuda_graph; + cudaStream_t external_stream = + use_external_stream ? static_cast(config_.user_compute_stream) : nullptr; + auto context = std::make_shared(config_.device_id, use_external_stream, external_stream); PerThreadContext& context_ref = *context; { std::lock_guard lock(per_thread_contexts_mutex_); diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc index cb021034662e8..d445d8bab033c 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep_factory.cc @@ -609,12 +609,9 @@ OrtStatus* ORT_API_CALL CudaEpFactory::CreateEpImpl( "CUDA plugin EP does not support using both user_compute_stream and external allocator simultaneously."); } - // Validate: user_compute_stream and cuda graph cannot both be active. - if (config.has_user_compute_stream && config.enable_cuda_graph) { - return factory->ort_api_.CreateStatus( - ORT_INVALID_ARGUMENT, - "CUDA plugin EP does not support using both user_compute_stream and enable_cuda_graph simultaneously."); - } + // user_compute_stream and enable_cuda_graph CAN be combined: when both are set, CUDA graph + // capture/replay runs on the user-provided stream (the same stream kernels are issued to), + // matching the bundled CUDA EP behavior. See CudaEp::GetPerThreadContext. // When user_compute_stream is set, force unified stream mode (matches bundled EP behavior). if (config.has_user_compute_stream) { diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 021acaf142435..f134c599d5b46 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -1023,56 +1023,60 @@ class CudaKernel : public OpKernel { template using IAllocatorUniquePtr = std::unique_ptr>; template - inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* s) const { + inline IAllocatorUniquePtr GetScratchBuffer(size_t cnt, void* /*stream*/) const { if (cnt == 0) return IAllocatorUniquePtr(nullptr, [](T*) {}); - size_t sz = 0; - if (!detail::TryBytesForCount(cnt, detail::SizeOf::value, sz)) { - ORT_THROW("CUDA scratch buffer allocation size overflow for ", cnt, " elements"); - } - void* p = nullptr; - cudaError_t alloc_result = cudaSuccess; - bool used_async_alloc = false; - if (s) { - // Note: stream-ordered allocations (cudaMallocAsync/cudaFreeAsync) rely on CUDA Memory Pools, - // which are not supported on NVIDIA GPUs with Multi-Instance GPU (MIG) enabled. - // On such instances, this will return cudaErrorNotSupported. - alloc_result = cudaMallocAsync(&p, sz, static_cast(s)); - used_async_alloc = (alloc_result == cudaSuccess); - if (!used_async_alloc && (alloc_result == cudaErrorNotSupported || alloc_result == cudaErrorInvalidValue)) { - cudaGetLastError(); // Clear the thread-local error state - alloc_result = cudaMalloc(&p, sz); - } - } else { - alloc_result = cudaMalloc(&p, sz); - } - - if (alloc_result != cudaSuccess) { - ORT_THROW("CUDA scratch buffer allocation failed for ", sz, " bytes: ", cudaGetErrorString(alloc_result)); - } - - return IAllocatorUniquePtr(static_cast(p), [s, used_async_alloc](T* ptr) { - if (ptr) { - // Guard: only attempt async free if the stream is still registered. - // CudaSyncStream::~CudaSyncStream guarantees UnregisterStream() is - // called before cudaStreamDestroy(), so a non-null lookup here means - // the raw cudaStream_t handle is still valid. - if (used_async_alloc && s && - cuda_plugin::CudaSyncStream::FromCudaStream(static_cast(s)) != nullptr) { - // As noted above, cudaFreeAsync may also return cudaErrorNotSupported on MIG-enabled instances. - cudaError_t free_result = cudaFreeAsync(ptr, static_cast(s)); - if (free_result == cudaSuccess) { - return; - } - cudaGetLastError(); // Clear any error set by cudaFreeAsync - } - - // Fall back to synchronous free if async free is unsupported or if the - // stream is no longer registered. cudaFree is valid for allocations - // returned by cudaMallocAsync and avoids using a stale stream handle. - cudaFree(ptr); - } - }); + // Route kernel scratch/workspace allocations through the EP allocator + // (a BFC arena by default) instead of raw cudaMallocAsync/cudaMalloc. + // + // The arena pre-reserves device memory and reuses freed chunks across runs. + // Once the model has executed `min_num_runs_before_cuda_graph_capture` + // warmup runs, the arena has grown to its steady-state working set, so the + // capture run serves every scratch allocation from an already-reserved chunk + // without issuing a fresh cudaMalloc. This keeps the device free-memory + // footprint stable across the capture window, which is required for correct + // CUDA graph capture/replay. + // + // The previous behavior (cudaMallocAsync/cudaMalloc allocated-and-freed per + // call) allocated new device memory on every run, including the capture run, + // so no amount of warmup could stabilize it and the + // "GPU memory was allocated during CUDA graph capture" detector would trip. + // This now matches the built-in (non-plugin) CUDA EP, which also obtains + // scratch from Info().GetAllocator() (see core/providers/cuda/cuda_kernel.h). + // The overflow check that the previous hand-rolled path performed is still + // enforced inside MakeUniquePtr via ValidatedCalcMemSizeForArray (it throws + // on cnt * sizeof(T) overflow). + // + // The compute stream is intentionally NOT forwarded to the allocator here. This is a + // bookkeeping decision, NOT a synchronization bug: the `stream` argument to a stream-aware + // arena is only metadata used to decide when a freed chunk may be reused on a *different* + // stream without an intervening sync. It does not change where the kernel runs - the returned + // buffer is still consumed by the kernel on the real compute stream. In a serialized run (and + // within one graph-capture run), alloc/free/reuse ordering is implicit on that stream, so there + // is no cross-stream chunk to race on. Tagging chunks with a null stream (freely reusable, the + // same semantics as a plain non-stream-aware BFC arena) is therefore correct and safe as long + // as the EP does not advertise concurrent Session::Run() support. + // + // It is also currently the only safe option, because of a C-API type constraint: a plugin + // kernel only has the raw cudaStream_t (KernelContext::GetGPUComputeStream), not the framework + // OrtSyncStream* that the stream-aware arena persists in each chunk (CudaArena stores + // `chunk->stream` and later dereferences it through the EP stream API, e.g. + // SyncStream_GetImpl/SyncStream_GetSyncId). Note that OrtSyncStream (the ORT-core wrapper, + // `struct OrtSyncStream : public onnxruntime::Stream`) is a DIFFERENT object from the plugin's + // CudaSyncStream (an OrtSyncStreamImpl); CudaSyncStream::FromCudaStream() recovers the latter, + // not the former. Wrapping the raw handle in a temporary framework Stream shim and passing it + // down would be unsafe on two counts: (1) the shim is stack-allocated and would dangle after + // this function returns while the arena still holds the pointer, and (2) it is type-confused — + // the arena would reinterpret a framework Stream* as an OrtSyncStream* that was never created + // by ORT for this stream. + // + // Properly stream-tagging scratch chunks (needed before this path can support concurrent + // multi-stream runs) requires new C-API surface to expose the framework OrtSyncStream* to + // plugin kernels. See docs/cuda_plugin_ep/arena_allocator_migration_design.md ("Scratch buffer + // stream tagging") for the limitation and future work. + return ::onnxruntime::IAllocator::MakeUniquePtr( + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), cnt, /*use_reserve*/ false, + /*stream*/ nullptr); } template inline IAllocatorUniquePtr GetTransientScratchBuffer(size_t cnt) const { diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index 230b0b495bfbf..0202789a8777d 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -2,12 +2,87 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" +#ifndef BUILD_CUDA_EP_AS_PLUGIN #include "core/providers/cpu/tensor/shape_op.h" +#endif #include "core/providers/cuda/cuda_fwd.h" +#ifdef BUILD_CUDA_EP_AS_PLUGIN +#include +#endif + namespace onnxruntime { namespace cuda { +#ifdef BUILD_CUDA_EP_AS_PLUGIN +// The bundled CUDA EP registers the CPU `onnxruntime::Shape` kernel (which +// derives from the framework `OpKernel`) and only marks its output as CPU +// memory. That class cannot be registered through the plugin EP's adapter +// kernel machinery, so the plugin build provides an adapter-based Shape kernel +// with identical semantics. Shape only reads the input's shape metadata (never +// its data) and writes the dims to a CPU output, so registering it on the CUDA +// EP keeps the node inside the device partition and avoids the device->host +// Memcpy node that the framework would otherwise insert to feed an isolated CPU +// Shape node -- a Memcpy that would prevent CUDA Graph capture. The output is +// still CPU memory, so a downstream device consumer may still need a copy; this +// removes the graph-breaking input-side Memcpy, it does not eliminate all copies. +class Shape final : public CudaKernel { + public: + explicit Shape(const OpKernelInfo& info) : CudaKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs a 1D int64 tensor (on CPU memory) + // containing the shape of the input tensor. + Status ComputeInternal(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = static_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), + static_cast(true_start), + static_cast(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; +#endif // BUILD_CUDA_EP_AS_PLUGIN + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc new file mode 100644 index 0000000000000..d49faf3c90ea8 --- /dev/null +++ b/onnxruntime/test/providers/cuda/plugin/cuda_plugin_user_stream_graph_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests that the CUDA plugin EP supports combining a user-provided compute stream +// (user_compute_stream) with CUDA graph capture/replay (enable_cuda_graph). +// +// Historically the plugin EP rejected this combination with ORT_INVALID_ARGUMENT. +// It now captures and replays the CUDA graph on the user-provided stream (the same +// stream the kernels are issued to), matching the bundled CUDA EP behavior. These +// tests verify: +// 1. Session creation succeeds with both options set (regression for the removed +// validation). +// 2. Capture + replay on the user stream produce correct results. +// 3. Replay after an in-place input update (on the user stream) is correct. + +#if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/file_util.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { +namespace { + +constexpr const char* kCudaPluginEpRegistrationName = "CudaPluginUserStreamGraphTest"; +constexpr const char* kCudaPluginEpName = "CudaPluginExecutionProvider"; + +// Resolve the CUDA plugin EP shared library path. +std::filesystem::path GetCudaPluginLibraryPath() { + return GetSharedLibraryFileName(ORT_TSTR("onnxruntime_providers_cuda_plugin")); +} + +// RAII handle that registers/unregisters the CUDA plugin EP library. +class ScopedCudaPluginRegistration { + public: + ScopedCudaPluginRegistration(Ort::Env& env, const char* registration_name) + : env_(env), name_(registration_name) { + auto lib_path = GetCudaPluginLibraryPath(); + if (!std::filesystem::exists(lib_path)) { + available_ = false; + return; + } + env_.RegisterExecutionProviderLibrary(name_.c_str(), lib_path.c_str()); + available_ = true; + } + + ~ScopedCudaPluginRegistration() { + if (available_) { + try { + env_.UnregisterExecutionProviderLibrary(name_.c_str()); + } catch (...) { + } + } + } + + bool IsAvailable() const { return available_; } + + ScopedCudaPluginRegistration(const ScopedCudaPluginRegistration&) = delete; + ScopedCudaPluginRegistration& operator=(const ScopedCudaPluginRegistration&) = delete; + + private: + Ort::Env& env_; + std::string name_; + bool available_ = false; +}; + +// Find the CUDA plugin EP device after registration. +Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { + auto ep_devices = env.GetEpDevices(); + for (const auto& device : ep_devices) { + if (strcmp(device.EpName(), kCudaPluginEpName) == 0) { + return device; + } + } + return Ort::ConstEpDevice{nullptr}; +} + +} // namespace + +class CudaPluginUserStreamGraphTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kCudaPluginEpRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + cudaDeviceSynchronize(); + } + + // Build session options that select the plugin EP with CUDA graph capture enabled + // and the user-provided stream supplied as a pointer-sized address string. + Ort::SessionOptions CreateUserStreamGraphSessionOptions(cudaStream_t user_stream) { + Ort::SessionOptions so; + std::unordered_map provider_options = { + {"enable_cuda_graph", "1"}, + {"user_compute_stream", + std::to_string(reinterpret_cast(user_stream))}, + }; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, provider_options); + return so; + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +// Regression: creating a session with both user_compute_stream and enable_cuda_graph +// used to fail with ORT_INVALID_ARGUMENT. It must now succeed. +TEST_F(CudaPluginUserStreamGraphTest, SessionCreatesWithUserStreamAndCudaGraph) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + { + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + ASSERT_NO_THROW({ + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + (void)session; + }); + } + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +// Full capture + replay on the user stream, including replay after an in-place input +// update. mul_1.onnx computes Y = X * W with W = [1, 2, 3, 4, 5, 6] (shape 3x2). +TEST_F(CudaPluginUserStreamGraphTest, CaptureAndReplayOnUserStream) { + cudaStream_t user_stream = nullptr; + ASSERT_EQ(cudaSuccess, cudaStreamCreate(&user_stream)); + + Ort::SessionOptions so = CreateUserStreamGraphSessionOptions(user_stream); + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), so); + + // Device allocator backing the plugin EP's default memory. + auto device_memory_info = cuda_device_.GetMemoryInfo(OrtDeviceMemoryType_DEFAULT); + auto allocator = ort_env->GetSharedAllocator(device_memory_info); + ASSERT_NE(allocator, nullptr); + + constexpr size_t kNumElements = 6; + constexpr size_t kBytes = kNumElements * sizeof(float); + const std::array shape = {3, 2}; + const std::array w_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Pre-allocate device input/output buffers (required for CUDA graph IO binding). + void* input_gpu = allocator.Alloc(kBytes); + void* output_gpu = allocator.Alloc(kBytes); + ASSERT_NE(input_gpu, nullptr); + ASSERT_NE(output_gpu, nullptr); + + auto upload_input = [&](const std::array& host_values) { + ASSERT_EQ(cudaSuccess, + cudaMemcpyAsync(input_gpu, host_values.data(), kBytes, + cudaMemcpyHostToDevice, user_stream)); + }; + + auto read_output = [&](std::array& host_values) { + // Kernels run on the user stream; wait for them before copying the result back. + ASSERT_EQ(cudaSuccess, cudaStreamSynchronize(user_stream)); + ASSERT_EQ(cudaSuccess, + cudaMemcpy(host_values.data(), output_gpu, kBytes, cudaMemcpyDeviceToHost)); + }; + + Ort::Value input_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(input_gpu), kNumElements, + shape.data(), shape.size()); + Ort::Value output_tensor = Ort::Value::CreateTensor( + device_memory_info, reinterpret_cast(output_gpu), kNumElements, + shape.data(), shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", input_tensor); + binding.BindOutput("Y", output_tensor); + + // First run: warmup + capture + first replay on the user stream. + const std::array x0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + upload_input(x0); + session.Run(Ort::RunOptions{}, binding); + + std::array y{}; + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "capture mismatch at " << i; + } + + // Second run: pure graph replay (same inputs) on the user stream. + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x0[i] * w_values[i]) << "replay mismatch at " << i; + } + + // Update the input in place on the user stream and replay again. + const std::array x1 = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; + upload_input(x1); + session.Run(Ort::RunOptions{}, binding); + read_output(y); + for (size_t i = 0; i < kNumElements; ++i) { + EXPECT_FLOAT_EQ(y[i], x1[i] * w_values[i]) << "updated-input replay mismatch at " << i; + } + + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + allocator.Free(input_gpu); + allocator.Free(output_gpu); + + ASSERT_EQ(cudaSuccess, cudaStreamDestroy(user_stream)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP From d4b01e437367deba5a0b1bcc5fa457e82c2388a5 Mon Sep 17 00:00:00 2001 From: FuZoe <147491646+FuZoe@users.noreply.github.com> Date: Thu, 25 Jun 2026 10:10:08 +0800 Subject: [PATCH 06/10] Fix CPU Attention causal mask alignment (#29050) ## Summary - align CPU ONNX Attention causal masking with upper-left behavior for q_len=1, kv_len>1, no past - preserve the existing `nonpad_kv_seqlen` / TensorScatter single-query causal behavior - update Python attention reference causal mask to model ONNX upper-left alignment with an explicit past offset - add a regression test for issue #29020 Fixes #29020 ## Validation - `python -m py_compile onnxruntime/test/python/transformers/test_onnx_attention/common.py onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py` - `git diff --check` Notes: - `pytest onnxruntime/test/python/transformers/test_onnx_attention/test_tensorscatter_attention.py -k "cpu_fp32 and causal" -q` could not run locally because this Python environment does not have `onnx` / `onnxruntime` installed. - After the latest follow-up commit, an incremental rebuild of `onnxruntime_provider_test` was attempted but failed in MSBuild before compiling this change due to a local environment issue: duplicate `Path` / `PATH` environment keys when launching `CL.exe`. --- .../core/providers/cpu/llm/attention.cc | 10 ++- .../providers/cpu/llm/attention_op_test.cc | 73 ++++++++++++++++++- .../test_onnx_attention/common.py | 11 +-- .../test_onnx_attention/test_gqa.py | 2 + .../test_onnx_attention/test_mha.py | 9 ++- 5 files changed, 97 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 1b0f37feab9fb..150fe7b478c1e 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -347,7 +347,15 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* mask_data = nullptr; bool delete_mask_data = false; - bool causal = parameters.is_causal && parameters.q_sequence_length > 1; + // In the nonpad_kv_seqlen path, q_len=1 is external KV-cache decode with + // bottom-right alignment. The single query's causal frontier is the valid + // length, so nonpad masking alone leaves exactly all valid keys visible. + // Keep causal=false for that case to avoid applying the batch-shared + // upper-left overlay used by the no-nonpad path. + bool causal = parameters.is_causal && + (parameters.has_nonpad_kv_seqlen + ? parameters.q_sequence_length > 1 + : !(parameters.q_sequence_length == 1 && parameters.past_sequence_length > 0)); // When nonpad_kv_seqlen is present the causal frontier is offset-aware // (bottom-right) and per-batch, so it cannot be baked into the batch-shared mask // buffer here; it is applied per-batch in the main loop below. Skip the top-left diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index cf76ca0fa00f8..03a47664c632a 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -2332,6 +2332,45 @@ TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_Decode_BottomRight) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// q_len=1 with nonpad_kv_seqlen keeps bottom-right decode behavior: the single +// query attends every valid key. If the no-nonpad upper-left overlay is applied +// here, this returns 1.0 instead of the expected 1/6. +TEST(AttentionTest, Attention_Causal_NonPadKVSeqLen_SingleQueryKeepsBottomRight_CPU) { + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + test.AddAttribute("is_causal", static_cast(1)); + + constexpr int batch_size = 1; + constexpr int q_num_heads = 1; + constexpr int kv_num_heads = 1; + constexpr int q_sequence_length = 1; + constexpr int kv_sequence_length = 6; + constexpr int head_size = 8; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::fill_n(v.begin(), head_size, 1.0f); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, q); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, k); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + test.AddInput("nonpad_kv_seqlen", {batch_size}, {kv_sequence_length}); + + std::vector expected_y(batch_size * q_num_heads * q_sequence_length * head_size, + 1.0f / static_cast(kv_sequence_length)); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, expected_y, false, 0, + 1e-4f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + // Continued / chunked prefill (S_q=2) into a partially-filled static cache. // nonpad=[4], S_q=2 -> offset = 4 - 2 = 2: query 0 attends keys {0,1,2}, query 1 // attends {0,1,2,3}. The old top-left alignment would mask everything past the @@ -3087,9 +3126,41 @@ TEST(AttentionTest, Attention4DSoftCapOutputQkRawLogits) { // ============================================================================ // Causal alignment tests: verify upper-left (no past) vs lower-right (with past) -// These are CUDA-only tests that validate the causal masking fix. +// These tests validate causal mask alignment across CPU and CUDA. // ============================================================================ +// Test: Causal + cross-attention (S_q=1, S_kv=6, no past) +// ONNX spec mandates upper-left alignment: q0 attends only to kv[0]. +// This covers GitHub issue #29020, where CPU skipped causal masking for S_q=1. +TEST(AttentionTest, Attention4DCausalSingleQueryCrossAttentionUpperLeft) { + int batch_size = 1; + int q_num_heads = 1; + int q_sequence_length = 1; + int head_size = 8; + int kv_sequence_length = 6; + int kv_num_heads = 1; + int v_head_size = 8; + int past_sequence_length = 0; + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.0f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.0f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * v_head_size, 0.0f); + for (int i = 0; i < v_head_size; ++i) { + v[i] = 1.0f; + } + std::vector y(batch_size * q_num_heads * q_sequence_length * v_head_size, 1.0f); + + RunTest4D(batch_size, q_num_heads, q_sequence_length, head_size, kv_sequence_length, kv_num_heads, + v_head_size, past_sequence_length, + q, k, v, std::vector(), std::initializer_list(), + std::vector(), std::vector(), + 1, -1, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -1, + TensorType::kFloat, + y, std::vector(), std::vector(), std::vector(), + false, false, true // disable_cpu, disable_cuda, disable_dml + ); +} + // Test: Causal + cross-attention (S_q=3, S_kv=5, no past) // ONNX spec mandates upper-left alignment: q_i attends to kv[0..i]. // V is identity-like so output directly reveals which KV positions were attended. diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/common.py b/onnxruntime/test/python/transformers/test_onnx_attention/common.py index 349f68f0ac5ce..d1a5d833d12a4 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/common.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/common.py @@ -666,12 +666,11 @@ def attention_past_func( # ################################################################################################# -def construct_causal_mask(seqlen_q, seqlen_k, device): - """Construct a causal mask for attention.""" +def construct_causal_mask(seqlen_q, seqlen_k, device, past_seqlen=0): + """Construct a causal mask for ONNX Attention upper-left alignment.""" row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - # Causal: positions can only attend to earlier positions - return col_idx > row_idx + seqlen_k - seqlen_q + return col_idx > row_idx + past_seqlen def attention_ref( @@ -682,6 +681,7 @@ def attention_ref( attn_bias=None, causal=False, softcap=0.0, + past_seqlen=0, ): """ Reference implementation of scaled dot-product attention with GQA support. @@ -694,6 +694,7 @@ def attention_ref( attn_bias: Additive attention bias [broadcastable to batch, num_heads, seq_q, seq_k] causal: Whether to apply causal masking softcap: Softcap value for attention scores (0.0 = disabled) + past_seqlen: Number of past K/V tokens before q[0] for causal masking. Returns: output: Attention output [batch, seq_q, num_heads, head_size] @@ -724,7 +725,7 @@ def attention_ref( scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device) + causal_mask = construct_causal_mask(seqlen_q, seqlen_k, q.device, past_seqlen=past_seqlen) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index 542094a9ac4ee..eca86e429b597 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -289,6 +289,7 @@ def parity_check_gqa_past( key_padding_mask=key_padding_mask, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -548,6 +549,7 @@ def parity_check_gqa_past_with_padding( key_padding_mask=key_padding_mask, causal=config.is_causal == 1, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) # --- ONNX Runtime Path --- diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py index e2d9acbd0c500..16a00085f224f 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py @@ -298,6 +298,7 @@ def parity_check_mha_past( attn_bias=attn_bias_ref, causal=causal, softcap=config.softcap, + past_seqlen=config.past_kv_sequence_length, ) out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() @@ -2393,7 +2394,13 @@ def test_mha_past_asymmetric_v_head_size(self): full_k_bsnh = full_k_bnsh.transpose(1, 2) full_v_bsnh = full_v_bnsh.transpose(1, 2) - out_ref, _ = attention_ref(q=q, k=full_k_bsnh, v=full_v_bsnh, causal=True) + out_ref, _ = attention_ref( + q=q, + k=full_k_bsnh, + v=full_v_bsnh, + causal=True, + past_seqlen=config.past_kv_sequence_length, + ) # ORT path — should fall back to unfused (not crash in MEA) out_ort, present_k, present_v = attention_past_func( From 92b4c663496d575c3e22f1e5cfee922b35a6962e Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 26 Jun 2026 13:32:43 +0800 Subject: [PATCH 07/10] webgpu: Enable FlashAttention for batched GQA with right-padded prompts (#29247) ## Summary Lift WebGPU FlashAttention's `batch_size == 1` restriction so batched GQA with right-padded prompts (the common GenAI batched-prefill shape) takes the fused FlashAttention path instead of falling back to `ApplyAttention`. - **Per-batch seqlens in FlashAttention shaders.** Prefill, decode split-reduce, CopyKVCache, and the fused rotary-and-copyKV template now read `seqlens_k[batch_idx]` instead of hardcoding `seqlens_k[0]`. All `past_X = total_X - new_X` subtractions are clamped to avoid u32 underflow when a short batch's per-batch total is less than the batch-wide `sequence_length`. - **Indirect-dispatch sizing uses GQA's `total_sequence_length` input.** `CopyKVCache`, `SplitPackedQKVWithRotaryEmbeddingAndCopyKV`, and `FlashAttentionDecodeQKV` now take a new `total_sequence_length_input` binding (GQA input #6, GPU-resident under graph capture) for the indirect-dispatch grid sizing. This is the global max KV span across the batch by construction, replacing the previous `seqlens_k[0] + 1u` that under-dispatched whenever batch 0 wasn't the longest. Per-batch `seqlens_k[batch] + 1` still drives causal masking and K/V bounds inside the kernels. GQA now enforces `graph_capture_enabled -> past_present_share_buffer_` so the host-side `use_indirect_dispatch` predicate stays simple. - **Decoupled attention_bias stride from per-batch OOB.** `attention_bias` is still allocated to the global max `total_sequence_length`; only the causal-mask / softmax tile loops are gated by the per-batch total. The one-past-end fallback was tightened to clamp inside the same row (`offset_base + stride_total_seq - 1u`). - **Decode workgroup grid stays at global max.** `decode_qkv` keeps a workgroup grid sized to the global max tile count to keep `workgroup_idx` slicing consistent across batches, with neutral `(-inf, 0)` early-exit for tiles beyond a short batch's per-batch total so the `VxReduce` online softmax rescaling is not skewed. - **New `use_seqlen_k` template parameter** (separate from `use_indirect_dispatch` which still requires graph capture). It is enabled whenever `seqlen_k` is provided and (`graph_capture || batch_size_ > 1`). - **Rotary fix prerequisite** (`webgpu: fix GQA batched right-padded prefill with do_rotary`, 591df5b1): clamps `past_seqlen` to 0 in `RotaryEmbeddingProgram`, `FusedQKRotaryEmbeddingProgram`, and `split_packed_qkv_with_rotary_embedding`, which previously produced gibberish for the shorter batches. ## Motivation GenAI's batched prefill right-pads short prompts to the batch max and reports each batch's real length via `seqlens_k[b] = real_len[b] - 1`. The previous FlashAttention gate forced every batched call onto the slower `ApplyAttention` path, and the rotary shaders underflowed `u32` for any batch shorter than the batch-wide `sequence_length`, producing garbage Q/K positions and gibberish output text for the shorter batches. ## Test plan - [x] All `GroupQueryAttentionTest.WebGPU_*` op tests pass, including `BatchedRightPaddedRotaryPrefill` (FlashAttention path) and the new `BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU` covering a `real_lens` spread > tile_size - [x] phi4-prune three-prompt batched generation: coherent outputs on WebGPU matching CPU reference (3 prompts, 384 tokens, 173 tps) - [x] phi4-prune single-prompt generation regression: coherent - [x] phi4-graph-prune (graph capture enabled): `verify_model_correctness.py` 4/4 PASS; `verify_multi_gen.py` sequential + overlapping both PASS - [x] whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU - [x] Lintrunner clean on all changed files --- .../webgpu/bert/flash_attention.cc | 111 ++++++++++++------ .../contrib_ops/webgpu/bert/flash_attention.h | 24 ++-- .../webgpu/bert/flash_attention.wgsl.template | 79 +++++++------ .../flash_attention_decode_qkv.wgsl.template | 52 ++++++-- ...h_attention_decode_vx_reduce.wgsl.template | 10 +- .../webgpu/bert/group_query_attention.cc | 10 +- ..._rotary_embedding_and_copykv.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 66 ++++++++++- 8 files changed, 249 insertions(+), 106 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d9d299d4fd5d9..4e926c7efa597 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -55,6 +55,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + if (prepare_indirect_dispatch_) { + sh.AddInput("total_sequence_length_input", ShaderUsage::None); + } const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); @@ -97,8 +100,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::None); } - // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + // If prepare_indirect_dispatch is enabled, add total_sequence_length_input + // and indirect_buffer output. total_sequence_length_input is the global max + // total sequence length across the batch (from GQA input #6); using it for + // dispatch sizing covers right-padded batches where batch 0 is not the max. if (prepare_indirect_dispatch_) { + shader.AddInput("total_sequence_length_input", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -109,11 +116,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; if (use_seqlen_k_) { - shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[batch]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + // Right-padded batches with prompt shorter than kv_sequence_length would underflow u32; clamp to 0. + shader.MainFunctionBody() << " let past_sequence_length = select(total_seq_length - uniforms.kv_sequence_length, 0u, total_seq_length <= uniforms.kv_sequence_length);\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; } else { @@ -124,7 +132,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (prepare_indirect_dispatch_) { shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" - << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" + << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -152,7 +161,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles, + const Tensor* total_seqlen) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -188,6 +198,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -262,9 +275,15 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (use_indirect_dispatch_) { + // Global max total sequence length across batches (from GQA input #6). + // Used in indirect-dispatch mode for the workgroup_idx slicing so that + // batch 0's per-batch length cannot undersize the dispatch grid. + shader.AddInput("total_sequence_length_input", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -282,6 +301,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), @@ -293,7 +313,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -303,13 +323,16 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (use_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -320,10 +343,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } if (use_indirect_dispatch) { @@ -332,7 +357,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -343,6 +368,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, {attn_bias_dim1}, + {attn_bias_dim3}, {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); @@ -351,7 +377,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -364,7 +390,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(output, output)); @@ -379,17 +405,17 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch, const Tensor* head_sink, - uint32_t m_tile) { + uint32_t m_tile, + bool use_seqlen_k) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { @@ -399,7 +425,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) + .CacheHint(tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -415,7 +441,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink, + const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr. @@ -437,7 +464,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Read seqlens_k per batch_idx in the shader whenever seqlens_k is supplied. + // This covers both graph-capture (total_sequence_length_ is 0 on the host) and + // right-padded batches (batch_size > 1 with distinct per-batch totals), and lets + // batch=1 share the same path. When seqlens_k is null, kernels fall back to + // uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr; // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -453,8 +485,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Prepare indirect dispatch buffer for split-reduce path with static KV cache. // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && + // Static KV cache (past_present_share_buffer_) is guaranteed by GQA's + // ORT_ENFORCE when graph capture is enabled. + const bool use_indirect_dispatch = seqlen_k != nullptr && + total_seqlen != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions @@ -492,10 +526,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles)); + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); } // Extract present_sequence_length directly from present_key tensor shape @@ -555,10 +590,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) @@ -572,7 +609,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {alpha}, {num_seq_tile}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {attn_bias_dim3}}); return context.RunProgram(program); } @@ -596,27 +634,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile)); + present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - head_sink, m_tile)); + num_present_sequence_length_tile, tile_size, + head_sink, m_tile, use_seqlen_k)); return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - const bool kv_empty = parameters.kv_sequence_length_ == 0; - // FlashAttention here does not implement right-padded per-batch prefill, so the - // first disjunction restricts it to inputs where padding cannot occur: - // - batch_size_ == 1: single sequence, no padding possible. - // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. - // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. - // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), - // mismatched head/value sizes, and head_size alignments unsupported by the kernel. - return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && - !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } @@ -631,7 +660,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size, uint32_t num_q_tiles) { + uint32_t tile_size, uint32_t num_q_tiles, + const Tensor* total_seqlen) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -669,6 +699,9 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {cos_cache, ProgramTensorMetadataDependency::Rank, components}, {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }); + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, {present_key, ProgramTensorMetadataDependency::None, components}, {present_value, ProgramTensorMetadataDependency::None, components}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 218baf926173f..85ba61c1d20b5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -125,7 +125,8 @@ class FlashAttentionProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, - {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim3", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -148,8 +149,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -195,17 +199,18 @@ class FlashAttentionDecodeVxReduceProgram final : public Program u32 { - return u32(seqlens_k[0]) + 1u; +// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. +fn get_total_sequence_length(batch_idx: u32) -> u32 { + return u32(seqlens_k[batch_idx]) + 1u; } #else -// When graph capture is disabled, total_sequence_length comes from uniforms -fn get_total_sequence_length() -> u32 { +// Without seqlens_k, total_sequence_length comes from uniforms (max across batches). +fn get_total_sequence_length(batch_idx: u32) -> u32 { return uniforms.total_sequence_length; } #endif @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); @@ -95,15 +93,19 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { } #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t { - if (k_idx_global >= get_total_sequence_length()) { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t { + if (k_idx_global >= total_seq) { return q_element_t(0); } let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); - return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; + return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq - 1u)]); } #endif @@ -111,24 +113,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // For max performance max_k_step should be the same as sg_size, however we might run out of registers // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } } @@ -160,18 +162,22 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= total_seq) { return vec4(0); } // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; - let offset_max = offset_base + get_total_sequence_length(); + let offset_max = offset_base + stride_total_seq - 1u; let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -179,7 +185,7 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { return vec4(0); } #endif @@ -226,11 +232,14 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; #endif - let total_sequence_length = get_total_sequence_length(); + let total_sequence_length = get_total_sequence_length(batch_idx); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); @@ -244,8 +253,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx); - loadv(k_start, batch_head_idx, local_idx); + loadk(k_start, batch_head_idx, local_idx, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, total_sequence_length); workgroupBarrier(); for (var k = 0u; k < max_k_step; k++) { @@ -254,7 +263,7 @@ $MAIN { score += dot(q_tile[i], k_tile[k][i]); } #if has_attention_bias - score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx); + score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length); #endif qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length); } @@ -302,8 +311,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx, capped_sg_size); - loadv(k_start, batch_head_idx, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); workgroupBarrier(); // Compute QKt @@ -361,11 +370,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length); } // Neuter qk values where K is out of bounds. diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 524a18ca43245..778e07fbf63ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -10,6 +10,7 @@ #param tile_size #param tile_size_k_vec #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -34,18 +35,22 @@ var tile_max: array; var tile_sum: array; #if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + - bias_head_idx * uniforms.new_sequence_length * total_seq_length + - q_idx * total_seq_length + + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + + q_idx * stride_total_seq + k_idx; return attention_bias[offset]; } #else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { return q_element_t(0); } @@ -54,12 +59,14 @@ var tile_sum: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; + // total_sequence_length used for workgroup_idx slicing must match the host-side dispatch + // grid, i.e. the global maximum across batches. Per-batch total is derived separately below. #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + let global_total_sequence_length = u32(total_sequence_length_input[0]); #else - let total_sequence_length = uniforms.total_sequence_length; + let global_total_sequence_length = uniforms.total_sequence_length; #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_total_seq_length_tile = (global_total_sequence_length + tile_size - 1) / tile_size; let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; @@ -71,9 +78,28 @@ $MAIN { if (batch_idx >= uniforms.batch_size) { return; } + // Per-batch total_sequence_length used for K/V bounds, causal mask, and softmax range. + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[batch_idx]) + 1u; + #else + let total_sequence_length = global_total_sequence_length; + #endif let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + // If this workgroup's tile lies entirely beyond this batch's per-batch total_sequence_length, + // write neutral metadata so VxReduce contributes nothing for these tiles, then exit early. + if (total_seq_offset >= total_sequence_length) { + if (local_idx == 0u) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx_local = q_base + m; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx_local) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(-3.4028234663852886e+38f, 0.0f)); + } + } + return; + } + // ============================================================ // Phase 1: QK^T computation // ============================================================ @@ -109,6 +135,12 @@ $MAIN { } // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask +#if is_unidirectional + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); +#endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { @@ -117,9 +149,9 @@ $MAIN { sum += inner_qk_values[m][local_idx][i]; } - sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional - if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + if (total_seq_offset + local_idx > past_sequence_length + q_idx) { sum = q_element_t(-65504.0f); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a3ce0b68cb659..628ad835a9d4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -5,7 +5,7 @@ #param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -32,8 +32,12 @@ $MAIN { } let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + // Per-batch total_sequence_length: short batches contributed neutral metadata + // (-inf, 0) for tiles beyond their per-batch total, so reading only this batch's + // tiles ensures softmax rescaling is not skewed by garbage tiles. + #if use_seqlen_k + let batch_idx_for_seqlen = batch_head_idx / uniforms.num_heads; + let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 36d688c9723fd..24ace3487a4c5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -327,6 +327,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& past_value->DataRaw() == present_value->DataRaw(); ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + ORT_ENFORCE(!context.IsGraphCaptureEnabled() || parameters.past_present_share_buffer_, + "Graph capture requires past/present KV cache to share the same buffer (static KV cache)."); Tensor qSplit; Tensor kSplit; @@ -350,7 +352,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (kv_empty) { @@ -381,7 +383,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink); + present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink, + total_seqlen_tensor); } // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -472,7 +475,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink, + total_seqlen_tensor); } // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 97c610fb90024..e3d92c036d2c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -43,7 +43,8 @@ $MAIN { #if prepare_indirect_dispatch if (global_idx == 0u) { - let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + let global_total_seq_length = u32(total_sequence_length_input[0]); + let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index c6f6e0affc163..28c41b5bf5ed4 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2613,7 +2613,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - GqaTargetEp target_ep = GqaTargetEp::kCpu) { + GqaTargetEp target_ep = GqaTargetEp::kCpu, + bool smooth_softmax = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2632,6 +2633,12 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); + if (smooth_softmax) { + // smooth_softmax disqualifies the WebGPU FlashAttention path via the outer + // gating in GroupQueryAttention::ComputeInternal, routing this case through + // ApplyAttention instead. + tester.AddAttribute("smooth_softmax", static_cast(1)); + } // Packed QKV: pass through `query` input, leave key/value as optional edges. if (use_fp16) { @@ -2722,7 +2729,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, + const std::vector& real_lens = {4, 2, 6}, + bool smooth_softmax = false) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2731,10 +2740,10 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int kv_hidden_size = kv_num_heads * head_size; constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; - // Real prompt lengths per batch; max = sequence_length (right-padding extends + // Per-batch real prompt lengths; max = sequence_length (right-padding extends // shorter batches up to this length). The bug only manifests when at least // one batch is shorter than sequence_length. - const std::vector real_lens = {4, 2, 6}; + ASSERT_EQ(static_cast(real_lens.size()), batch_size); const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); std::vector packed_batched; @@ -2755,7 +2764,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, target_ep); + packed_single, target_ep, smooth_softmax); } // Now run all batches together with right-padding. @@ -2765,7 +2774,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, target_ep); + seqlens_k_data, packed_batched, target_ep, smooth_softmax); // Guard the regression deterministically: every element of the batched output // (including padding rows) must be finite. The CPU root cause is uninitialized @@ -2818,5 +2827,50 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with per-batch +// real_lens whose max crosses the prefill threshold (sequence_length >= 32) so +// the WebGPU EP picks FlashAttentionProgram (single-kernel prefill path with +// subgroup shuffles) instead of the split-reduce decode path. This exercises +// the prefill flash-attention kernel under right-padded batches with do_rotary, +// which is the path used by Phi-style models during batched prefill. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // sequence_length = max(real_lens) = 33 > 32 -> FlashAttentionProgram path. + // Mixed shorter batches (12, 20) ensure right-padding is non-trivial. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); +} + +// Stress the FlashAttention prefill path with a per-batch spread that exceeds +// the indirect-dispatch tile size (64). batch 0 has the SHORTEST real length; +// batch 2 has the LONGEST. This is the data pattern that would surface the +// indirect-dispatch undersizing bug when graph capture is enabled (where the +// dispatch grid is sized from a GPU buffer rather than the host scalar). +// OpTester does not toggle graph capture, so this test exercises the new +// total_sequence_length_input shader plumbing on the non-graph-capture path; +// the graph-capture path is covered end-to-end by phi4-graph-prune verification. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // spread = 96 - 20 = 76 > tile_size(64), batch 0 is not the max. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 96}); +} + +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with +// smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes +// through ApplyAttention (non-flash path). Covers right-padded batched prefill +// on the non-flash attention path (used by e.g. Phi-4 attention variants). +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); +} + } // namespace test } // namespace onnxruntime From a203dfafc94b5446a59ddd67e92e6b4b66b01d7e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jun 2026 23:29:40 -0700 Subject: [PATCH 08/10] [CPU] Add FP32 GEMV decode kernel for GroupQueryAttention (#29216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description PR1 https://github.com/microsoft/onnxruntime/pull/28962 adds flash attention for **prefill**, and removed flash decoding. This PR will add optimized kernel for **single-token decode**, which will be faster than other kernels including flash decoding. This PR builds on the prefill-only flash attention change and additionally introduces a dedicated decode kernel. #### What's included - **Decode (GEMV) kernel** — A dedicated single-token decode kernel (`MlasGQADecodeGQAThreaded`) for `sequence_length == 1`, parallelized over (batch, head) with a two-pass softmax, using GEMV (`acc[8]`-lane dot product / AXPY) helpers instead of per-block M=1 SGEMM calls. This fixes the per-block SGEMM decode regression. - The FP32 flash gate (`group_query_attention.cc`) is enabled for `total_sequence_length > 1`, routing prefill to the tiled kernel and decode to the GEMV kernel. - The quantized KV-cache path is unchanged (FP32-only scope). #### Results (AMD EPYC 7763, AVX2, 8 threads) - **Decode:** correctness ~1e-8 vs naive; long-context decode ~1.0–1.5x (T = 4097 ~1.3–1.5x). ### Motivation and Context The naive GQA path materializes the full score matrix, which is memory-bound for long sequences. Flash attention reduces memory traffic for prefill, and the GEMV decode kernel avoids SGEMM overhead for the M=1 decode case. ### Testing - Built with `--compile_no_warning_as_error`. - Correctness verified against the naive path for both prefill and decode (max abs diff ~1e-8). - Benchmarked via `benchmark_gqa_cpu_flash.py`. --- docs/contrib_ops/cpu/gqa.md | 79 ++- .../contrib_ops/cpu/bert/gqa_attention_base.h | 51 +- .../cpu/bert/group_query_attention.cc | 10 +- onnxruntime/core/mlas/inc/mlas.h | 14 +- onnxruntime/core/mlas/lib/flashattn_gqa.cpp | 528 +++++++++++++++++- .../test/python/transformers/test_gqa_cpu.py | 64 +++ 6 files changed, 707 insertions(+), 39 deletions(-) diff --git a/docs/contrib_ops/cpu/gqa.md b/docs/contrib_ops/cpu/gqa.md index 8b81fdba8f1a6..ffce29682ca42 100644 --- a/docs/contrib_ops/cpu/gqa.md +++ b/docs/contrib_ops/cpu/gqa.md @@ -58,7 +58,7 @@ Both the non-quantized and quantized paths have two execution strategies: - **Naive (full materialization)**: Computes the full `[S, T]` attention score matrix, applies masking and softmax, then computes the SV product. Simple but memory-intensive for long sequences. - **Flash Attention (tiled, online softmax)**: Processes K/V in L2-cache-sized blocks using the online softmax algorithm (Milakov & Gimelshein, 2018). Avoids materializing the full attention matrix, reducing peak memory from O(S×T) to O(S×Bc) per head. Multi-threaded via the MLAS thread pool. -The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, and online-softmax structure. The quantized path additionally provides a two-phase flash-decoding strategy for single-token decode; the non-quantized FP32 path is limited to prefill (`sequence_length > 1`) and uses the naive path for decode. +The quantized path uses `MlasFlashAttentionQuantizedKV` (`flashattn_qkv.cpp`); the non-quantized FP32 path uses `MlasFlashAttentionGQA` (`flashattn_gqa.cpp`). Both share the same tiling, masking, online-softmax, and flash-decoding structure. The flash path is selected by default when conditions are met (see below). Set `ORT_GQA_DISABLE_FLASH_ATTENTION=1` to force the naive path (applies to both the quantized and non-quantized paths). @@ -213,7 +213,7 @@ The partials buffer is allocated alongside the per-thread scratch in a single al ## Non-Quantized Flash Attention Path -The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, and masking structure. Unlike the quantized path, it is limited to prefill / chunked-prefill (`sequence_length > 1`); single-token decode (`sequence_length == 1`) uses the naive path, which is why there is no flash-decoding variant here. +The non-quantized flash attention path (`MlasFlashAttentionGQA`, in `flashattn_gqa.cpp`) is the FP32-KV-cache counterpart of the quantized path. It is selected for the `float` kernel specialization and reuses the same tiling, online-softmax, masking, and flash-decoding structure. ### Differences from the Quantized Path @@ -242,7 +242,8 @@ computes a per-q_block bound `ir >= kv_causal_limit`, instead of computing and then discarding the masked upper-triangle QK/SV GEMMs. This skips roughly half of the QK/SV work for square prefill (S = T) and is the main reason the FP32 flash path is faster than naive even at short sequence lengths -(see the benchmark results below). +(see the benchmark results below). Decode (q_block of size 1 at the cache tail) attends to all +KV positions, so the bound equals `total_seqlen` and nothing is skipped. ### Activation Conditions @@ -250,18 +251,52 @@ The non-quantized flash path is selected when ALL of the following hold: - The kernel specialization is `float` (FP16 uses the naive path) - `ORT_GQA_DISABLE_FLASH_ATTENTION` environment variable is not set (or set to `0`) -- `sequence_length > 1` (prefill / chunked-prefill; single-token decode uses the naive path) +- `total_sequence_length > 1` - No softcap - No smooth softmax - No head sink - No output QK capture - `present_key` and `present_value` are provided -Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, and shared past/present buffers are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. +Attention bias, causal masking, local window attention, GQA head grouping (`num_heads != kv_num_heads`), ragged per-batch sequence lengths, shared past/present buffers, and flash decoding are all supported, mirroring the quantized flash path. When any condition is not met, the kernel falls back to the naive full-materialization path. + +### Block Sizes, Threading, and Flash Decoding + +Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`), and the two-phase flash-decoding strategy for single-token decode are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. + +#### Decode uses a dedicated GEMV kernel (`sequence_length == 1`) + +The tiled online-softmax SGEMM kernel (`MlasFlashAttentionGQAThreaded`) is used **only for +prefill** (`sequence_length > 1`), where each KV tile is reused across the `q_block_size` +query rows and tiling delivers real cache-locality and SGEMM packing benefits. + +For single-token decode the query tile has `M = 1`, so every K/V element is streamed +exactly once with no reuse across query rows. Tiling provides **no** cache-locality +benefit, and routing the `1 × T × H` work through `MlasSgemmOperation` pays the SGEMM +B-packing/setup cost on every call — which previously made the flash decode path *slower* +than the naive path (≈0.4–0.6x) for short-to-medium total sequence lengths. -### Block Sizes and Threading +Decode is therefore handled by a dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`), +dispatched whenever `sequence_length == 1` and flash decoding is not active. It +parallelizes over `(batch, head)` and, per head, computes the attention directly with two +matrix-vector products and a two-pass softmax: + +- **QK GEMV** — `scores[t] = scale · dot(q, K[t])` for `t ∈ [0, total_seqlen)`. +- two-pass softmax over `scores` using the dispatched `ReduceMaximumF32Kernel` / + `ComputeSumExpF32Kernel` helpers. +- **SV GEMV** — `out[h] = Σ_t probs[t] · V[t][h]`, then normalize by `1/Σ probs`. + +Both GEMV helpers (`MlasGQADecodeQK`, `MlasGQADecodeSV`) live in the baseline-ISA MLAS +translation unit, so their inner loops use independent accumulator lanes / map-style +updates that vectorize under SSE2 without `-ffast-math`. Decode needs no causal mask (the +single new token is the most recent position and attends to every cached token); only +optional local-window masking and additive attention bias are applied. The kernel streams +K and V exactly once each, so it is memory-bandwidth bound. + +The two-phase flash-decoding path (active when `batch × heads < threads`, KV partitioned +across idle threads) now also uses these GEMV helpers for its per-chunk QK and SV products +instead of `M = 1` SGEMM calls, removing the same packing overhead. -Block-size selection (`kv_block_size`, `q_block_size`), `(batch, head, q_block)` task partitioning, and the per-thread working buffer layout (`l`, `m`, `scores`, `temp_output`) are identical to the quantized path described above. The only difference is that the per-thread `temp_output` tile is accumulated directly by the SV SGEMM rather than via a fused dequantization. Because this path is prefill-only, it does not include the quantized path's two-phase flash-decoding strategy for single-token decode. ## MLAS Dispatch Paths @@ -513,13 +548,31 @@ offsets the intrinsic per-KV-block online-softmax overhead (running max/exp/outp The same advantage holds single-threaded (1.4\u20131.8x at threads=1), confirming the gain is algorithmic rather than purely from threading. -#### Decode (S = 1, token generation) +#### Latency — Decode (S = 1, token generation) + +For single-token decode at this head configuration (`batch\u00d7heads = 16 > threads = 8`, so +flash decoding KV-partitioning is not active), the workload per `Run` is tiny (a `1 × T × H` +GEMV pair per head) and operator-level latency is dominated by fixed per-`Run` overhead +(session dispatch, KV-cache concatenation), so operator-level measurements on the EPYC dev +box are extremely noisy. The numbers below come from a min-of-many-repeats MLAS-path harness +to suppress that jitter. -Single-token decode (`sequence_length == 1`) is **not** handled by the FP32 flash path; it falls -back to the naive path. Decode produces only a `[1, total_sequence_length]` score row per head, -so there is nothing to tile away, and the extra online-softmax bookkeeping made the flash kernel -slower and noisier in practice. Restricting the flash path to prefill (`sequence_length > 1`) keeps -the consistent prefill win without regressing decode. +| Total Seqlen | Naive (ms) | Flash (ms) | Speedup | +|---:|---:|---:|---:| +| 513 | 0.50 | 0.42 | ~1.0\u20131.2x (noisy) | +| 1025 | 0.78 | 0.69 | ~1.0\u20131.1x (noisy) | +| 2049 | 1.89 | 1.73 | ~1.0\u20131.1x (noisy) | +| 4097 | 6.1 | 4.5 | 1.35\u20131.5x | + +Decode is now handled by the dedicated GEMV kernel (`MlasGQADecodeGQAThreaded`) instead of +the prefill tiling kernel; see *Decode uses a dedicated GEMV kernel* above. Replacing the +per-head `M = 1` `MlasSgemmOperation` QK/SV calls with direct GEMVs removes the SGEMM +B-packing overhead that previously made flash decode noticeably **slower** than naive +(measured ≈0.4\u20130.6x across all lengths before the change). Flash decode is now at parity +for short/medium sequences (where the work is memory-bandwidth bound and overhead-dominated) +and consistently ahead for long contexts (T≥4097, ~1.4\u20131.5x) where the streamed +single-pass KV access wins. Short decode remains overhead-bound rather than algorithm-bound, +so it is not the target of the prefill-oriented causal early-termination optimization. ## Current CPU Limitations The current CPU GroupQueryAttention implementation has a few important limitations: diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 413483756ed5c..60fa4f0c4ada1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -1091,16 +1091,49 @@ class GQAAttentionBase { int thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); thread_count = std::max(thread_count, 1); - // Per-thread scratch: l + m + scores[q_block_size * kv_block_size] + temp_output[q_block_size * head_size] - const size_t buffer_size_per_thread = - (SafeInt(q_block_size) * 2 + // l + m - SafeInt(q_block_size) * kv_block_size + // scores - SafeInt(q_block_size) * head_size) * // temp_output - sizeof(float); - size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count; + // Flash decoding: for decode (sequence_length==1), partition KV across threads + // to improve parallelism when batch*heads < thread_count. This KV-split is only + // wired into the unified kernel (common_past_seqlen >= 0); the ragged/per-batch + // fallback runs the single-pass decode kernel instead, which needs a larger + // per-thread scratch (scores[total_seqlen] + temp_output[head_size]). Gating on + // common_past_seqlen >= 0 keeps the per-thread buffer sizing below consistent + // with the kernel that actually runs. + const int kv_chunk_count = (max_total_seqlen + kv_block_size - 1) / kv_block_size; + const bool use_flash_decoding = (sequence_length == 1 && + common_past_seqlen >= 0 && + batch_size * num_heads_ < thread_count && + kv_chunk_count > 1); + + size_t buffer_size_per_thread; + size_t partials_buffer_bytes = 0; + if (use_flash_decoding) { + // Flash decoding: per-thread scratch only needs scores[kv_block_size] + buffer_size_per_thread = SafeInt(kv_block_size) * sizeof(float); + // Partials: [batch * num_heads * kv_chunk_count * (2 + head_size)] floats + partials_buffer_bytes = SafeInt(batch_size) * num_heads_ * + kv_chunk_count * (2 + head_size) * sizeof(float); + } else if (sequence_length == 1) { + // Decode (GEMV kernel, no Q/KV tiling): per-thread scratch holds the full + // score row scores[total_seqlen] plus a temp output accumulator[head_size]. + buffer_size_per_thread = + (SafeInt(max_total_seqlen) + head_size) * sizeof(float); + } else { + buffer_size_per_thread = + (SafeInt(q_block_size) * 2 + // l + m + SafeInt(q_block_size) * kv_block_size + // scores + SafeInt(q_block_size) * head_size) * // temp_output + sizeof(float); + } + size_t total_buffer_bytes = SafeInt(buffer_size_per_thread) * thread_count + partials_buffer_bytes; auto flash_buffer_alloc = allocator->Alloc(total_buffer_bytes); BufferUniquePtr flash_buffer(flash_buffer_alloc, BufferDeleter(allocator)); + // Partials buffer is placed after per-thread scratch + float* partials_ptr = use_flash_decoding + ? reinterpret_cast(reinterpret_cast(flash_buffer_alloc) + + buffer_size_per_thread * thread_count) + : nullptr; + const float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; // If all batch items share the same past_seqlen, use the unified flash kernel. @@ -1133,6 +1166,8 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = attention_bias_broadcast_batch; args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = partials_ptr; + args.kv_chunk_count = kv_chunk_count; MlasFlashAttentionGQA(&args, tp); } else { @@ -1185,6 +1220,8 @@ class GQAAttentionBase { args.attention_bias_seqlen_stride = attention_bias_seqlen_stride; args.attention_bias_broadcast_batch = true; // batch offset handled above args.attention_bias_broadcast_head = attention_bias_broadcast_head; + args.flash_decoding_partials = nullptr; // per-batch doesn't use flash decoding + args.kv_chunk_count = 0; MlasFlashAttentionGQA(&args, tp); } diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 692a5759d1adc..debda282eb4f1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -348,13 +348,13 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // kernel to avoid materializing the full attention score matrix. Falls back to the // naive path when an unsupported feature is requested (softcap, smooth softmax, // head sink, or QK output). + // + // Prefill (sequence_length > 1) uses the tiled kernel; single-token decode + // (sequence_length == 1 with total_sequence_length > 1) uses the dedicated GEMV + // decode kernel. Both are reached when total_sequence_length > 1. if constexpr (std::is_same_v) { - // Restrict the flash path to prefill / chunked-prefill (query length > 1). Single-token - // decode (sequence_length == 1) has no flash benefit: the naive score matrix is only - // [1, total_sequence_length] per head, so there is nothing to tile away, and the extra - // online-softmax bookkeeping makes it slower in practice. const bool use_flash = !disable_gqa_flash_ && - parameters.sequence_length > 1 && + parameters.total_sequence_length > 1 && softcap_ == 0.0f && !use_smooth_softmax_ && head_sink_data == nullptr && diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 2410fcc83e7cd..cbca2d85a97a4 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2302,9 +2302,9 @@ MlasFlashAttention( // // Adapts the online-softmax tiled algorithm to operate on an FP32 present // K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]). -// Supports GQA head grouping, causal masking, local window attention, and -// additive attention bias. Intended for prefill / chunked-prefill -// (sequence_length > 1). +// Supports GQA head grouping, causal masking, local window attention, +// additive attention bias, and an optional flash-decoding split over the KV +// sequence dimension for the single-token decode case. // struct MlasFlashAttentionGQAArgs { int batch_size; @@ -2320,7 +2320,7 @@ struct MlasFlashAttentionGQAArgs { int kv_block_size; // key/value tile size (Bc) float scale; // QK scaling factor int thread_count; // number of partitions / threads - float* buffer; // per-thread scratch + float* buffer; // per-thread scratch (+ optional flash-decoding partials) size_t buffer_size_per_thread; const float* query; // [batch, num_heads, sequence_length, head_size] BNSH @@ -2334,6 +2334,12 @@ struct MlasFlashAttentionGQAArgs { int attention_bias_seqlen_stride; bool attention_bias_broadcast_batch; bool attention_bias_broadcast_head; + + // Flash decoding (sequence_length == 1): partition KV across threads. + // Set flash_decoding_partials != nullptr to enable; otherwise the standard + // per-(batch, head, q_block) partitioning is used. + float* flash_decoding_partials; + int kv_chunk_count; }; /** diff --git a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp index 4d0ff65733a44..0f1210ca1e3c5 100644 --- a/onnxruntime/core/mlas/lib/flashattn_gqa.cpp +++ b/onnxruntime/core/mlas/lib/flashattn_gqa.cpp @@ -17,11 +17,14 @@ Module Name: an FP32 present K/V cache laid out as BNSH ([batch, kv_num_heads, seqlen_present, head_size]) and to support GQA head grouping (num_heads % kv_num_heads == 0), causal masking, local window - attention, and additive attention bias. Intended for prefill / - chunked-prefill (sequence_length > 1). + attention, additive attention bias, and an optional flash-decoding split + over the KV sequence dimension for single-token decode. - QK^T and S*V use the single-threaded SGEMM primitive MlasSgemmOperation; - the outer parallelism is provided by MlasExecuteThreaded. + For multi-token prefill (sequence_length > 1) QK^T and S*V use the + single-threaded SGEMM primitive MlasSgemmOperation. For single-token decode + (sequence_length == 1, including the flash-decoding KV split) the M == 1 + GEMVs use the local MlasGQADecodeQK / MlasGQADecodeSV helpers to avoid SGEMM + packing overhead. The outer parallelism is provided by MlasExecuteThreaded. --*/ @@ -32,6 +35,71 @@ Module Name: #include "mlasi.h" +// +// Decode (sequence_length == 1) GEMV helpers. +// +// With a single query token the QK^T and S*V products degenerate into +// matrix-vector products. Computing them directly streams the K and V cache +// exactly once and avoids the SGEMM B-packing overhead that otherwise dominates +// the tiny M = 1 GEMMs. These helpers live in the baseline-ISA MLAS translation +// unit, so the inner loops are written with independent accumulator lanes and a +// map-style update so the compiler can vectorize them without -ffast-math +// (which would be required to reassociate a plain scalar float reduction). +// + +// QK^T GEMV: scores[t] = scale * dot(q[0..H), K[t*H .. t*H+H)) for t in [0, n_kv). +static void +MlasGQADecodeQK( + const float* q, + const float* k_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float scale, + float* scores +) +{ + constexpr std::ptrdiff_t kLanes = 8; + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float* krow = k_cache + t * head_size; + float acc[kLanes] = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}; + std::ptrdiff_t h = 0; + for (; h + kLanes <= head_size; h += kLanes) { + for (std::ptrdiff_t j = 0; j < kLanes; ++j) { + acc[j] += q[h + j] * krow[h + j]; + } + } + float sum = ((acc[0] + acc[1]) + (acc[2] + acc[3])) + + ((acc[4] + acc[5]) + (acc[6] + acc[7])); + for (; h < head_size; ++h) { + sum += q[h] * krow[h]; + } + scores[t] = sum * scale; + } +} + +// S*V GEMV (accumulate): out[h] = sum_t probs[t] * V[t*H + h] for h in [0, head_size). +// `out` is overwritten (initialized to zero) before accumulation. +static void +MlasGQADecodeSV( + const float* probs, + const float* v_cache, + std::ptrdiff_t n_kv, + std::ptrdiff_t head_size, + float* out +) +{ + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] = 0.0f; + } + for (std::ptrdiff_t t = 0; t < n_kv; ++t) { + const float p = probs[t]; + const float* vrow = v_cache + t * head_size; + for (std::ptrdiff_t h = 0; h < head_size; ++h) { + out[h] += p * vrow[h]; + } + } +} + void MlasFlashAttentionGQAThreaded( void* argptr, @@ -294,6 +362,415 @@ MlasFlashAttentionGQAThreaded( } } +// +// Flash Decoding: Phase 1 - parallel partial attention over (batch, head, kv_chunk). +// Each task computes attention for one KV chunk and stores (m, l, partial_output) +// into the partials buffer. +// +void +MlasFlashDecodingGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t past_seqlen = static_cast(args->past_seqlen); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + // Partials layout per entry: [m, l, output[head_size]] + const ptrdiff_t partial_stride = 2 + head_size; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // Total tasks: (batch, head, kv_chunk) + const ptrdiff_t total_task_count = batch_size * num_heads * kv_chunk_count; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + // Decompose task_index into (batch_idx, head_idx, kv_chunk_idx) + ptrdiff_t tmp = task_index; + ptrdiff_t kv_chunk_idx = tmp % kv_chunk_count; + tmp /= kv_chunk_count; + ptrdiff_t head_idx = tmp % num_heads; + ptrdiff_t batch_idx = tmp / num_heads; + + // Per-thread scratch buffer: just scores[kv_block_size] + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + + // KV block range for this chunk + const ptrdiff_t ir = kv_chunk_idx * kv_block_size; + const size_t row_size_kv = static_cast(std::min(kv_block_size, total_seqlen - ir)); + + // Determine KV head index for GQA head sharing + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + + // K/V cache pointers + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size] (sequence_length=1). + // The batch stride is supplied separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV for this KV chunk (M = 1) + const float* k_block = k_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeQK(q_ptr, k_block, static_cast(row_size_kv), head_size, scale, scores); + + // Step 1b: Apply attention bias if present + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_seqlen_stride = + static_cast(args->attention_bias_seqlen_stride); + const ptrdiff_t bias_matrix_size = bias_seqlen_stride; // S=1 + // The bias tensor has shape [batch|1, num_heads|1, S, T]; the batch stride + // uses the actual head extent (1 when the head dim is broadcast). + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * + bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset + ir; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + scores[jcol] += bias_row[jcol]; + } + } + + // Step 2: Apply causal mask + const ptrdiff_t global_q_pos = past_seqlen; // sequence_length=1, q_idx=0 + const ptrdiff_t causal_limit = global_q_pos + 1; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos >= causal_limit) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + + // Apply local window masking if enabled + if (local_window_size >= 0) { + const ptrdiff_t window_start = + (causal_limit > local_window_size) ? (causal_limit - local_window_size) : 0; + for (ptrdiff_t jcol = 0; jcol < static_cast(row_size_kv); ++jcol) { + ptrdiff_t kv_pos = ir + jcol; + if (kv_pos < window_start) { + scores[jcol] = std::numeric_limits::lowest(); + } + } + } + + // Step 3: Compute local softmax statistics (m, l) and exp scores +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, row_size_kv); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, row_size_kv); +#endif + + // Pointer to this task's partial in the partials buffer + const ptrdiff_t partial_index = + (batch_idx * num_heads + head_idx) * kv_chunk_count + kv_chunk_idx; + float* partial = args->flash_decoding_partials + partial_index * partial_stride; + float* partial_m = partial; + float* partial_l = partial + 1; + float* partial_output = partial + 2; + + if (rowmax == std::numeric_limits::lowest()) { + // Entire chunk is masked: store sentinel + *partial_m = std::numeric_limits::lowest(); + *partial_l = 0.0f; + memset(partial_output, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + *partial_m = rowmax; + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, row_size_kv, &negmax); +#endif + *partial_l = rowsum; + + // Step 4: S_exp * V_block -> partial_output (M = 1) + const float* v_block = v_cache_head + static_cast(ir) * static_cast(head_size); + MlasGQADecodeSV(scores, v_block, static_cast(row_size_kv), head_size, partial_output); + } +} + +// +// Flash Decoding: Phase 2 - reduce partials for each (batch, head) into final output. +// +void +MlasFlashDecodingGQAReduceThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t kv_chunk_count = static_cast(args->kv_chunk_count); + const ptrdiff_t thread_count = static_cast(args->thread_count); + const ptrdiff_t partial_stride = 2 + head_size; + + // Total reduction tasks: one per (batch, head) + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t head_idx = task_index % num_heads; + ptrdiff_t batch_idx = task_index / num_heads; + + // Pointer to this (batch, head)'s partials: kv_chunk_count entries + const float* partials_base = args->flash_decoding_partials + + task_index * kv_chunk_count * partial_stride; + + // Find global max across all chunks + float global_m = std::numeric_limits::lowest(); + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + float chunk_m = partials_base[c * partial_stride]; + global_m = std::max(global_m, chunk_m); + } + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + static_cast(batch_idx) * static_cast(num_heads) * static_cast(head_size) + + static_cast(head_idx) * static_cast(head_size); + + // If all chunks are masked, output zeros + if (global_m == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + // Accumulate rescaled outputs and l values + float global_l = 0.0f; + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + + for (ptrdiff_t c = 0; c < kv_chunk_count; ++c) { + const float* partial = partials_base + c * partial_stride; + float chunk_m = partial[0]; + float chunk_l = partial[1]; + const float* chunk_output = partial + 2; + + if (chunk_l <= 0.0f) { + continue; // masked chunk contributes nothing + } + + float rescale = std::exp(chunk_m - global_m); + global_l += rescale * chunk_l; + + // partial_output = S_exp * V where sum(S_exp) = l_c (unnormalized). + // Rescale by exp(m_c - global_m) to align all chunks to the same max. + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] += rescale * chunk_output[i]; + } + } + + // output = sum_c(rescale_c * partial_output_c) / global_l + float inv_l = (global_l > 0.0f) ? (1.0f / global_l) : 0.0f; + for (ptrdiff_t i = 0; i < head_size; ++i) { + output_ptr[i] *= inv_l; + } + } +} + +// +// Decode kernel for sequence_length == 1 without KV-split (batch * heads >= +// thread_count). Parallelizes over (batch, head); each task attends the single +// query token to the whole KV cache with a pair of GEMVs and a two-pass softmax. +// Decode needs no causal masking (the single new token is the most recent +// position and attends to every cached token); only optional local-window +// masking and additive bias are applied. +// +void +MlasGQADecodeGQAThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionGQAArgs* args = + reinterpret_cast(argptr); + + const ptrdiff_t batch_size = static_cast(args->batch_size); + const ptrdiff_t num_heads = static_cast(args->num_heads); + const ptrdiff_t kv_num_heads = static_cast(args->kv_num_heads); + const ptrdiff_t total_seqlen = static_cast(args->total_seqlen); + const ptrdiff_t head_size = static_cast(args->head_size); + const ptrdiff_t local_window_size = static_cast(args->local_window_size); + const float scale = args->scale; + + float* buffer = args->buffer; + const ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + const ptrdiff_t thread_count = static_cast(args->thread_count); + + const size_t kv_num_heads_factor = static_cast(num_heads / kv_num_heads); + const size_t kv_head_stride = + static_cast(args->seqlen_present_kv) * static_cast(head_size); + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + // One task per (batch, head). + const ptrdiff_t total_task_count = batch_size * num_heads; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + // Local-window low bound: decode can attend to KV positions [window_start, total_seqlen). + // causal_limit == past_seqlen + 1 == total_seqlen for the single new token. + const ptrdiff_t window_start = + (local_window_size >= 0 && total_seqlen > local_window_size) ? (total_seqlen - local_window_size) : 0; + + // Per-thread scratch: scores[total_seqlen] followed by temp_output[head_size]. + char* buffer_ptr = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* scores = reinterpret_cast(buffer_ptr); + float* temp_output = scores + total_seqlen; + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + const ptrdiff_t head_idx = task_index % num_heads; + const ptrdiff_t batch_idx = task_index / num_heads; + + // KV head index for GQA head sharing. + const size_t kv_head_idx = static_cast(head_idx) / kv_num_heads_factor; + const size_t kv_batch_head_offset = + (static_cast(batch_idx) * static_cast(kv_num_heads) + kv_head_idx) * + kv_head_stride; + const float* k_cache_head = args->k_cache + kv_batch_head_offset; + const float* v_cache_head = args->v_cache + kv_batch_head_offset; + + // Q pointer: layout [batch, num_heads, 1, head_size]; batch stride supplied + // separately to support packed-QKV input. + const float* q_ptr = args->query + + static_cast(batch_idx) * args->q_batch_stride + + static_cast(head_idx) * static_cast(head_size); + + // Step 1: QK^T GEMV -> scores[0..T) + MlasGQADecodeQK(q_ptr, k_cache_head, total_seqlen, head_size, scale, scores); + + // Step 1b: additive attention bias (shape [batch|1, num_heads|1, S=1, T]). + if (args->attention_bias != nullptr) { + const ptrdiff_t bias_matrix_size = + static_cast(args->attention_bias_seqlen_stride); // S == 1 + const ptrdiff_t bias_head_extent = + args->attention_bias_broadcast_head ? 1 : static_cast(num_heads); + ptrdiff_t bias_offset = 0; + if (!args->attention_bias_broadcast_batch) { + bias_offset += static_cast(batch_idx) * bias_head_extent * bias_matrix_size; + } + if (!args->attention_bias_broadcast_head) { + bias_offset += static_cast(head_idx) * bias_matrix_size; + } + const float* bias_row = args->attention_bias + bias_offset; + for (ptrdiff_t t = 0; t < total_seqlen; ++t) { + scores[t] += bias_row[t]; + } + } + + // Step 2: local-window masking (no causal mask needed for decode). + if (window_start > 0) { + for (ptrdiff_t t = 0; t < window_start; ++t) { + scores[t] = std::numeric_limits::lowest(); + } + } + + // Step 3: softmax over scores[0..T). +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(scores, total_seqlen); +#else + float rowmax = MlasReduceMaximumF32Kernel(scores, total_seqlen); +#endif + + // Output layout: [batch, sequence_length=1, num_heads, head_size] + float* output_ptr = args->output + + (static_cast(batch_idx) * static_cast(num_heads) + + static_cast(head_idx)) * static_cast(head_size); + + if (rowmax == std::numeric_limits::lowest()) { + memset(output_ptr, 0, static_cast(head_size) * sizeof(float)); + continue; + } + + float negmax = -rowmax; +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(scores, scores, total_seqlen, &negmax); +#endif + + // Step 4: S_exp * V GEMV -> temp_output, then normalize by 1/l. + MlasGQADecodeSV(scores, v_cache_head, total_seqlen, head_size, temp_output); + + const float inv_l = (rowsum > 0.0f) ? (1.0f / rowsum) : 0.0f; + for (ptrdiff_t h = 0; h < head_size; ++h) { + output_ptr[h] = temp_output[h] * inv_l; + } + } +} + void MLASCALL MlasFlashAttentionGQA( @@ -301,10 +778,41 @@ MlasFlashAttentionGQA( MLAS_THREADPOOL* ThreadPool ) { - MlasExecuteThreaded( - MlasFlashAttentionGQAThreaded, - static_cast(args), - static_cast(args->thread_count), - ThreadPool - ); + if (args->sequence_length == 1) { + // Decode: M = 1, use the GEMV kernels (no SGEMM packing overhead). + if (args->flash_decoding_partials != nullptr) { + // Flash decoding: two-phase approach when KV is partitioned across threads. + // Phase 1: parallel partial computation over (batch, head, kv_chunk). + MlasExecuteThreaded( + MlasFlashDecodingGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + // Phase 2: reduce partials into final output (parallel over batch*heads). + MlasExecuteThreaded( + MlasFlashDecodingGQAReduceThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } else { + // Single-pass decode parallelized over (batch, head). + MlasExecuteThreaded( + MlasGQADecodeGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } + } else { + // Prefill (sequence_length > 1): tiled online-softmax SGEMM kernel. + MlasExecuteThreaded( + MlasFlashAttentionGQAThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool + ); + } } + diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 19968db98edd7..c438790ab5950 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -19,6 +19,7 @@ import torch from bert_padding import pad_input, unpad_input from einops import rearrange, repeat +from env_var_helper import scoped_env_var from onnx import TensorProto, helper from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -2645,6 +2646,69 @@ def test_gqa_past(self): qk_output, ) + def test_gqa_decode_flash_vs_naive_parity(self): + # The FP32 flash gate enables the dedicated GEMV decode kernel (and the + # flash-decoding KV-split reduction) for sequence_length == 1 with + # total_sequence_length > 1. Run the same decode configs against the + # reference twice: once with the flash path enabled (default) and once + # with it disabled via ORT_GQA_DISABLE_FLASH_ATTENTION=1 (naive path). + # If both paths match the reference, the decode kernel and KV-split + # reduction are correct -- including the bias and local-window cases. + print("-------- TEST GQA DECODE FLASH VS NAIVE PARITY ---------") + + # FP32 only: the GEMV decode kernel and flash gate are float-only. + torch_type = torch.float32 + numpy_type = numpy.float32 + ort_type = TensorProto.FLOAT + rtol = 1e-3 + atol = 1e-3 + + batches = [1, 3] + # (sequence_length == 1) decode. Include a long KV length so that the + # flash-decoding KV-split path (kv_chunk_count > 1) is exercised. + seqs = [(1, 128), (1, 2048)] + num_h = [(9, 3)] + h_sizes = [64, 128] + + # "0" keeps the flash path enabled; "1" forces the naive path. Reseed per + # phase so both paths are validated against the reference on identical + # inputs, independent of test execution order. + for env_value in ["0", "1"]: + with scoped_env_var("ORT_GQA_DISABLE_FLASH_ATTENTION", env_value): + print(f" flash {'disabled (naive path)' if env_value == '1' else 'enabled'}") + random.seed(69) + torch.manual_seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for has_attn in [False, True]: + config = Config( + b, + s, + s2, + 0, + n, + n2, + h, + False, + has_attn, + False, + QKOutputType.NO_OUTPUT, + ) + all_close = parity_check_gqa_past( + config, + torch_type=torch_type, + numpy_type=numpy_type, + ort_type=ort_type, + local=local, + past_format=Formats.BNSH, + rtol=rtol, + atol=atol, + ) + self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): print("-------- TEST GQA INTERACTIVE ---------") batches = [1] From 0271287d570c633c477288364d1b4aaca4ac9778 Mon Sep 17 00:00:00 2001 From: Sayan Shaw <52221015+sayanshaw24@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:46:23 -0700 Subject: [PATCH 09/10] Fix unbounded lifetime on WithOutputTensor in Rust bindings (#29251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Fix unbounded lifetime on WithOutputTensor in Rust bindings ## Description The `WithOutputTensor<'a, T>` struct had a free lifetime parameter `'a` on its `TryFrom` impl that was unconstrained by any input. Combined with the `Deref` impl (whose `Target = ArrayView<'a, T, IxDyn>` exposed a `Clone`-able view), it was possible for the `ArrayView` to outlive the underlying `OrtOutputTensor` buffer owner. This change restructures `WithOutputTensor` to eliminate the unbounded lifetime: - Removes the `'a` lifetime parameter from `WithOutputTensor`, `OrtOutput`, and `Session::run` - Removes the `Deref` impl (the escape hatch) - Replaces the stored `ArrayView<'a, T>` with a raw pointer + shape - Adds a `view(&self)` method returning `ArrayView<'_, T, IxDyn>` — the view lifetime is now tied to `&self` - Updates all call sites (examples, integration tests) to use `.view()` ## Motivation The C API contract (`onnxruntime_c_api.h`) explicitly bounds the data pointer lifetime to the `OrtValue`: the pointer is only valid until the value is destroyed. The Rust type system must enforce this invariant. Previously it did not — the `ArrayView` could be cloned out and observed after the `OrtValue` was freed. ## API Change ```rust // Before: Deref-based access let output = outputs[0].float_array().unwrap(); let sum: f32 = output.iter().sum(); // After: explicit view() call let output = outputs[0].float_array().unwrap(); let sum: f32 = output.view().iter().sum(); ``` ## Testing Existing integration tests updated to use the new `view()` API. The fix is enforced at compile time by the borrow checker — the previously problematic pattern now produces a lifetime error. Co-authored-by: Sayan Shaw --- rust/onnxruntime/examples/issue22.rs | 2 +- rust/onnxruntime/examples/sample.rs | 5 +- rust/onnxruntime/src/session.rs | 6 +- .../src/tensor/ort_output_tensor.rs | 117 ++++++++++-------- rust/onnxruntime/tests/integration_tests.rs | 6 +- 5 files changed, 80 insertions(+), 56 deletions(-) diff --git a/rust/onnxruntime/examples/issue22.rs b/rust/onnxruntime/examples/issue22.rs index 6c96e899fa774..1fb7fe28ff123 100644 --- a/rust/onnxruntime/examples/issue22.rs +++ b/rust/onnxruntime/examples/issue22.rs @@ -51,5 +51,5 @@ fn main() { let outputs = session.run(inputs).unwrap(); - print!("outputs: {:#?}", outputs[0].float_array().unwrap()); + print!("outputs: {:#?}", outputs[0].float_array().unwrap().view()); } diff --git a/rust/onnxruntime/examples/sample.rs b/rust/onnxruntime/examples/sample.rs index 9af5cf733ccae..b6f351b2082ed 100644 --- a/rust/onnxruntime/examples/sample.rs +++ b/rust/onnxruntime/examples/sample.rs @@ -73,10 +73,11 @@ fn run() -> Result<(), Error> { let outputs = session.run(input_tensor_values)?; let output = outputs[0].float_array().unwrap(); + let view = output.view(); - assert_eq!(output.shape(), output0_shape.as_slice()); + assert_eq!(view.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, view[[0, i, 0, 0]]); } Ok(()) diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs index 326426e35982c..d475d1b724111 100644 --- a/rust/onnxruntime/src/session.rs +++ b/rust/onnxruntime/src/session.rs @@ -410,10 +410,10 @@ impl Session { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'input, 'output>( - &'output self, + pub fn run<'input>( + &self, mut input_arrays: impl AsMut<[Box]> + 'input, - ) -> Result>> { + ) -> Result> { let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; diff --git a/rust/onnxruntime/src/tensor/ort_output_tensor.rs b/rust/onnxruntime/src/tensor/ort_output_tensor.rs index 83663c0d303f8..727cae1db0ef4 100644 --- a/rust/onnxruntime/src/tensor/ort_output_tensor.rs +++ b/rust/onnxruntime/src/tensor/ort_output_tensor.rs @@ -71,22 +71,41 @@ impl Drop for OrtOutputTensor { } /// An Output tensor with the ptr and the item that will copy from the ptr. -#[derive(Debug)] -pub struct WithOutputTensor<'a, T> { - #[allow(dead_code)] +/// +/// The view is materialized on each access via [`view()`](Self::view) to ensure the +/// borrowed lifetime is tied to `&self`, preventing the view from outliving the +/// underlying buffer owned by the `OrtOutputTensor`. +pub struct WithOutputTensor { pub(crate) tensor: OrtOutputTensor, - item: ArrayView<'a, T, ndarray::IxDyn>, + data_ptr: *const T, + shape: Vec, } -impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> { - type Target = ArrayView<'a, T, ndarray::IxDyn>; +impl Debug for WithOutputTensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("WithOutputTensor") + .field("tensor", &self.tensor) + .field("data_ptr", &self.data_ptr) + .field("shape", &self.shape) + .finish() + } +} - fn deref(&self) -> &Self::Target { - &self.item +// SAFETY: The data pointer is derived from OrtOutputTensor which owns the allocation. +// Access is only possible through &self (via view()), so Send/Sync follow from T: Send/Sync. +unsafe impl Send for WithOutputTensor {} +unsafe impl Sync for WithOutputTensor {} + +impl WithOutputTensor { + /// Returns an [`ArrayView`] over the output tensor data. + /// + /// The returned view borrows `self`, so it cannot outlive the tensor owner. + pub fn view(&self) -> ArrayView<'_, T, ndarray::IxDyn> { + unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&self.shape), self.data_ptr) } } } -impl<'a, T> TryFrom for WithOutputTensor<'a, T> +impl TryFrom for WithOutputTensor where T: TypeToTensorElementDataType, { @@ -110,45 +129,45 @@ where status_to_result(status).map_err(OrtError::IsTensor)?; assert_ne!(output_array_ptr, std::ptr::null_mut()); - let array_view = - unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&value.shape), output_array_ptr) }; + let shape = value.shape.clone(); Ok(WithOutputTensor { tensor: value, - item: array_view, + data_ptr: output_array_ptr, + shape, }) } } /// The onnxruntime Run output type. -pub enum OrtOutput<'a> { +pub enum OrtOutput { /// Tensor of f32s - Float(WithOutputTensor<'a, f32>), + Float(WithOutputTensor), /// Tensor of f64s - Double(WithOutputTensor<'a, f64>), + Double(WithOutputTensor), /// Tensor of u8s - UInt8(WithOutputTensor<'a, u8>), + UInt8(WithOutputTensor), /// Tensor of u16s - UInt16(WithOutputTensor<'a, u16>), + UInt16(WithOutputTensor), /// Tensor of u32s - UInt32(WithOutputTensor<'a, u32>), + UInt32(WithOutputTensor), /// Tensor of u64s - UInt64(WithOutputTensor<'a, u64>), + UInt64(WithOutputTensor), /// Tensor of i8s - Int8(WithOutputTensor<'a, i8>), + Int8(WithOutputTensor), /// Tensor of i16s - Int16(WithOutputTensor<'a, i16>), + Int16(WithOutputTensor), /// Tensor of i32s - Int32(WithOutputTensor<'a, i32>), + Int32(WithOutputTensor), /// Tensor of i64s - Int64(WithOutputTensor<'a, i64>), + Int64(WithOutputTensor), /// Tensor of Strings - String(WithOutputTensor<'a, String>), + String(WithOutputTensor), } -impl<'a> OrtOutput<'a> { - /// Return `WithOutputTensor<'a, f32>` which derefs into an `ArrayView`. - pub fn float_array(&self) -> Option<&WithOutputTensor<'a, f32>> { +impl OrtOutput { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn float_array(&self) -> Option<&WithOutputTensor> { if let Self::Float(item) = self { Some(item) } else { @@ -156,8 +175,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, f64>` which derefs into an `ArrayView`. - pub fn double_array(&self) -> Option<&WithOutputTensor<'a, f64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn double_array(&self) -> Option<&WithOutputTensor> { if let Self::Double(item) = self { Some(item) } else { @@ -165,8 +184,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u8>` which derefs into an `ArrayView`. - pub fn uint8_array(&self) -> Option<&WithOutputTensor<'a, u8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint8_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt8(item) = self { Some(item) } else { @@ -174,8 +193,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u16>` which derefs into an `ArrayView`. - pub fn uint16_array(&self) -> Option<&WithOutputTensor<'a, u16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint16_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt16(item) = self { Some(item) } else { @@ -183,8 +202,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u32>` which derefs into an `ArrayView`. - pub fn uint32_array(&self) -> Option<&WithOutputTensor<'a, u32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint32_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt32(item) = self { Some(item) } else { @@ -192,8 +211,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, u64>` which derefs into an `ArrayView`. - pub fn uint64_array(&self) -> Option<&WithOutputTensor<'a, u64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn uint64_array(&self) -> Option<&WithOutputTensor> { if let Self::UInt64(item) = self { Some(item) } else { @@ -201,8 +220,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i8>` which derefs into an `ArrayView`. - pub fn int8_array(&self) -> Option<&WithOutputTensor<'a, i8>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int8_array(&self) -> Option<&WithOutputTensor> { if let Self::Int8(item) = self { Some(item) } else { @@ -210,8 +229,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i16>` which derefs into an `ArrayView`. - pub fn int16_array(&self) -> Option<&WithOutputTensor<'a, i16>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int16_array(&self) -> Option<&WithOutputTensor> { if let Self::Int16(item) = self { Some(item) } else { @@ -219,8 +238,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i32>` which derefs into an `ArrayView`. - pub fn int32_array(&self) -> Option<&WithOutputTensor<'a, i32>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int32_array(&self) -> Option<&WithOutputTensor> { if let Self::Int32(item) = self { Some(item) } else { @@ -228,8 +247,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, i64>` which derefs into an `ArrayView`. - pub fn int64_array(&self) -> Option<&WithOutputTensor<'a, i64>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn int64_array(&self) -> Option<&WithOutputTensor> { if let Self::Int64(item) = self { Some(item) } else { @@ -237,8 +256,8 @@ impl<'a> OrtOutput<'a> { } } - /// Return `WithOutputTensor<'a, String>` which derefs into an `ArrayView`. - pub fn string_array(&self) -> Option<&WithOutputTensor<'a, String>> { + /// Return `WithOutputTensor` which provides a `view()` method for an `ArrayView`. + pub fn string_array(&self) -> Option<&WithOutputTensor> { if let Self::String(item) = self { Some(item) } else { @@ -247,10 +266,10 @@ impl<'a> OrtOutput<'a> { } } -impl<'a> TryFrom for OrtOutput<'a> { +impl TryFrom for OrtOutput { type Error = OrtError; - fn try_from(value: OrtOutputTensor) -> Result> { + fn try_from(value: OrtOutputTensor) -> Result { unsafe { let mut shape_info = std::ptr::null_mut(); diff --git a/rust/onnxruntime/tests/integration_tests.rs b/rust/onnxruntime/tests/integration_tests.rs index 7843fe269e5e4..1c096400eccf7 100644 --- a/rust/onnxruntime/tests/integration_tests.rs +++ b/rust/onnxruntime/tests/integration_tests.rs @@ -112,6 +112,7 @@ mod download { // and iterate on resulting probabilities, creating an index to later access labels. let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -209,6 +210,7 @@ mod download { let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -301,6 +303,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -398,6 +401,7 @@ mod download { let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .iter() .copied() @@ -515,7 +519,7 @@ mod download { let output = outputs[0].float_array().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } From 5f49a372d6aa990b23d516e2ff13dde5849f1895 Mon Sep 17 00:00:00 2001 From: Sanaa Hamel Date: Fri, 26 Jun 2026 12:58:09 -0400 Subject: [PATCH 10/10] fix(ci): incorrect identity for azcopy (#29274) ### Description Use new GitHub CI identity for azcopy. ### Motivation and Context GitHub CI pools have been assigned a new identity. --- .github/workflows/windows_cuda.yml | 4 ++-- .github/workflows/windows_cuda_plugin.yml | 4 ++-- .github/workflows/windows_gpu_doc_gen.yml | 2 +- .github/workflows/windows_openvino.yml | 2 +- .github/workflows/windows_qnn_x64.yml | 2 +- .github/workflows/windows_tensorrt.yml | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index c6346d9c4e932..9776c22fedabb 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -148,7 +148,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU CUDA CI Pipeline Test Job @@ -260,4 +260,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_cuda_plugin.yml b/.github/workflows/windows_cuda_plugin.yml index 538c1d783cd68..c2d84d7be482a 100644 --- a/.github/workflows/windows_cuda_plugin.yml +++ b/.github/workflows/windows_cuda_plugin.yml @@ -118,7 +118,7 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows CUDA Plugin EP Test @@ -214,4 +214,4 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e diff --git a/.github/workflows/windows_gpu_doc_gen.yml b/.github/workflows/windows_gpu_doc_gen.yml index 5e50a970875fc..b41de5542b0c6 100644 --- a/.github/workflows/windows_gpu_doc_gen.yml +++ b/.github/workflows/windows_gpu_doc_gen.yml @@ -44,7 +44,7 @@ jobs: setVcvars: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e runs-on: [ "self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10", diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index 52581c7d0a5f5..d87a4919a86fd 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -26,7 +26,7 @@ jobs: timeout-minutes: 240 env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e OnnxRuntimeBuildDirectory: ${{ github.workspace }} DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_qnn_x64.yml b/.github/workflows/windows_qnn_x64.yml index 35c620ca6f650..5560de89040de 100644 --- a/.github/workflows/windows_qnn_x64.yml +++ b/.github/workflows/windows_qnn_x64.yml @@ -29,7 +29,7 @@ jobs: QnnLibKind: [shared_lib, static_lib] env: AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index d5710795942d1..3ad0076de6d52 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -154,7 +154,7 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e test: name: Windows GPU TensorRT CI Pipeline Test Job @@ -265,4 +265,4 @@ jobs: DocUpdateNeeded: false ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + AZCOPY_MSI_CLIENT_ID: d712a4c7-a0cf-4e87-af75-31510eba0a8e