diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index d9d299d4fd5d9..4e926c7efa597 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -55,6 +55,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = sh.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = sh.AddInput("sin_cache", ShaderUsage::UseUniform); + if (prepare_indirect_dispatch_) { + sh.AddInput("total_sequence_length_input", ShaderUsage::None); + } const auto& query = sh.AddOutput("query", ShaderUsage::UseUniform); const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform); @@ -97,8 +100,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (use_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::None); } - // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output + // If prepare_indirect_dispatch is enabled, add total_sequence_length_input + // and indirect_buffer output. total_sequence_length_input is the global max + // total sequence length across the batch (from GQA input #6); using it for + // dispatch sizing covers right-padded batches where batch 0 is not the max. if (prepare_indirect_dispatch_) { + shader.AddInput("total_sequence_length_input", ShaderUsage::None); shader.AddOutput("indirect_buffer", ShaderUsage::None); } @@ -109,11 +116,12 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { " let num_head_id = output_indices[1];\n" " let batch = output_indices[0];\n"; if (use_seqlen_k_) { - shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[0u]) + 1u;\n"; + shader.MainFunctionBody() << " let total_seq_length = u32(seqlen_k[batch]) + 1u;\n"; } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } - shader.MainFunctionBody() << " let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; + // Right-padded batches with prompt shorter than kv_sequence_length would underflow u32; clamp to 0. + shader.MainFunctionBody() << " let past_sequence_length = select(total_seq_length - uniforms.kv_sequence_length, 0u, total_seq_length <= uniforms.kv_sequence_length);\n"; if (past_present_share_buffer_) { shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n"; } else { @@ -124,7 +132,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { if (prepare_indirect_dispatch_) { shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" - << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" + << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" + << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -152,7 +161,8 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles, + const Tensor* total_seqlen) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -188,6 +198,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -262,9 +275,15 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } + if (use_indirect_dispatch_) { + // Global max total sequence length across batches (from GQA input #6). + // Used in indirect-dispatch mode for the workgroup_idx slicing so that + // batch 0's per-batch length cannot undersize the dispatch grid. + shader.AddInput("total_sequence_length_input", ShaderUsage::None); + } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } @@ -282,6 +301,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), @@ -293,7 +313,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -303,13 +323,16 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (use_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -320,10 +343,12 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } if (use_indirect_dispatch) { @@ -332,7 +357,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -343,6 +368,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, {attn_bias_dim1}, + {attn_bias_dim3}, {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); @@ -351,7 +377,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { + if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { @@ -364,7 +390,7 @@ Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& sha WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_VARIABLE(input, input), WGSL_TEMPLATE_VARIABLE(metadata, metadata), WGSL_TEMPLATE_VARIABLE(output, output)); @@ -379,17 +405,17 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch, const Tensor* head_sink, - uint32_t m_tile) { + uint32_t m_tile, + bool use_seqlen_k) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; bool has_head_sink = head_sink != nullptr; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k}; program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); - if (use_indirect_dispatch) { + if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_head_sink) { @@ -399,7 +425,7 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) + .CacheHint(tile_size, seq_tile_size, has_head_sink, m_tile, use_seqlen_k) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, @@ -415,7 +441,8 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value, const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k, - const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink) { + const Tensor* cos_cache, const Tensor* sin_cache, const Tensor* head_sink, + const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; // Create present_key and present_value tensors if they are nullptr. @@ -437,7 +464,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_value = &internal_present_value; } - const bool use_seqlen_k = seqlen_k != nullptr && context.IsGraphCaptureEnabled(); + // Read seqlens_k per batch_idx in the shader whenever seqlens_k is supplied. + // This covers both graph-capture (total_sequence_length_ is 0 on the host) and + // right-padded batches (batch_size > 1 with distinct per-batch totals), and lets + // batch=1 share the same path. When seqlens_k is null, kernels fall back to + // uniforms.total_sequence_length. + const bool use_seqlen_k = seqlen_k != nullptr; // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; @@ -453,8 +485,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Prepare indirect dispatch buffer for split-reduce path with static KV cache. // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. - const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && - seqlen_k != nullptr && + // Static KV cache (past_present_share_buffer_) is guaranteed by GQA's + // ORT_ENFORCE when graph capture is enabled. + const bool use_indirect_dispatch = seqlen_k != nullptr && + total_seqlen != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions @@ -492,10 +526,11 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles)); + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); } // Extract present_sequence_length directly from present_key tensor shape @@ -555,10 +590,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; + uint32_t attn_bias_dim3 = 0; if (has_attention_bias) { const auto& bias_shape = attention_bias->Shape(); attn_bias_dim0 = static_cast(bias_shape[0]); attn_bias_dim1 = static_cast(bias_shape[1]); + attn_bias_dim3 = static_cast(bias_shape[3]); } program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) @@ -572,7 +609,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {alpha}, {num_seq_tile}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {attn_bias_dim3}}); return context.RunProgram(program); } @@ -596,27 +634,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co &metadata, seqlen_k, parameters, indirect_buffer_ptr, num_total_seq_length_tile, num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile)); + present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - head_sink, m_tile)); + num_present_sequence_length_tile, tile_size, + head_sink, m_tile, use_seqlen_k)); return Status::OK(); } -bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) { - const bool kv_empty = parameters.kv_sequence_length_ == 0; - // FlashAttention here does not implement right-padded per-batch prefill, so the - // first disjunction restricts it to inputs where padding cannot occur: - // - batch_size_ == 1: single sequence, no padding possible. - // - seqlen_k == nullptr: no per-batch lengths, padding inexpressible. - // - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader. - // The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel), - // mismatched head/value sizes, and head_size alignments unsupported by the kernel. - return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) && - !parameters.is_packed_qkv_ && +bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } @@ -631,7 +660,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size, uint32_t num_q_tiles) { + uint32_t tile_size, uint32_t num_q_tiles, + const Tensor* total_seqlen) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -669,6 +699,9 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {cos_cache, ProgramTensorMetadataDependency::Rank, components}, {sin_cache, ProgramTensorMetadataDependency::Rank, components}, }); + if (prepare_indirect_dispatch) { + program.AddInput({total_seqlen, ProgramTensorMetadataDependency::None}); + } program.AddOutputs({{query, ProgramTensorMetadataDependency::None, components}, {present_key, ProgramTensorMetadataDependency::None, components}, {present_value, ProgramTensorMetadataDependency::None, components}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 218baf926173f..85ba61c1d20b5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -125,7 +125,8 @@ class FlashAttentionProgram final : public Program { {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}, {"attn_bias_dim0", ProgramUniformVariableDataType::Uint32}, - {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}); + {"attn_bias_dim1", ProgramUniformVariableDataType::Uint32}, + {"attn_bias_dim3", ProgramUniformVariableDataType::Uint32}); private: bool has_attention_bias_; @@ -148,8 +149,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -195,17 +199,18 @@ class FlashAttentionDecodeVxReduceProgram final : public Program u32 { - return u32(seqlens_k[0]) + 1u; +// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. +fn get_total_sequence_length(batch_idx: u32) -> u32 { + return u32(seqlens_k[batch_idx]) + 1u; } #else -// When graph capture is disabled, total_sequence_length comes from uniforms -fn get_total_sequence_length() -> u32 { +// Without seqlens_k, total_sequence_length comes from uniforms (max across batches). +fn get_total_sequence_length(batch_idx: u32) -> u32 { return uniforms.total_sequence_length; } #endif @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; - let total_seq = get_total_sequence_length(); for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); @@ -95,15 +93,19 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { } #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t { - if (k_idx_global >= get_total_sequence_length()) { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t { + if (k_idx_global >= total_seq) { return q_element_t(0); } let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); - return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; + return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq - 1u)]); } #endif @@ -111,24 +113,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // For max performance max_k_step should be the same as sg_size, however we might run out of registers // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). -fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) { +fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } } -fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) { +fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); - let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length()); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } } @@ -160,18 +162,22 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) { #endif #if has_attention_bias -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (k_idx_global >= get_total_sequence_length()) { + if (k_idx_global >= total_seq) { return vec4(0); } // Handle broadcasting: if dimension size is 1, use index 0 let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() + - bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length(); + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq; let offset = offset_base + k_idx_global; - let offset_max = offset_base + get_total_sequence_length(); + let offset_max = offset_base + stride_total_seq - 1u; let c1 = q_element_t(attention_bias[min(offset, offset_max)]); let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); @@ -179,7 +185,7 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he return vec4(c1, c2, c3, c4); } #else -fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { +fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4 { return vec4(0); } #endif @@ -226,11 +232,14 @@ $MAIN { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; #endif - let total_sequence_length = get_total_sequence_length(); + let total_sequence_length = get_total_sequence_length(batch_idx); #if is_unidirectional // If attention is unidirectional, set the loop bound to enforce causal masking. - let past_sequence_length = total_sequence_length - uniforms.new_sequence_length; + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); let max_causal_len_for_workgroup = past_sequence_length + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup); @@ -244,8 +253,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx); - loadv(k_start, batch_head_idx, local_idx); + loadk(k_start, batch_head_idx, local_idx, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, total_sequence_length); workgroupBarrier(); for (var k = 0u; k < max_k_step; k++) { @@ -254,7 +263,7 @@ $MAIN { score += dot(q_tile[i], k_tile[k][i]); } #if has_attention_bias - score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx); + score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length); #endif qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length); } @@ -302,8 +311,8 @@ $MAIN { for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { workgroupBarrier(); - loadk(k_start, batch_head_idx, local_idx, capped_sg_size); - loadv(k_start, batch_head_idx, local_idx, capped_sg_size); + loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); + loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length); workgroupBarrier(); // Compute QKt @@ -361,11 +370,11 @@ $MAIN { qk_2[3] += dot(q_own, fetchKTile(7, i, k_local)); } } - qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx); + qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length); + qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length); if (sg_size > 8) { - qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx); - qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx); + qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length); + qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length); } // Neuter qk values where K is out of bounds. diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 524a18ca43245..778e07fbf63ff 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -10,6 +10,7 @@ #param tile_size #param tile_size_k_vec #param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -34,18 +35,22 @@ var tile_max: array; var tile_sum: array; #if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + - bias_head_idx * uniforms.new_sequence_length * total_seq_length + - q_idx * total_seq_length + + // Stride along the last dim of attention_bias matches its actual shape, which may + // differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform + // to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly. + let stride_total_seq = uniforms.attn_bias_dim3; + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq + + bias_head_idx * uniforms.new_sequence_length * stride_total_seq + + q_idx * stride_total_seq + k_idx; return attention_bias[offset]; } #else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> q_element_t { return q_element_t(0); } @@ -54,12 +59,14 @@ var tile_sum: array; $MAIN { let local_row = u32(local_idx / tile_size_k_vec); let local_col = local_idx % tile_size_k_vec; + // total_sequence_length used for workgroup_idx slicing must match the host-side dispatch + // grid, i.e. the global maximum across batches. Per-batch total is derived separately below. #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + let global_total_sequence_length = u32(total_sequence_length_input[0]); #else - let total_sequence_length = uniforms.total_sequence_length; + let global_total_sequence_length = uniforms.total_sequence_length; #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_total_seq_length_tile = (global_total_sequence_length + tile_size - 1) / tile_size; let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; @@ -71,9 +78,28 @@ $MAIN { if (batch_idx >= uniforms.batch_size) { return; } + // Per-batch total_sequence_length used for K/V bounds, causal mask, and softmax range. + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[batch_idx]) + 1u; + #else + let total_sequence_length = global_total_sequence_length; + #endif let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + // If this workgroup's tile lies entirely beyond this batch's per-batch total_sequence_length, + // write neutral metadata so VxReduce contributes nothing for these tiles, then exit early. + if (total_seq_offset >= total_sequence_length) { + if (local_idx == 0u) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx_local = q_base + m; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx_local) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(-3.4028234663852886e+38f, 0.0f)); + } + } + return; + } + // ============================================================ // Phase 1: QK^T computation // ============================================================ @@ -109,6 +135,12 @@ $MAIN { } // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask +#if is_unidirectional + // Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0. + let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length, + 0u, + total_sequence_length <= uniforms.new_sequence_length); +#endif for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { let q_idx = q_base + m; if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { @@ -117,9 +149,9 @@ $MAIN { sum += inner_qk_values[m][local_idx][i]; } - sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional - if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + if (total_seq_offset + local_idx > past_sequence_length + q_idx) { sum = q_element_t(-65504.0f); } #endif diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index a3ce0b68cb659..628ad835a9d4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -5,7 +5,7 @@ #param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch +#param use_seqlen_k #use .getByOffset .setByOffset @@ -32,8 +32,12 @@ $MAIN { } let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; + // Per-batch total_sequence_length: short batches contributed neutral metadata + // (-inf, 0) for tiles beyond their per-batch total, so reading only this batch's + // tiles ensures softmax rescaling is not skewed by garbage tiles. + #if use_seqlen_k + let batch_idx_for_seqlen = batch_head_idx / uniforms.num_heads; + let total_sequence_length = u32(seqlens_k[batch_idx_for_seqlen]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 36d688c9723fd..24ace3487a4c5 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -327,6 +327,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& past_value->DataRaw() == present_value->DataRaw(); ORT_ENFORCE(parameters.total_sequence_length_ <= parameters.seqlen_present_kv_cache_, "Total sequence length cannot be greater than the existing KV cache length."); + ORT_ENFORCE(!context.IsGraphCaptureEnabled() || parameters.past_present_share_buffer_, + "Graph capture requires past/present KV cache to share the same buffer (static KV cache)."); Tensor qSplit; Tensor kSplit; @@ -350,7 +352,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking WebgpuAttentionParameters temp_params = parameters; temp_params.is_packed_qkv_ = false; - will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k); + will_use_flash_attention = CanApplyFlashAttention(temp_params, context); } if (kv_empty) { @@ -381,7 +383,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed return ApplyFlashAttention(query, nullptr, nullptr, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink); + present_value, parameters, context, seqlen_k, cos_cache, sin_cache, head_sink, + total_seqlen_tensor); } // Fused: splitQKV + rotary QK qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); @@ -472,7 +475,8 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (will_use_flash_attention) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink); + present_value, parameters, context, seqlen_k, nullptr, nullptr, head_sink, + total_seqlen_tensor); } // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 97c610fb90024..e3d92c036d2c1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -43,7 +43,8 @@ $MAIN { #if prepare_indirect_dispatch if (global_idx == 0u) { - let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; + let global_total_seq_length = u32(total_sequence_length_input[0]); + let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 645564d01abc0..54d7c790f77d8 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -2510,7 +2510,8 @@ static std::vector RunGQAPackedQKVRotaryPrefill( int head_size, const std::vector& seqlens_k_data, const std::vector& packed_qkv_data, - GqaTargetEp target_ep = GqaTargetEp::kCpu) { + GqaTargetEp target_ep = GqaTargetEp::kCpu, + bool smooth_softmax = false) { const int hidden_size = num_heads * head_size; const int kv_hidden_size = kv_num_heads * head_size; const int qkv_hidden = hidden_size + 2 * kv_hidden_size; @@ -2529,6 +2530,12 @@ static std::vector RunGQAPackedQKVRotaryPrefill( tester.AddAttribute("num_heads", static_cast(num_heads)); tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); tester.AddAttribute("do_rotary", static_cast(1)); + if (smooth_softmax) { + // smooth_softmax disqualifies the WebGPU FlashAttention path via the outer + // gating in GroupQueryAttention::ComputeInternal, routing this case through + // ApplyAttention instead. + tester.AddAttribute("smooth_softmax", static_cast(1)); + } // Packed QKV: pass through `query` input, leave key/value as optional edges. if (use_fp16) { @@ -2619,7 +2626,9 @@ static std::vector RunGQAPackedQKVRotaryPrefill( // output matches its single-prompt reference. Both reference and batched runs // go through the same EP, so this validates per-batch consistency within each // EP rather than cross-EP equivalence. -static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { +static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep, + const std::vector& real_lens = {4, 2, 6}, + bool smooth_softmax = false) { constexpr int batch_size = 3; constexpr int num_heads = 4; constexpr int kv_num_heads = 2; @@ -2628,10 +2637,10 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { constexpr int kv_hidden_size = kv_num_heads * head_size; constexpr int qkv_hidden = hidden_size + 2 * kv_hidden_size; - // Real prompt lengths per batch; max = sequence_length (right-padding extends + // Per-batch real prompt lengths; max = sequence_length (right-padding extends // shorter batches up to this length). The bug only manifests when at least // one batch is shorter than sequence_length. - const std::vector real_lens = {4, 2, 6}; + ASSERT_EQ(static_cast(real_lens.size()), batch_size); const int sequence_length = *std::max_element(real_lens.begin(), real_lens.end()); std::vector packed_batched; @@ -2652,7 +2661,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { /*batch_size=*/1, /*sequence_length=*/real_len, num_heads, kv_num_heads, head_size, /*seqlens_k_data=*/{static_cast(real_len - 1)}, - packed_single, target_ep); + packed_single, target_ep, smooth_softmax); } // Now run all batches together with right-padding. @@ -2662,7 +2671,7 @@ static void RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp target_ep) { } const auto batched_output = RunGQAPackedQKVRotaryPrefill( batch_size, sequence_length, num_heads, kv_num_heads, head_size, - seqlens_k_data, packed_batched, target_ep); + seqlens_k_data, packed_batched, target_ep, smooth_softmax); // Guard the regression deterministically: every element of the batched output // (including padding rows) must be finite. The CPU root cause is uninitialized @@ -2715,5 +2724,50 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefill_WebGPU) { RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu); } +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with per-batch +// real_lens whose max crosses the prefill threshold (sequence_length >= 32) so +// the WebGPU EP picks FlashAttentionProgram (single-kernel prefill path with +// subgroup shuffles) instead of the split-reduce decode path. This exercises +// the prefill flash-attention kernel under right-padded batches with do_rotary, +// which is the path used by Phi-style models during batched prefill. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // sequence_length = max(real_lens) = 33 > 32 -> FlashAttentionProgram path. + // Mixed shorter batches (12, 20) ensure right-padding is non-trivial. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 33}); +} + +// Stress the FlashAttention prefill path with a per-batch spread that exceeds +// the indirect-dispatch tile size (64). batch 0 has the SHORTEST real length; +// batch 2 has the LONGEST. This is the data pattern that would surface the +// indirect-dispatch undersizing bug when graph capture is enabled (where the +// dispatch grid is sized from a GPU buffer rather than the host scalar). +// OpTester does not toggle graph capture, so this test exercises the new +// total_sequence_length_input shader plumbing on the non-graph-capture path; +// the graph-capture path is covered end-to-end by phi4-graph-prune verification. +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillFlashAttentionLargeSpread_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + // spread = 96 - 20 = 76 > tile_size(64), batch 0 is not the max. + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {20, 12, 96}); +} + +// Same property as BatchedRightPaddedRotaryPrefill_WebGPU, but with +// smooth_softmax=1 so the WebGPU EP bypasses CanApplyFlashAttention and routes +// through ApplyAttention (non-flash path). Covers right-padded batched prefill +// on the non-flash attention path (used by e.g. Phi-4 attention variants). +TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_WebGPU) { + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + if (!webgpu_ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); +} + } // namespace test } // namespace onnxruntime