From 591df5b11914e5ce4636b378e8b3c871da3ec695 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jun 2026 16:43:28 +0800 Subject: [PATCH 1/7] webgpu: fix GQA batched right-padded prefill with do_rotary When GenAI runs a batched prefill with prompts of unequal lengths, short prompts are right-padded up to the batch max sequence_length and each batch's real length is reported via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary embedding shaders computed past_seqlen = (seqlens_k[b]+1) - sequence_length per batch, which underflowed u32 for any batch shorter than sequence_length. The resulting astronomically large position_id indexed past the cos/sin caches and produced garbage rotated Q/K, which manifested as gibberish output text for the shorter batches in the batch. Clamp past_seqlen to 0 in all three rotary embedding shaders: RotaryEmbeddingProgram (seqlens variant), FusedQKRotaryEmbeddingProgram, and the split_packed_qkv_with_rotary_embedding template. Also extend CanApplyFlashAttention to bypass FlashAttention for batched cases with per-batch seqlens (which exercise the unpatched and-copykv variant), while still allowing it for shared-KV layers where it is mandatory. Adds a regression test exercising the packed-QKV do_rotary path with three batches of unequal real lengths. --- .../webgpu/bert/flash_attention.cc | 6 +- .../contrib_ops/webgpu/bert/flash_attention.h | 2 +- .../webgpu/bert/group_query_attention.cc | 2 +- .../webgpu/bert/rotary_embedding.cc | 6 +- ...ed_qkv_with_rotary_embedding.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 193 ++++++++++++++++++ 6 files changed, 205 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 02e764d01e05e..9be6a047cea9c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -606,8 +606,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - return !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { + const bool kv_empty = parameters.kv_sequence_length_ == 0; + return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && + !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3da6b33b4dc0e..218baf926173f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -205,7 +205,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr, const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr); -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr); // Split packed QKV with Q/K rotary embedding and copy KV cache fusion Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 930cb296122ce..36d688c9723fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -350,7 +350,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); } if (kv_empty) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 1b11a69de7824..b4fbdc555a6d5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -44,7 +44,8 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { << " let seqlen_i = " << position_ids_or_seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" " let total_seqlen = seqlen + 1u;\n" - " let past_seqlen = total_seqlen - uniforms.global_shape[1];\n" + " // Right-padded batches with prompt shorter than global_shape[1] would underflow u32; clamp to 0.\n" + " let past_seqlen = select(total_seqlen - uniforms.global_shape[1], 0u, total_seqlen <= uniforms.global_shape[1]);\n" " let position_id = past_seqlen + bsnh[1];\n" << " let i = dot(bsnh, uniforms.input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let j = i + select(half_rotary_emb_dim, 1u, " << interleaved_str << ");\n" @@ -200,7 +201,8 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c << " let seqlen_i = " << seqlens.GetByOffset("batch_idx") << ";\n" << " let seqlen = u32(seqlen_i);\n" << " let total_seqlen = seqlen + 1u;\n" - << " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n" + << " // Right-padded batches with prompt shorter than q_global_shape[1] would underflow u32; clamp to 0.\n" + << " let past_seqlen = select(total_seqlen - uniforms.q_global_shape[1], 0u, total_seqlen <= uniforms.q_global_shape[1]);\n" << " let position_id = past_seqlen + sequence_idx;\n" << " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template index 7fcdfcfddfb25..51eda83d089f1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding.wgsl.template @@ -30,7 +30,8 @@ $MAIN { let seqlen_i = seqlens.getByOffset(batch_idx); let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset); diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 821f43971848a..e342c872fd9b4 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2392,5 +2392,198 @@ TEST(GroupQueryAttentionTest, WebGPU_SharedKV_SlidingWindow) { tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// --------------------------------------------------------------------------- +// WebGPU: batched right-padded packed-QKV prefill regression +// +// In a multi-batch prefill where individual prompts have different real lengths, +// GenAI right-pads short prompts up to the max sequence_length and reports each +// batch's real length via seqlens_k[b] = real_len[b] - 1. The WebGPU rotary +// embedding shader for packed-QKV computes past_seqlen = (seqlens_k[b] + 1) - +// sequence_length per-batch. For a short batch whose real_len < sequence_length, +// that subtraction underflowed u32, producing astronomically large position_ids +// that read out-of-bounds from cos/sin caches -- garbage values manifesting as +// gibberish output text. The fix clamps past_seqlen to 0 during prefill. +// +// This test exercises the packed-QKV do_rotary path (which dispatches +// SplitPackedQKVWithRotaryEmbeddingProgram). It compares each batch's +// real-last-token output against a single-batch reference for the same prompt. +// --------------------------------------------------------------------------- + +// Builds a packed QKV tensor with deterministic values at real positions and +// zeros at right-padded positions. Layout per token: [Q(hidden), K(kv), V(kv)]. +// Uses values of order ~1.0 (well above the 5e-3 mismatch tolerance) so the +// rotated-vs-unrotated divergence is unambiguously detectable. +static void FillBatchedRightPaddedPackedQKV(int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& real_lens, + std::vector& packed_out) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int token_size = hidden_size + 2 * kv_hidden_size; + packed_out.assign(batch_size * sequence_length * token_size, 0.0f); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + for (int s = 0; s < real_len; ++s) { + float* token = &packed_out[(b * sequence_length + s) * token_size]; + for (int c = 0; c < hidden_size; ++c) { + token[c] = 0.1f + 0.3f * static_cast(((b * 7 + s * 3 + c) % 13) + 1); + } + for (int c = 0; c < kv_hidden_size; ++c) { + token[hidden_size + c] = + 0.1f + 0.25f * static_cast(((b * 5 + s * 2 + c) % 11) + 1); + token[hidden_size + kv_hidden_size + c] = + 0.1f + 0.2f * static_cast(((b * 3 + s + c) % 9) + 1); + } + } + } +} + +// Runs a packed-QKV GQA prefill with do_rotary=1 and the given per-batch +// seqlens_k. Returns the output tensor [batch_size, sequence_length, hidden_size]. +static std::vector RunGQAPackedQKVRotaryPrefill( + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& seqlens_k_data, + const std::vector& packed_qkv_data) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int qkv_hidden = hidden_size + 2 * kv_hidden_size; + const int total_sequence_length = sequence_length; // prefill: no past + const int half_rotary = head_size / 2; + const int max_seq_len = sequence_length + 8; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("do_rotary", static_cast(1)); + + // Packed QKV: pass through `query` input, leave key/value as optional edges. + tester.AddInput("query", {batch_size, sequence_length, qkv_hidden}, packed_qkv_data); + tester.AddOptionalInputEdge(); // key (signals packed) + tester.AddOptionalInputEdge(); // value (signals packed) + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + tester.AddInput("seqlens_k", {batch_size}, seqlens_k_data); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, + /*is_initializer=*/true); + + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + const float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / + static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); // We fetch and compare outputs ourselves. + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); +} + +// Regression for u32 underflow in WebGPU SplitPackedQKVWithRotaryEmbedding +// shader during right-padded batched prefill. Runs each prompt singly to build +// a reference, then runs all prompts as a right-padded batch and asserts that +// each batch's real-last-token output matches its single-prompt reference. +TEST(GroupQueryAttentionTest, WebGPU_BatchedRightPaddedRotaryPrefill) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 3; + constexpr int num_heads = 4; + constexpr int kv_num_heads = 2; + constexpr int head_size = 16; // multiple of 4 for FlashAttention gate; rotary half = 8 + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; + + // Real prompt lengths per batch; max = sequence_length (right-padding extends + // shorter batches up to this length). The bug only manifests when at least + // one batch is shorter than sequence_length. + const std::vector real_lens = {4, 2, 6}; + const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); + + std::vector packed_batched; + FillBatchedRightPaddedPackedQKV(batch_size, sequence_length, num_heads, kv_num_heads, + head_size, real_lens, packed_batched); + + // Build single-prompt references by extracting each batch's real-len slice + // and running it as a batch_size=1 prefill (which is known correct). + std::vector> ref_outputs(batch_size); + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + std::vector packed_single(real_len * qkv_hidden); + for (int s = 0; s < real_len; ++s) { + std::copy_n(&packed_batched[(b * sequence_length + s) * qkv_hidden], qkv_hidden, + &packed_single[s * qkv_hidden]); + } + ref_outputs[b] = RunGQAPackedQKVRotaryPrefill( + /*batch_size=*/1, /*sequence_length=*/real_len, + num_heads, kv_num_heads, head_size, + /*seqlens_k_data=*/{static_cast(real_len - 1)}, + packed_single); + } + + // Now run all batches together with right-padding. + std::vector seqlens_k_data(batch_size); + for (int b = 0; b < batch_size; ++b) { + seqlens_k_data[b] = static_cast(real_lens[b] - 1); + } + const auto batched_output = RunGQAPackedQKVRotaryPrefill( + batch_size, sequence_length, num_heads, kv_num_heads, head_size, + seqlens_k_data, packed_batched); + + // Each batch's real-last-token output (used to predict next token) must match + // its single-prompt reference. The tolerance is loose enough for fp16 rounding + // while still catching the underflow bug (which produces values that differ + // by orders of magnitude or are NaN/Inf). + constexpr float tolerance = 5e-3f; + for (int b = 0; b < batch_size; ++b) { + const int real_len = real_lens[b]; + const int q_last = real_len - 1; + const float* batched_last = + batched_output.data() + (b * sequence_length + q_last) * hidden_size; + const float* ref_last = ref_outputs[b].data() + q_last * hidden_size; + for (int c = 0; c < hidden_size; ++c) { + EXPECT_NEAR(batched_last[c], ref_last[c], tolerance) + << "batch " << b << " real_len=" << real_len + << " channel " << c << " mismatch"; + } + } +} + } // namespace test } // namespace onnxruntime From 4c0492c711d296091edf62c00ff0e38d9e65e274 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jun 2026 17:43:37 +0800 Subject: [PATCH 2/7] webgpu: support batch_size > 1 on the FlashAttention path The FlashAttention path on WebGPU previously gated batched GQA out via (batch_size_ == 1 || seqlen_k == nullptr || kv_empty) in CanApplyFlashAttention because three shaders hardcoded seqlens_k[0] and several KV-cache write offsets / causal-mask / rotary-position derivations underflowed u32 when the per-batch prompt was shorter than the max-across-batches sequence_length (right-padded batches). This change reads seqlens_k per batch in all FlashAttention shaders (prefill + decode split-reduce + CopyKVCache + the rotary-and-copyKV template), clamps every past_X = total_X - new_X subtraction to avoid u32 underflow, and decouples attention_bias stride (still allocated to the global max total_sequence_length) from the per-batch OOB check. The decode_qkv shader retains a workgroup-grid sized to the global max total_sequence_length tile count (so workgroup_idx slicing remains self-consistent across batches), and early-exits with neutral metadata (-inf, 0) for tiles beyond a short batch's per-batch total so the VxReduce online softmax rescaling is not skewed by garbage tiles. A new use_seqlen_k template parameter (separate from use_indirect_dispatch, which still requires graph capture) drives the per-batch path; it is enabled whenever seqlen_k is provided and (graph_capture || batch_size_ > 1). Verified: - All 7 GroupQueryAttentionTest.WebGPU_* op tests pass, including BatchedRightPaddedRotaryPrefill which now exercises FlashAttention instead of ApplyAttention. - phi4-prune three-prompt batched generation produces coherent, correct outputs on WebGPU matching the CPU reference. - phi4-prune single-prompt generation regression: coherent output. - whisper-tiny-int4 transcription regression: 2/2 byte-exact with CPU. --- .../webgpu/bert/flash_attention.cc | 48 +++++++----- .../contrib_ops/webgpu/bert/flash_attention.h | 11 ++- .../webgpu/bert/flash_attention.wgsl.template | 77 ++++++++++--------- .../flash_attention_decode_qkv.wgsl.template | 51 +++++++++--- ...h_attention_decode_vx_reduce.wgsl.template | 10 +++ ..._rotary_embedding_and_copykv.wgsl.template | 3 +- 6 files changed, 132 insertions(+), 68 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 9be6a047cea9c..e22bf20a1e27d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -109,11 +109,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; if (use_seqlen_k_) { - shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[batch]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + // Right-padded batches with prompt shorter than kv_sequence_length would underflow u32; clamp to 0. + shader.MainFunctionBody() << " let past_sequence_length = select(total_seq_length - uniforms.kv_sequence_length, 0u, total_seq_length <= uniforms.kv_sequence_length);\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; } else { @@ -262,7 +263,8 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + const bool needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_; + if (needs_seqlens_k) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { @@ -282,6 +284,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), @@ -293,7 +296,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -303,11 +306,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + const bool needs_seqlens_k = use_indirect_dispatch || use_seqlen_k; + if (needs_seqlens_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { @@ -332,7 +336,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -351,7 +355,8 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { + const bool needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_; + if (needs_seqlens_k) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -365,6 +370,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(output, output)); @@ -381,15 +387,17 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t seq_tile_size, bool use_indirect_dispatch, const Tensor* head_sink, - uint32_t m_tile) { + uint32_t m_tile, + bool use_seqlen_k) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - if (use_indirect_dispatch) { + const bool needs_seqlens_k = use_indirect_dispatch || use_seqlen_k; + if (needs_seqlens_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { @@ -399,7 +407,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -437,7 +445,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Read seqlens_k per batch_idx in the shader whenever: + // (a) graph capture is enabled (total_sequence_length_ is 0 on the host), OR + // (b) batch_size > 1 (right-padded batches have distinct per-batch totals). + // Otherwise the kernels fall back to uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr && + (context.IsGraphCaptureEnabled() || parameters.batch_size_ > 1); // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -596,20 +609,19 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile)); + present_sequence_length, m_tile, use_seqlen_k)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - head_sink, m_tile)); + head_sink, m_tile, use_seqlen_k)); return Status::OK(); } bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - const bool kv_empty = parameters.kv_sequence_length_ == 0; - return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && - !parameters.is_packed_qkv_ && + ORT_UNUSED_PARAMETER(seqlen_k); + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 218baf926173f..ad1a73cc5fe4d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -148,8 +148,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -198,6 +200,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program u32 { - return u32(seqlens_k[0]) + 1u; +// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. +fn get_total_sequence_length(batch_idx: u32) -> u32 { + return u32(seqlens_k[batch_idx]) + 1u; } #else -// When graph capture is disabled, total_sequence_length comes from uniforms -fn get_total_sequence_length() -> u32 { +// Without seqlens_k, total_sequence_length comes from uniforms (max across batches). +fn get_total_sequence_length(batch_idx: u32) -> u32 { return uniforms.total_sequence_length; } #endif @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); @@ -95,15 +93,18 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { } #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t { - if (k_idx_global >= get_total_sequence_length()) { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t { + if (k_idx_global >= total_seq) { return q_element_t(0); } let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); - return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]); + // The attention_bias tensor is allocated to the global max total_sequence_length, + // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. + let stride_total_seq = uniforms.total_sequence_length; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; + return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq)]); } #endif @@ -111,24 +112,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // For max performance max_k_step should be the same as sg_size, however we might run out of registers // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } } @@ -160,18 +161,21 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= total_seq) { return vec4(0); } // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // The attention_bias tensor is allocated to the global max total_sequence_length, + // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. + let stride_total_seq = uniforms.total_sequence_length; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; - let offset_max = offset_base + get_total_sequence_length(); + let offset_max = offset_base + stride_total_seq; let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -179,7 +183,7 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { return vec4(0); } #endif @@ -226,11 +230,14 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; #endif - let total_sequence_length = get_total_sequence_length(); + let total_sequence_length = get_total_sequence_length(batch_idx); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); @@ -244,8 +251,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx); - loadv(k_start, batch_head_idx, local_idx); + loadk(k_start, batch_head_idx, local_idx, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, total_sequence_length); workgroupBarrier(); for (var k = 0u; k < max_k_step; k++) { @@ -254,7 +261,7 @@ $MAIN { score += dot(q_tile[i], k_tile[k][i]); } #if has_attention_bias - score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx); + score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length); #endif qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length); } @@ -302,8 +309,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx, capped_sg_size); - loadv(k_start, batch_head_idx, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); workgroupBarrier(); // Compute QKt @@ -361,11 +368,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length); } // Neuter qk values where K is out of bounds. diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 524a18ca43245..8e8708a206229 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -10,6 +10,7 @@ #param tile_size #param tile_size_k_vec #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -34,18 +35,21 @@ var tile_max: array; var tile_sum: array; #if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + - bias_head_idx * uniforms.new_sequence_length * total_seq_length + - q_idx * total_seq_length + + // The attention_bias tensor is allocated to the global max total_sequence_length, + // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. + let stride_total_seq = uniforms.total_sequence_length; + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + + q_idx * stride_total_seq + k_idx; return attention_bias[offset]; } #else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { return q_element_t(0); } @@ -54,12 +58,14 @@ var tile_sum: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; + // total_sequence_length used for workgroup_idx slicing must match the host-side dispatch + // grid, i.e. the global maximum across batches. Per-batch total is derived separately below. #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + let global_total_sequence_length = u32(seqlens_k[0]) + 1u; #else - let total_sequence_length = uniforms.total_sequence_length; + let global_total_sequence_length = uniforms.total_sequence_length; #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_total_seq_length_tile = (global_total_sequence_length + tile_size - 1) / tile_size; let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; @@ -71,9 +77,28 @@ $MAIN { if (batch_idx >= uniforms.batch_size) { return; } + // Per-batch total_sequence_length used for K/V bounds, causal mask, and softmax range. + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[batch_idx]) + 1u; + #else + let total_sequence_length = global_total_sequence_length; + #endif let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + // If this workgroup's tile lies entirely beyond this batch's per-batch total_sequence_length, + // write neutral metadata so VxReduce contributes nothing for these tiles, then exit early. + if (total_seq_offset >= total_sequence_length) { + if (local_idx == 0u) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx_local = q_base + m; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx_local) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(-3.4028234663852886e+38f, 0.0f)); + } + } + return; + } + // ============================================================ // Phase 1: QK^T computation // ============================================================ @@ -109,6 +134,12 @@ $MAIN { } // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask +#if is_unidirectional + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); +#endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { @@ -117,9 +148,9 @@ $MAIN { sum += inner_qk_values[m][local_idx][i]; } - sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional - if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + if (total_seq_offset + local_idx > past_sequence_length + q_idx) { sum = q_element_t(-65504.0f); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a3ce0b68cb659..0bb042d34944f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -6,6 +6,7 @@ #param seq_tile_size #param tile_size #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -32,12 +33,21 @@ $MAIN { } let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; + // Per-batch total_sequence_length: short batches contributed neutral metadata + // (-inf, 0) for tiles beyond their per-batch total, so reading only this batch's + // tiles ensures softmax rescaling is not skewed by garbage tiles. + #if use_seqlen_k + let batch_idx_for_seqlen = batch_head_idx / uniforms.num_heads; + let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; + let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; + #else #if use_indirect_dispatch let total_sequence_length = u32(seqlens_k[0]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif + #endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 7b09a3a6af080..97c610fb90024 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -31,7 +31,8 @@ $MAIN { let seqlen = u32(seqlen_i); let total_seqlen = seqlen + 1u; - let past_seqlen = total_seqlen - uniforms.sequence_length; + // Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0. + let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length); // `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value let position_id = past_seqlen + seq_idx; #if use_multi_rotary_cache_concat From 0c1512afbf640e1f1e69db0f848d42e931c8bedf Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 25 Jun 2026 13:14:51 +0800 Subject: [PATCH 3/7] webgpu: Tidy flash-attention seqlens_k handling and add prefill test - Drop the now-unused seqlen_k parameter from CanApplyFlashAttention and update the GQA caller. The argument was already ORT_UNUSED_PARAMETER. - Simplify use_seqlen_k to (seqlen_k != nullptr) so batch=1 and batch>1 share one path. The previous (graph_capture || batch_size > 1) qualifier was redundant: both conditions imply seqlen_k is supplied, and reading seqlens_k[batch_idx] in the shader is a no-op for batch=1. - Read attention_bias's actual last-dim stride from a new attn_bias_dim3 uniform instead of uniforms.total_sequence_length. The shader stride must match the tensor's storage shape, which can differ from the per-step total (e.g. graph capture sets total_sequence_length=0 on the host, which would have produced a zero-stride bias offset). - Strip the unreachable indirect-dispatch branch in flash_attention_decode_ vx_reduce.wgsl.template. After the use_seqlen_k simplification, use_indirect_dispatch=true implies use_seqlen_k=true (the host gate requires seqlen_k != nullptr), so the inner branch could never execute. Keep the use_indirect_dispatch flag elsewhere for forward compatibility. - Add BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU test that exercises the FlashAttentionProgram prefill path (sequence_length=33 crosses the split-reduce threshold of 32) with right-padded batches and do_rotary, matching the Phi-style batched prefill scenario. --- .../webgpu/bert/flash_attention.cc | 24 +++++++++++-------- .../contrib_ops/webgpu/bert/flash_attention.h | 6 +++-- .../webgpu/bert/flash_attention.wgsl.template | 14 ++++++----- .../flash_attention_decode_qkv.wgsl.template | 7 +++--- ...h_attention_decode_vx_reduce.wgsl.template | 6 ----- .../webgpu/bert/group_query_attention.cc | 2 +- .../group_query_attention_op_test.cc | 23 +++++++++++++++--- 7 files changed, 51 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index e22bf20a1e27d..4cd7baeba5c1d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -324,10 +324,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } if (use_indirect_dispatch) { @@ -347,6 +349,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, {attn_bias_dim1}, + {attn_bias_dim3}, {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); @@ -369,7 +372,6 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), @@ -445,12 +447,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - // Read seqlens_k per batch_idx in the shader whenever: - // (a) graph capture is enabled (total_sequence_length_ is 0 on the host), OR - // (b) batch_size > 1 (right-padded batches have distinct per-batch totals). - // Otherwise the kernels fall back to uniforms.total_sequence_length. - const bool use_seqlen_k = seqlen_k != nullptr && - (context.IsGraphCaptureEnabled() || parameters.batch_size_ > 1); + // Read seqlens_k per batch_idx in the shader whenever seqlens_k is supplied. + // This covers both graph-capture (total_sequence_length_ is 0 on the host) and + // right-padded batches (batch_size > 1 with distinct per-batch totals), and lets + // batch=1 share the same path. When seqlens_k is null, kernels fall back to + // uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr; // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -568,10 +570,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) @@ -585,7 +589,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {alpha}, {num_seq_tile}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {attn_bias_dim3}}); return context.RunProgram(program); } @@ -619,8 +624,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - ORT_UNUSED_PARAMETER(seqlen_k); +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index ad1a73cc5fe4d..6207b94c4c0ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -125,7 +125,8 @@ class FlashAttentionProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, - {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim3", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -165,6 +166,7 @@ class FlashAttentionDecodeQKVProgram final : public Program= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - // The attention_bias tensor is allocated to the global max total_sequence_length, - // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. - let stride_total_seq = uniforms.total_sequence_length; + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq)]); @@ -169,9 +170,10 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - // The attention_bias tensor is allocated to the global max total_sequence_length, - // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. - let stride_total_seq = uniforms.total_sequence_length; + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 8e8708a206229..14af843b01dbd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -39,9 +39,10 @@ var tile_sum: array; { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - // The attention_bias tensor is allocated to the global max total_sequence_length, - // so stride math uses uniforms.total_sequence_length even when per-batch values are smaller. - let stride_total_seq = uniforms.total_sequence_length; + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx * stride_total_seq + diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index 0bb042d34944f..628ad835a9d4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -5,7 +5,6 @@ #param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch #param use_seqlen_k #use .getByOffset .setByOffset @@ -41,13 +40,8 @@ $MAIN { let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; - let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; - #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif - #endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 36d688c9723fd..930cb296122ce 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -350,7 +350,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (kv_empty) { diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 645564d01abc0..7f9e5caff3f06 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2619,7 +2619,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, + const std::vector& real_lens = {4, 2, 6}) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2628,10 +2629,10 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int kv_hidden_size = kv_num_heads * head_size; constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; - // Real prompt lengths per batch; max = sequence_length (right-padding extends + // Per-batch real prompt lengths; max = sequence_length (right-padding extends // shorter batches up to this length). The bug only manifests when at least // one batch is shorter than sequence_length. - const std::vector real_lens = {4, 2, 6}; + ASSERT_EQ(static_cast(real_lens.size()), batch_size); const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); std::vector packed_batched; @@ -2715,5 +2716,21 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with per-batch +// real_lens whose max crosses the prefill threshold (sequence_length >= 32) so +// the WebGPU EP picks FlashAttentionProgram (single-kernel prefill path with +// subgroup shuffles) instead of the split-reduce decode path. This exercises +// the prefill flash-attention kernel under right-padded batches with do_rotary, +// which is the path used by Phi-style models during batched prefill. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // sequence_length = max(real_lens) = 33 > 32 -> FlashAttentionProgram path. + // Mixed shorter batches (12, 20) ensure right-padding is non-trivial. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); +} + } // namespace test } // namespace onnxruntime From f624ecb4ce7fdf78d1f0f5ecc79b91de1a63f2a8 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 25 Jun 2026 13:25:55 +0800 Subject: [PATCH 4/7] webgpu: Add non-flash attention path coverage to batched right-padded prefill tests Thread a smooth_softmax flag through RunGQAPackedQKVRotaryPrefill / RunBatchedRightPaddedRotaryPrefillForEP and add a new WebGPU test: BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU With smooth_softmax=1 the WebGPU EP skips CanApplyFlashAttention and routes through ApplyAttention, so the three WebGPU prefill tests now cover all three batched right-padded code paths: - split-reduce decode shader (FlashAttentionDecodeQKV + VxReduce) - fused prefill shader (FlashAttentionProgram) - non-flash attention path (ApplyAttention) Verified by temporarily printing the chosen path during dispatch: each test hits its intended path with batch_size=3 and right-padded seqlens_k. --- .../group_query_attention_op_test.cc | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 7f9e5caff3f06..5c831d22cd3fd 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2510,7 +2510,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - GqaTargetEp target_ep = GqaTargetEp::kCpu) { + GqaTargetEp target_ep = GqaTargetEp::kCpu, + bool smooth_softmax = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2529,6 +2530,10 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); + if (smooth_softmax) { + // smooth_softmax disqualifies WebGPU CanApplyFlashAttention -> routes to ApplyAttention. + tester.AddAttribute("smooth_softmax", static_cast(1)); + } // Packed QKV: pass through `query` input, leave key/value as optional edges. if (use_fp16) { @@ -2620,7 +2625,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, - const std::vector& real_lens = {4, 2, 6}) { + const std::vector& real_lens = {4, 2, 6}, + bool smooth_softmax = false) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2653,7 +2659,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, target_ep); + packed_single, target_ep, smooth_softmax); } // Now run all batches together with right-padding. @@ -2663,7 +2669,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, target_ep); + seqlens_k_data, packed_batched, target_ep, smooth_softmax); // Guard the regression deterministically: every element of the batched output // (including padding rows) must be finite. The CPU root cause is uninitialized @@ -2732,5 +2738,17 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebG RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with +// smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes +// through ApplyAttention (non-flash path). Covers right-padded batched prefill +// on the non-flash attention path (used by e.g. Phi-4 attention variants). +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); +} + } // namespace test } // namespace onnxruntime From 9c5bff92902e5c2789869746ddead09037699dd7 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 25 Jun 2026 13:54:07 +0800 Subject: [PATCH 5/7] webgpu: Simplify seqlens_k binding gate to use_seqlen_k_ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use_indirect_dispatch implies seqlen_k != nullptr (graph-capture path), and use_seqlen_k_ = (seqlen_k != nullptr), so use_indirect_dispatch_ ⇒ use_seqlen_k_. The needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_ disjunction is therefore equivalent to use_seqlen_k_ at every call site. Drop the local. --- .../contrib_ops/webgpu/bert/flash_attention.cc | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 4cd7baeba5c1d..29ef2de6b228f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -263,8 +263,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const bool needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_; - if (needs_seqlens_k) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { @@ -310,8 +309,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - const bool needs_seqlens_k = use_indirect_dispatch || use_seqlen_k; - if (needs_seqlens_k) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { @@ -358,8 +356,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - const bool needs_seqlens_k = use_indirect_dispatch_ || use_seqlen_k_; - if (needs_seqlens_k) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -398,8 +395,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - const bool needs_seqlens_k = use_indirect_dispatch || use_seqlen_k; - if (needs_seqlens_k) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { From 2adc29a7a38740beb5262510542668787cdb7451 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 25 Jun 2026 15:37:57 +0800 Subject: [PATCH 6/7] webgpu: Use global total_sequence_length for indirect dispatch sizing In graph-capture mode the host total_sequence_length scalar is 0 and the dispatch grid for the flash-attention pipeline is computed on the GPU. The three shaders that prepare or consume the indirect-dispatch buffer (CopyKVCache, SplitPackedQKVWithRotaryEmbeddingAndCopyKV, FlashAttentionDecodeQKV) previously sized the grid from seqlens_k[batch=0] + 1. For batched right-padded prefill, batch 0 is not guaranteed to hold the maximum KV span, so when the spread across the batch crosses a tile boundary other batches lose tiles and produce wrong output. Thread GQA's input #6 (total_sequence_length, GPU-resident exactly when graph capture is enabled) through ApplyFlashAttention into the three shaders and use it for the indirect-dispatch sizing only. Per-batch seqlens_k[batch] + 1 still drives causal masking and per-batch bounds inside the kernels. Also enforce in GroupQueryAttention that graph capture implies past_present_share_buffer_, so the use_indirect_dispatch predicate only needs to check seqlen_k, total_seqlen, and IsGraphCaptureEnabled. Address PR review: - Clamp attention_bias load to offset_base + stride_total_seq - 1u in both scalar and vec4 paths so the one-past-end fallback stays within the same row. - Reword the smooth_softmax test comment to reference the outer gating in GroupQueryAttention::ComputeInternal that routes through ApplyAttention. - Extend the indirect-dispatch fix to FlashAttentionDecodeQKV; the new use of use_indirect_dispatch_ also resolves the -Wunused-private-field Clang error on the wasm and arm64 builds. Add BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU with real_lens spread > tile_size so a future regression in the dispatch sizing surfaces in the WebGPU test suite (graph capture itself cannot be toggled from OpTester). --- .../webgpu/bert/flash_attention.cc | 51 +++++++++++++++---- .../contrib_ops/webgpu/bert/flash_attention.h | 6 ++- .../webgpu/bert/flash_attention.wgsl.template | 4 +- .../flash_attention_decode_qkv.wgsl.template | 2 +- .../webgpu/bert/group_query_attention.cc | 8 ++- ..._rotary_embedding_and_copykv.wgsl.template | 3 +- .../group_query_attention_op_test.cc | 21 +++++++- 7 files changed, 75 insertions(+), 20 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 29ef2de6b228f..67417e8f4c7be 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -55,6 +55,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + if (prepare_indirect_dispatch_) { + sh.AddInput("total_sequence_length_input", ShaderUsage::None); + } const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); @@ -97,8 +100,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::None); } - // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + // If prepare_indirect_dispatch is enabled, add total_sequence_length_input + // and indirect_buffer output. total_sequence_length_input is the global max + // total sequence length across the batch (from GQA input #6); using it for + // dispatch sizing covers right-padded batches where batch 0 is not the max. if (prepare_indirect_dispatch_) { + shader.AddInput("total_sequence_length_input", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -125,7 +132,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (prepare_indirect_dispatch_) { shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" - << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" + << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -153,7 +161,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles, + const Tensor* total_seqlen) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -189,6 +198,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -266,6 +278,12 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (use_indirect_dispatch_) { + // Global max total sequence length across batches (from GQA input #6). + // Used in indirect-dispatch mode for the workgroup_idx slicing so that + // batch 0's per-batch length cannot undersize the dispatch grid. + shader.AddInput("total_sequence_length_input", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -295,7 +313,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -312,6 +330,9 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (use_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -421,7 +442,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink, + const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr. @@ -464,8 +486,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Prepare indirect dispatch buffer for split-reduce path with static KV cache. // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && + // Static KV cache (past_present_share_buffer_) is guaranteed by GQA's + // ORT_ENFORCE when graph capture is enabled. + const bool use_indirect_dispatch = seqlen_k != nullptr && + total_seqlen != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions @@ -503,10 +527,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles)); + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); } // Extract present_sequence_length directly from present_key tensor shape @@ -610,7 +635,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile, use_seqlen_k)); + present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, @@ -636,7 +661,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size, uint32_t num_q_tiles) { + uint32_t tile_size, uint32_t num_q_tiles, + const Tensor* total_seqlen) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -674,6 +700,9 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {cos_cache, ProgramTensorMetadataDependency::Rank, components}, {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }); + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, {present_key, ProgramTensorMetadataDependency::None, components}, {present_value, ProgramTensorMetadataDependency::None, components}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 6207b94c4c0ff..77b787f55aa18 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -208,7 +208,8 @@ class FlashAttentionDecodeVxReduceProgram final : public ProgramDataRaw() == present_value->DataRaw(); ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + ORT_ENFORCE(!context.IsGraphCaptureEnabled() || parameters.past_present_share_buffer_, + "Graph capture requires past/present KV cache to share the same buffer (static KV cache)."); Tensor qSplit; Tensor kSplit; @@ -381,7 +383,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink); + present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink, + total_seqlen_tensor); } // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -472,7 +475,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink, + total_seqlen_tensor); } // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 97c610fb90024..e3d92c036d2c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -43,7 +43,8 @@ $MAIN { #if prepare_indirect_dispatch if (global_idx == 0u) { - let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + let global_total_seq_length = u32(total_sequence_length_input[0]); + let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 5c831d22cd3fd..54d7c790f77d8 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2531,7 +2531,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); if (smooth_softmax) { - // smooth_softmax disqualifies WebGPU CanApplyFlashAttention -> routes to ApplyAttention. + // smooth_softmax disqualifies the WebGPU FlashAttention path via the outer + // gating in GroupQueryAttention::ComputeInternal, routing this case through + // ApplyAttention instead. tester.AddAttribute("smooth_softmax", static_cast(1)); } @@ -2738,6 +2740,23 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebG RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); } +// Stress the FlashAttention prefill path with a per-batch spread that exceeds +// the indirect-dispatch tile size (64). batch 0 has the SHORTEST real length; +// batch 2 has the LONGEST. This is the data pattern that would surface the +// indirect-dispatch undersizing bug when graph capture is enabled (where the +// dispatch grid is sized from a GPU buffer rather than the host scalar). +// OpTester does not toggle graph capture, so this test exercises the new +// total_sequence_length_input shader plumbing on the non-graph-capture path; +// the graph-capture path is covered end-to-end by phi4-graph-prune verification. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // spread = 96 - 20 = 76 > tile_size(64), batch 0 is not the max. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 96}); +} + // Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with // smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes // through ApplyAttention (non-flash path). Covers right-padded batched prefill From 9c23874f6164f40e5b46ecfe6843881584d0579a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 25 Jun 2026 16:13:46 +0800 Subject: [PATCH 7/7] webgpu: Drop unused use_indirect_dispatch from DecodeVxReduce FlashAttentionDecodeVxReduceProgram no longer branches on use_indirect_dispatch in its shader template (the per-batch iteration is gated by use_seqlen_k instead), so the field is dead and Clang rejects it as -Wunused-private-field on wasm and arm64 builds. Remove the parameter from the program ctor, the member field, the ComputeFlashAttentionDecodeVxReduce signature, and the CacheHint so identical shaders share one cached pipeline. --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 7 +++---- onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 67417e8f4c7be..4e926c7efa597 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -405,7 +405,6 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch, const Tensor* head_sink, uint32_t m_tile, bool use_seqlen_k) { @@ -413,7 +412,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile, use_seqlen_k}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); if (use_seqlen_k) { @@ -426,7 +425,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile, use_seqlen_k) + .CacheHint(tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -639,7 +638,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + num_present_sequence_length_tile, tile_size, head_sink, m_tile, use_seqlen_k)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 77b787f55aa18..85ba61c1d20b5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -182,8 +182,8 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -199,7 +199,6 @@ class FlashAttentionDecodeVxReduceProgram final : public Program