Skip to content
111 changes: 72 additions & 39 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

24 changes: 15 additions & 9 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@
{"alpha", ProgramUniformVariableDataType::Float32},
{"num_seq_tile", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim3", ProgramUniformVariableDataType::Uint32});

private:
bool has_attention_bias_;
Expand All @@ -148,8 +149,9 @@
bool has_attention_bias, uint32_t tile_size, int head_size_vec,
bool use_indirect_dispatch, bool q_BNSH = false,
bool is_unidirectional = false,
uint32_t m_tile = 1)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) {
uint32_t m_tile = 1,
bool use_seqlen_k = false)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -164,6 +166,7 @@
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim3", ProgramUniformVariableDataType::Uint32},
{"new_sequence_length", ProgramUniformVariableDataType::Uint32});

private:
Expand All @@ -174,12 +177,13 @@
bool q_BNSH_;
bool is_unidirectional_;
uint32_t m_tile_;
bool use_seqlen_k_;
};

class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
public:
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1)
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) {
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool has_head_sink = false, uint32_t m_tile = 1, bool use_seqlen_k = false)

Check warning on line 185 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:185: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), has_head_sink_(has_head_sink), m_tile_(m_tile), use_seqlen_k_(use_seqlen_k) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -195,17 +199,18 @@
private:
uint32_t tile_size_;
uint32_t seq_tile_size_;
bool use_indirect_dispatch_;
bool has_head_sink_;
uint32_t m_tile_;
bool use_seqlen_k_;
};

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr,
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr);
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr,
const Tensor* total_seqlen = nullptr);

bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);
bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);

// Split packed QKV with Q/K rotary embedding and copy KV cache fusion
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
Expand All @@ -218,7 +223,8 @@
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer,
uint32_t tile_size, uint32_t num_q_tiles);
uint32_t tile_size, uint32_t num_q_tiles,
const Tensor* total_seqlen = nullptr);
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
79 changes: 44 additions & 35 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ const head_size : u32 = qkv_head_size;
const num_heads : u32 = qkv_num_heads;

#if use_seqlen_k
// When graph capture is enabled, total_sequence_length is read from GPU buffer
fn get_total_sequence_length() -> u32 {
return u32(seqlens_k[0]) + 1u;
// When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer.
fn get_total_sequence_length(batch_idx: u32) -> u32 {
return u32(seqlens_k[batch_idx]) + 1u;
}
#else
// When graph capture is disabled, total_sequence_length comes from uniforms
fn get_total_sequence_length() -> u32 {
// Without seqlens_k, total_sequence_length comes from uniforms (max across batches).
fn get_total_sequence_length(batch_idx: u32) -> u32 {
return uniforms.total_sequence_length;
}
#endif
Expand Down Expand Up @@ -65,20 +65,18 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_

var<private> qk_scores : array<q_element_t, max_k_step>;

fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32) {
fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) {
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
k_start * head_size_vec;
let total_seq = get_total_sequence_length();
for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq);
}
}

fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32) {
fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) {
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
v_start * head_size_vec;
let total_seq = get_total_sequence_length();
for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq);
Expand All @@ -95,40 +93,44 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) {
}

#if has_attention_bias
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> q_element_t {
if (k_idx_global >= get_total_sequence_length()) {
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> q_element_t {
if (k_idx_global >= total_seq) {
return q_element_t(0);
}
let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() +
bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length();
return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + get_total_sequence_length())]);
// Stride along the last dim of attention_bias matches its actual shape, which may
// differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform
// to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly.
let stride_total_seq = uniforms.attn_bias_dim3;
let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq +
bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq;
return q_element_t(attention_bias[min(offset_base + k_idx_global, offset_base + stride_total_seq - 1u)]);
}
#endif

#else
// For max performance max_k_step should be the same as sg_size, however we might run out of registers
// for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16).

fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32) {
fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) {
// Stored as float16[batch_size,num_heads,present_sequence_length,head_size]
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
k_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < get_total_sequence_length());
let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq);
k_tile[slot][idx % head_size_vec] = val;
}
}

fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32) {
fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) {
// Stored as float16[batch_size,num_heads,present_sequence_length,head_size]
let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec +
v_start * head_size_vec;
for (var idx : u32 = local_idx; idx < head_size_vec * v_step; idx += workgroup_size_x) {
let slot = u32(idx / head_size_vec);
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < get_total_sequence_length());
let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq);
v_tile[slot][idx % head_size_vec] = val;
}
}
Expand Down Expand Up @@ -160,26 +162,30 @@ fn writeo(batch_idx : u32, o_idx_global : u32, head_idx : u32) {
#endif

#if has_attention_bias
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4<q_element_t> {
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4<q_element_t> {
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
if (k_idx_global >= get_total_sequence_length()) {
if (k_idx_global >= total_seq) {
return vec4<q_element_t>(0);
}
// Handle broadcasting: if dimension size is 1, use index 0
let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * get_total_sequence_length() +
bias_head_idx * uniforms.new_sequence_length * get_total_sequence_length() + q_idx_global * get_total_sequence_length();
// Stride along the last dim of attention_bias matches its actual shape, which may
// differ from the per-step total_sequence_length (e.g. graph-capture sets that uniform
// to 0, or batches use a global max). The host passes attn_bias_dim3 explicitly.
let stride_total_seq = uniforms.attn_bias_dim3;
let offset_base = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * stride_total_seq +
bias_head_idx * uniforms.new_sequence_length * stride_total_seq + q_idx_global * stride_total_seq;
let offset = offset_base + k_idx_global;
let offset_max = offset_base + get_total_sequence_length();
let offset_max = offset_base + stride_total_seq - 1u;
let c1 = q_element_t(attention_bias[min(offset, offset_max)]);
let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]);
let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]);
let c4 = q_element_t(attention_bias[min(offset + 3, offset_max)]);
return vec4<q_element_t>(c1, c2, c3, c4);
}
#else
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4<q_element_t> {
fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, head_idx : u32, total_seq : u32) -> vec4<q_element_t> {
return vec4<q_element_t>(0);
}
#endif
Expand Down Expand Up @@ -226,11 +232,14 @@ $MAIN {
var previous_max : q_element_t = min_value;
var previous_denom : q_element_t = 0;
#endif
let total_sequence_length = get_total_sequence_length();
let total_sequence_length = get_total_sequence_length(batch_idx);

#if is_unidirectional
// If attention is unidirectional, set the loop bound to enforce causal masking.
let past_sequence_length = total_sequence_length - uniforms.new_sequence_length;
// Right-padded batches with prompt shorter than new_sequence_length would underflow u32; clamp to 0.
let past_sequence_length = select(total_sequence_length - uniforms.new_sequence_length,
0u,
total_sequence_length <= uniforms.new_sequence_length);
let max_causal_len_for_workgroup = past_sequence_length +
(workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x;
let loop_bound = min(total_sequence_length, max_causal_len_for_workgroup);
Expand All @@ -244,8 +253,8 @@ $MAIN {

for (var k_start = 0u; k_start < loop_bound; k_start += max_k_step) {
workgroupBarrier();
loadk(k_start, batch_head_idx, local_idx);
loadv(k_start, batch_head_idx, local_idx);
loadk(k_start, batch_head_idx, local_idx, total_sequence_length);
loadv(k_start, batch_head_idx, local_idx, total_sequence_length);
workgroupBarrier();

for (var k = 0u; k < max_k_step; k++) {
Expand All @@ -254,7 +263,7 @@ $MAIN {
score += dot(q_tile[i], k_tile[k][i]);
}
#if has_attention_bias
score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx);
score += loadAttentionBias(batch_idx, q_idx_global, k_start + k, head_idx, total_sequence_length);
#endif
qk_scores[k] = select(min_value, score, k_start + k < seq_causal_length);
}
Expand Down Expand Up @@ -302,8 +311,8 @@ $MAIN {

for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) {
workgroupBarrier();
loadk(k_start, batch_head_idx, local_idx, capped_sg_size);
loadv(k_start, batch_head_idx, local_idx, capped_sg_size);
loadk(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length);
loadv(k_start, batch_head_idx, local_idx, capped_sg_size, total_sequence_length);
workgroupBarrier();

// Compute QKt
Expand Down Expand Up @@ -361,11 +370,11 @@ $MAIN {
qk_2[3] += dot(q_own, fetchKTile(7, i, k_local));
}
}
qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx);
qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx);
qk_1 = qk_1 + loadAttentionBias(batch_idx, q_idx_global, k_start, head_idx, total_sequence_length);
qk_2 = qk_2 + loadAttentionBias(batch_idx, q_idx_global, k_start + 4, head_idx, total_sequence_length);
if (sg_size > 8) {
qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx);
qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx);
qk_3 = qk_3 + loadAttentionBias(batch_idx, q_idx_global, k_start + 8, head_idx, total_sequence_length);
qk_4 = qk_4 + loadAttentionBias(batch_idx, q_idx_global, k_start + 12, head_idx, total_sequence_length);
}

// Neuter qk values where K is out of bounds.
Expand Down
Loading
Loading