From 1380f413546494da89f03ef576c11db251f23e19 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 26 Jun 2026 14:08:00 -0700 Subject: [PATCH 1/3] Turbo Quant: WebGPU KV cache quantization (squashed) --- .../cpu/bert/group_query_attention_helper.h | 58 +- .../webgpu/bert/flash_attention.cc | 260 ++++-- .../contrib_ops/webgpu/bert/flash_attention.h | 17 +- .../webgpu/bert/flash_attention.wgsl.template | 103 +++ .../flash_attention_decode_qkv.wgsl.template | 134 ++- .../webgpu/bert/group_query_attention.cc | 41 +- .../webgpu/bert/hadamard_transform.cc | 60 ++ .../webgpu/bert/hadamard_transform.h | 62 ++ .../bert/hadamard_transform.wgsl.template | 67 ++ .../hadamard_transform_common.wgsl.template | 100 +++ .../indirect_dispatch_common.wgsl.template | 28 + ..._rotary_embedding_and_copykv.wgsl.template | 4 + .../bert/turbo_quant_common.wgsl.template | 33 + .../bert/turbo_quant_dequant.wgsl.template | 19 + ..._quant_fused_rotary_hadamard.wgsl.template | 237 ++++++ .../webgpu/bert/turbo_quant_hadamard.cc | 278 +++++++ .../webgpu/bert/turbo_quant_hadamard.h | 155 ++++ .../bert/turbo_quant_hadamard.wgsl.template | 197 +++++ .../core/providers/webgpu/compute_context.h | 14 + .../webgpu/webgpu_execution_provider.cc | 1 + .../webgpu/webgpu_execution_provider.h | 4 + .../webgpu/webgpu_provider_factory.cc | 11 + .../webgpu/webgpu_provider_options.h | 7 + .../group_query_attention_op_test.cc | 776 ++++++++++++++++++ 24 files changed, 2557 insertions(+), 109 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h create mode 100644 onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/indirect_dispatch_common.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_dequant.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary_hadamard.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h create mode 100644 onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 3429ca5f5be52..c52ba3a686797 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -97,7 +97,7 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const template Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width, - int& past_sequence_length) { + int& past_sequence_length, int kv_cache_extra_bits = 0) { const auto& past_key_dims = past_key->Shape().GetDims(); const auto& past_value_dims = past_value->Shape().GetDims(); @@ -140,17 +140,25 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ // We assume all sequence in past kv are right-padded to max or past sequence length past_sequence_length = static_cast(past_key_dims[2]); - // For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte. - // Note that we have checked that head_size is a multiple of 8 in Check_QKV. - int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; + // Compute expected KV cache head dimension from quantization parameters. + // kv_cache_bit_width: bits per element (4 or 8). 0 means no quantization. + // kv_cache_extra_bits: additional metadata bits per head + // (e.g., 32bits for TurboQuant storing scale). + int packed_head_size; + if (kv_cache_bit_width == 0) { + packed_head_size = head_size; + } else { + int bits_per_element = static_cast(past_key->DataType()->Size()) * 8; + packed_head_size = (head_size * kv_cache_bit_width + kv_cache_extra_bits) / bits_per_element; + } if (past_key_dims[3] != packed_head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", + "Input 'past_key' dimension 3 should match the packed KV head dimension, got ", past_key_dims[3], " expected ", packed_head_size); } if (past_value_dims[3] != packed_head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", + "Input 'past_value' dimension 3 should match the packed KV head dimension, got ", past_value_dims[3], " expected ", packed_head_size); } return Status::OK(); @@ -206,7 +214,12 @@ Status CheckInputs(const T* query, const T* total_seqlen, float scale, float softcap, - int kv_cache_bit_width) { + int kv_cache_bit_width, + int max_threads_per_block = 0, + int kv_cache_extra_bits = 0) { + if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); + } // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr @@ -246,10 +259,15 @@ Status CheckInputs(const T* query, kv_sequence_length = sequence_length; } + if (kv_cache_extra_bits != 0 && kv_cache_bit_width == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_cache_extra_bits requires kv_cache_bit_width to be non-zero."); + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length)); + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length, kv_cache_extra_bits)); // When past KV exists, Q and K/V must have the same sequence length, // UNLESS kv_sequence_length is 0 (shared KV: new K/V are empty, past buffer // already contains the full shared KV cache — no append needed). @@ -377,30 +395,6 @@ Status CheckInputs(const T* query, return Status::OK(); } -template -Status CheckInputs(const T* query, - const T* key, - const T* value, - const T* past_key, - const T* past_value, - const T* cos_cache, - const T* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const T* seqlens_k, - const T* total_seqlen, - float scale, - float softcap, - int kv_cache_bit_width, - int max_threads_per_block) { - if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); - } - - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width); -} - template Status CheckCustomAttentionInputs(const T* position_ids, const T* attention_bias, diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 4e926c7efa597..45f1af8939a2c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -3,6 +3,8 @@ #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/webgpu/bert/flash_attention.h" +#include "contrib_ops/webgpu/bert/hadamard_transform.h" +#include "contrib_ops/webgpu/bert/turbo_quant_hadamard.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -26,8 +28,8 @@ namespace webgpu { // workgroup_id (x, y, z) into a single linear workgroup_idx. // Caller contract: must register a storage output named exactly // `indirect_buffer` of array with at least 3 elements. -constexpr const char kNormalizeDispatchGroupSizeFn[] = R"( -fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) { +constexpr const char kPopulateIndirectDispatchBufferFn[] = R"( +fn populate_indirect_dispatch_buffer(x: u32, y: u32, z: u32) { let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension if (x <= limit && y <= limit && z <= limit) { indirect_buffer[0] = x; @@ -65,7 +67,6 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha if (prepare_indirect_dispatch_) { sh.AddOutput("indirect_buffer", ShaderUsage::None); - sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; } return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", @@ -130,7 +131,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { - shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; + shader.AdditionalImplementation() << kPopulateIndirectDispatchBufferFn; shader.MainFunctionBody() << " if (global_idx == 0u) {\n" << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" @@ -257,6 +258,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("output", ShaderUsage::UseUniform); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", + WGSL_TEMPLATE_PARAMETER(compressed_head_size_u32, compressed_head_size_u32_), WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), @@ -267,13 +269,14 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_), + WGSL_TEMPLATE_PARAMETER(turbo_quant, turbo_quant_), WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); @@ -293,6 +296,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const uint32_t tile_size_k_vec = 8; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template", + WGSL_TEMPLATE_PARAMETER(compressed_head_size_u32, compressed_head_size_u32_), WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), @@ -300,6 +304,7 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_PARAMETER(turbo_quant, turbo_quant_), WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), @@ -313,20 +318,24 @@ Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen) { + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile, bool use_seqlen_k, const Tensor* total_seqlen, + bool turbo_quant, int compressed_head_size_u32) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; + // TurboQuant changes view of kv cache from fp16/fp32 to packed u32. + // It already packs 4 float values into a single u32, so KV cache tensors use 1 component. + const int kv_cache_components = turbo_quant ? 1 : components; const int head_size_vec = parameters.v_head_size_ / components; bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; bool is_unidirectional = parameters.is_unidirectional_; - FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k}; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k, turbo_quant, compressed_head_size_u32}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + {present_key, ProgramTensorMetadataDependency::TypeAndRank, kv_cache_components}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, kv_cache_components}}); if (use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } @@ -357,7 +366,7 @@ Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& conte program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile, use_seqlen_k, turbo_quant, compressed_head_size_u32) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -445,6 +454,24 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const Tensor* total_seqlen) { constexpr uint32_t tile_size = 64; + const bool turbo_quant_enabled = context.KvCacheQuantizationEnabled(); + if (turbo_quant_enabled && (parameters.head_size_ < 8 || (parameters.head_size_ & (parameters.head_size_ - 1)) != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV cache quantization requires head_size >= 8 and a power of 2. Got head_size=", + parameters.head_size_); + } + + // Compressed head dimension, expressed in two units: + // compressed_head_size_u32 — u32 words per head (1 scale + head_size/8 packed 4-bit indices), + // passed to the shaders as the packed KV dimension. + // present_last_dim — the same span counted in Q elements (fp16/fp32), used to size an + // internally-allocated present buffer so its u32 view lines up + // (compressed_head_size_u32 * 4 bytes == present_last_dim * sizeof(Q elem)). + const int compressed_head_size_u32 = turbo_quant_enabled ? (parameters.head_size_ / 8 + 1) : 0; + const int64_t present_last_dim = + turbo_quant_enabled + ? static_cast(compressed_head_size_u32) * 4 / static_cast(Q->DataType()->Size()) + : parameters.head_size_; // Create present_key and present_value tensors if they are nullptr. // Skip allocation for kv_empty — present will be aliased to past below. Tensor internal_present_key; @@ -453,13 +480,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const bool kv_empty = (parameters.kv_sequence_length_ == 0); if (!kv_empty && present_key == nullptr) { TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads, - parameters.total_sequence_length_, parameters.head_size_}); + parameters.total_sequence_length_, present_last_dim}); internal_present_key = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape)); present_key = &internal_present_key; } if (!kv_empty && present_value == nullptr) { TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads, - parameters.total_sequence_length_, parameters.head_size_}); + parameters.total_sequence_length_, present_last_dim}); internal_present_value = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape)); present_value = &internal_present_value; } @@ -473,6 +500,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; + // Declare rotated_q at function scope so the pointer remains valid + Tensor rotated_q; // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch. const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u); @@ -498,13 +527,16 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr); + // kv_empty (kv_sequence_length_ == 0) occurs in KV-shared / cross-layer KV reuse layers: the + // layer computes its own Q but borrows another layer's already-populated KV cache instead of + // producing new K/V. There is nothing to copy, so CopyKVCache is skipped and attention reads + // the past buffers directly. Because no new KV is written, present buffers are intentionally + // not allocated above and some call sites pass nullptr present outputs — so we alias past as + // present here. if (kv_empty) { - // kv_sequence_length==0: K/V inputs are empty (shared KV layer). - // Skip CopyKVCache and fused split+rotary+copyKV. - // Use past_key/past_value directly as the present buffers for attention. - // Note: do_rotary is always false here because GQA passes cos_cache=nullptr, sin_cache=nullptr - // for kv_empty layers (rotary is applied to Q separately in GQA before calling ApplyFlashAttention). - ORT_ENFORCE(!do_rotary, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV should not be used with kv_sequence_length==0."); + // do_rotary must be false here: GQA passes cos_cache=nullptr, sin_cache=nullptr for kv_empty + // layers (rotary is applied to Q separately in GQA before calling ApplyFlashAttention). + ORT_ENFORCE(!do_rotary, "kv_empty (kv_sequence_length==0) is incompatible with fused rotary+copyKV."); ORT_ENFORCE(past_key != nullptr && past_value != nullptr, "kv_empty path requires past KV context (KV-shared layers reuse another layer's cache)."); // When past_present_share_buffer_ is true (MayInplace optimization), present already @@ -515,22 +547,95 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co present_key = const_cast(past_key); present_value = const_cast(past_value); } - } else if (do_rotary) { - ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input."); - ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache."); - - // Q points to the packed QKV tensor in this case, create query output tensor - query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); - - ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters, - Q, seqlen_k, - cos_cache, sin_cache, - &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size, num_q_tiles, - total_seqlen)); - Q = &query_output; - } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); + } + + // When TurboQuant is active, create u32 tensor views over present/past KV cache buffers. + Tensor present_key_u32, present_value_u32; + Tensor past_key_u32, past_value_u32; + Tensor* tq_present_key = present_key; + Tensor* tq_present_value = present_value; + const Tensor* tq_past_key = past_key; + const Tensor* tq_past_value = past_value; + if (turbo_quant_enabled) { + const int64_t bytes_per_elem = static_cast(present_key->DataType()->Size()); + const int64_t expected_last_dim_bytes = static_cast(compressed_head_size_u32) * 4; + ORT_RETURN_IF_ERROR( + (present_key->Shape().NumDimensions() == 4 && present_value->Shape().NumDimensions() == 4) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant expects present_key/present_value to be 4-D tensors.")); + ORT_RETURN_IF_ERROR( + (present_key->Shape()[3] * bytes_per_elem == expected_last_dim_bytes) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant KV cache shape mismatch for present_key. Expected last_dim_bytes==", + expected_last_dim_bytes, ", got shape=", present_key->Shape().ToString())); + ORT_RETURN_IF_ERROR( + (present_value->Shape()[3] * bytes_per_elem == expected_last_dim_bytes) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant KV cache shape mismatch for present_value. Expected last_dim_bytes==", + expected_last_dim_bytes, ", got shape=", present_value->Shape().ToString())); + + TensorShapeVector u32_present_shape({present_key->Shape()[0], present_key->Shape()[1], + present_key->Shape()[2], + static_cast(compressed_head_size_u32)}); + present_key_u32 = Tensor(DataTypeImpl::GetType(), TensorShape(u32_present_shape), + present_key->MutableDataRaw(), present_key->Location()); + present_value_u32 = Tensor(DataTypeImpl::GetType(), TensorShape(u32_present_shape), + present_value->MutableDataRaw(), present_value->Location()); + tq_present_key = &present_key_u32; + tq_present_value = &present_value_u32; + + if (past_key != nullptr && past_key->SizeInBytes() > 0) { + TensorShapeVector u32_past_shape({past_key->Shape()[0], past_key->Shape()[1], + past_key->Shape()[2], + static_cast(compressed_head_size_u32)}); + // past_key_u32 / past_value_u32 are read-only aliases over the past KV cache buffers. + // The Tensor ctor takes a non-const data pointer, so const_cast is required here, but the + // flash attention kernels only read through tq_past_key / tq_past_value — never write. + past_key_u32 = Tensor(DataTypeImpl::GetType(), TensorShape(u32_past_shape), + const_cast(past_key->DataRaw()), past_key->Location()); + past_value_u32 = Tensor(DataTypeImpl::GetType(), TensorShape(u32_past_shape), + const_cast(past_value->DataRaw()), past_value->Location()); + tq_past_key = &past_key_u32; + tq_past_value = &past_value_u32; + } + } + + // K/V copy is skipped for kv_empty (see the aliasing block above for why). + if (!kv_empty) { + if (do_rotary) { + ORT_ENFORCE(parameters.is_packed_qkv_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires packed QKV input."); + ORT_ENFORCE(parameters.past_present_share_buffer_, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV requires static KV cache."); + + // Q points to the packed QKV tensor in this case, create query output tensor + query_output = context.CreateGPUTensor(Q->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); + + if (turbo_quant_enabled) { + ORT_RETURN_IF_ERROR(TurboQuantApplyRotaryAndCopyToQuantizedKVCache(context, parameters, + Q, seqlen_k, + cos_cache, sin_cache, + &query_output, tq_present_key, tq_present_value, + indirect_buffer_ptr, tile_size, num_q_tiles)); + } else { + ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters, + Q, seqlen_k, + cos_cache, sin_cache, + &query_output, present_key, present_value, + indirect_buffer_ptr, tile_size, num_q_tiles, + total_seqlen)); + } + Q = &query_output; + } else if (turbo_quant_enabled) { + // TurboQuant without rotary: K/V must be non-null (kv_empty already handled above). + ORT_ENFORCE(K != nullptr && V != nullptr, + "TurboQuant requires non-null K/V inputs when kv_sequence_length > 0."); + ORT_RETURN_IF_ERROR(TurboQuantCopyToQuantizedKVCache(context, parameters, K, tq_past_key, tq_present_key, V, tq_past_value, tq_present_value, + tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); + } else { + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles, total_seqlen)); + } } // Extract present_sequence_length directly from present_key tensor shape @@ -538,12 +643,28 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); + // Rotate Q before attention (Hadamard transform for TurboQuant). + if (turbo_quant_enabled) { + rotated_q = context.CreateGPUTensor(Q->DataType(), Q->Shape()); + ORT_RETURN_IF_ERROR(ApplyHadamardTransform(context, Q, &rotated_q, parameters.head_size_)); + Q = &rotated_q; + } + + // When TurboQuant is active, write attention output to a temp buffer, then + // inverse-Hadamard from temp -> final output. + Tensor attn_output_temp; + Tensor* attn_output = output; + if (turbo_quant_enabled) { + attn_output_temp = context.CreateGPUTensor(output->DataType(), output->Shape()); + attn_output = &attn_output_temp; + } + // Route between prefill path (FlashAttentionProgram, single kernel) // and split-reduce decode path (QKV + VxReduce, 2 kernels). // Split-reduce wins for short Q (sequence_length < 32) across all KV // cache lengths measured: 1.13x-2.07x faster at total_sequence_length // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96). - const bool use_split_reduce = parameters.sequence_length_ < 32; + const bool use_split_reduce = (parameters.sequence_length_ < 32); if (!use_split_reduce) { // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles) @@ -567,10 +688,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co has_subgroups, q_BNSH, use_seqlen_k, - has_head_sink}; + has_head_sink, + turbo_quant_enabled, + compressed_head_size_u32}; + // When TQ is active, KV cache is u32-packed — use u32 tensor views for present_key/present_value. + const Tensor* fa_present_key = turbo_quant_enabled ? tq_present_key : present_key; + const Tensor* fa_present_value = turbo_quant_enabled ? tq_present_value : present_value; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + {fa_present_key, ProgramTensorMetadataDependency::TypeAndRank, turbo_quant_enabled ? 1 : 4}, + {fa_present_value, ProgramTensorMetadataDependency::TypeAndRank, turbo_quant_enabled ? 1 : 4}}); if (has_attention_bias) { program.AddInputs({{attention_bias, ProgramTensorMetadataDependency::TypeAndRank}}); } @@ -580,7 +706,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co if (has_head_sink) { program.AddInputs({{head_sink, ProgramTensorMetadataDependency::Type}}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); + program.AddOutputs({{attn_output, ProgramTensorMetadataDependency::TypeAndRank, 4}}); const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; @@ -600,7 +726,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(prefill_tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, program.max_k_step()) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm, is_nvidia, is_apple, has_subgroups, q_BNSH, use_seqlen_k, has_head_sink, turbo_quant_enabled, compressed_head_size_u32, program.max_k_step()) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, @@ -612,34 +738,42 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {attn_bias_dim1}, {attn_bias_dim3}}); - return context.RunProgram(program); - } + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + } else { + // Split-reduce path (fused QKV + VxReduce). Handles both TQ and non-TQ. + const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; + const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; + + const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, num_present_sequence_length_tile, 2}); + const TensorShape metadata_shape(metadata_dims); + Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - // Split-reduce path (QKV + VxReduce) - const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; - const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; + const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); + const TensorShape out_split_vx_shape(out_split_vx_dims); + Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - // The metadata is used to store the max and sum of each tile. - const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, num_present_sequence_length_tile, 2}); - const TensorShape metadata_shape(metadata_dims); - Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); + Tensor* qkv_present_key = turbo_quant_enabled ? tq_present_key : present_key; + Tensor* qkv_present_value = turbo_quant_enabled ? tq_present_value : present_value; - const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); - const TensorShape out_split_vx_shape(out_split_vx_dims); - Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, qkv_present_key, qkv_present_value, + &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + present_sequence_length, m_tile, use_seqlen_k, total_seqlen, + turbo_quant_enabled, compressed_head_size_u32)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value, - &metadata, seqlen_k, - parameters, indirect_buffer_ptr, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length, m_tile, use_seqlen_k, total_seqlen)); + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, attn_output, seqlen_k, parameters, + num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, + head_sink, m_tile, use_seqlen_k)); + } - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, - num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, - head_sink, m_tile, use_seqlen_k)); + // Apply inverse Hadamard transform: attn_output_temp -> output. + if (turbo_quant_enabled) { + ORT_RETURN_IF_ERROR(ApplyHadamardTransform(context, attn_output, output, parameters.head_size_)); + } return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 85ba61c1d20b5..02f25a753fff8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -84,7 +84,9 @@ class FlashAttentionProgram final : public Program { bool has_subgroups, bool q_BNSH, bool use_seqlen_k = false, - bool has_head_sink = false) + bool has_head_sink = false, + bool turbo_quant = false, + int compressed_head_size_u32 = 0) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), @@ -96,7 +98,9 @@ class FlashAttentionProgram final : public Program { use_shm_path_(is_apple || is_nvidia || !has_subgroups), q_BNSH_(q_BNSH), use_seqlen_k_(use_seqlen_k), - has_head_sink_(has_head_sink) { + has_head_sink_(has_head_sink), + turbo_quant_(turbo_quant), + compressed_head_size_u32_(compressed_head_size_u32) { if (use_shm_path_) { // Use shared-memory loop-based path with dynamic max_k_step. // Compute max_k_step from workgroup shared memory budget: k_tile + v_tile = 2 * element_size * head_size * max_k_step @@ -141,6 +145,8 @@ class FlashAttentionProgram final : public Program { bool use_seqlen_k_; bool has_head_sink_; int max_k_step_; + bool turbo_quant_; + int compressed_head_size_u32_; }; class FlashAttentionDecodeQKVProgram final : public Program { @@ -150,8 +156,9 @@ class FlashAttentionDecodeQKVProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template index 74cebc8184780..65766dc8959fc 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -8,6 +8,8 @@ #param q_BNSH #param qkv_head_size #param qkv_num_heads +#param turbo_quant +#param compressed_head_size_u32 #param use_seqlen_k #param use_shm_path #param max_k_step_param @@ -15,6 +17,20 @@ const head_size : u32 = qkv_head_size; const num_heads : u32 = qkv_num_heads; +#if turbo_quant +#include "bert/turbo_quant_common.wgsl.template" +#include "bert/turbo_quant_dequant.wgsl.template" +const COMPRESSED_HEAD_U32 : u32 = compressed_head_size_u32; + +// Dequantize 4 consecutive elements from a TQ-packed KV cache into a vec4. +fn tq_dequant_vec4(kv_cache: ptr, read>, base: u32, elem_base: u32, scale: f32) -> q_value_t { + let word_idx = elem_base >> 3u; + let packed = (*kv_cache)[base + 1u + word_idx]; + let shift = (elem_base & 4u) << 2u; + return tq_unpack_nibbles(packed >> shift) * q_element_t(scale); +} +#endif + #if use_seqlen_k // When seqlens_k is provided, total_sequence_length is read per batch from the GPU buffer. fn get_total_sequence_length(batch_idx: u32) -> u32 { @@ -40,6 +56,9 @@ const head_size_vec : u32 = head_size / vec_factor; // K and V tiles in shared memory. var k_tile : array, max_k_step>; var v_tile : array, max_k_step>; +#if turbo_quant +var tq_lut : array; +#endif // Private memory per lane. var q_tile : array; @@ -66,21 +85,60 @@ fn loadq(batch_idx : u32, q_idx_global : u32, head_idx : u32, alpha : q_element_ var qk_scores : array; fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { +#if turbo_quant + // TurboQuant: KV cache is u32-packed (scale + 4-bit indices). Dequantize on load. + // Parallelize across slots; each lane dequantizes one full row (head_size_vec vec4s). + let kv_head_idx = batch_head_idx / uniforms.n_reps; + for (var slot : u32 = local_idx; slot < max_k_step; slot += workgroup_size_x) { + let seq_idx = k_start + slot; + if (seq_idx < total_seq) { + let base = kv_head_idx * uniforms.present_sequence_length * COMPRESSED_HEAD_U32 + seq_idx * COMPRESSED_HEAD_U32; + let scale = bitcast(present_key[base]); + for (var v : u32 = 0u; v < head_size_vec; v++) { + k_tile[slot][v] = tq_dequant_vec4(&present_key, base, v * 4u, scale); + } + } else { + for (var v : u32 = 0u; v < head_size_vec; v++) { + k_tile[slot][v] = q_value_t(0); + } + } + } +#else let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); k_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); } +#endif } fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, total_seq : u32) { +#if turbo_quant + // TurboQuant: KV cache is u32-packed (scale + 4-bit indices). Dequantize on load. + let kv_head_idx = batch_head_idx / uniforms.n_reps; + for (var slot : u32 = local_idx; slot < max_k_step; slot += workgroup_size_x) { + let seq_idx = v_start + slot; + if (seq_idx < total_seq) { + let base = kv_head_idx * uniforms.present_sequence_length * COMPRESSED_HEAD_U32 + seq_idx * COMPRESSED_HEAD_U32; + let scale = bitcast(present_value[base]); + for (var v : u32 = 0u; v < head_size_vec; v++) { + v_tile[slot][v] = tq_dequant_vec4(&present_value, base, v * 4u, scale); + } + } else { + for (var v : u32 = 0u; v < head_size_vec; v++) { + v_tile[slot][v] = q_value_t(0); + } + } + } +#else let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; for (var idx : u32 = local_idx; idx < head_size_vec * max_k_step; idx += workgroup_size_x) { let slot = u32(idx / head_size_vec); v_tile[slot][idx % head_size_vec] = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); } +#endif } var o_tile : array; @@ -114,6 +172,24 @@ fn loadAttentionBias(batch_idx : u32, q_idx_global : u32, k_idx_global : u32, he // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, total_seq : u32) { +#if turbo_quant + // TurboQuant: KV cache is u32-packed (scale + 4-bit indices). Dequantize on load. + let kv_head_idx = batch_head_idx / uniforms.n_reps; + for (var slot : u32 = local_idx; slot < k_step; slot += workgroup_size_x) { + let seq_idx = k_start + slot; + if (seq_idx < total_seq) { + let base = kv_head_idx * uniforms.present_sequence_length * COMPRESSED_HEAD_U32 + seq_idx * COMPRESSED_HEAD_U32; + let scale = bitcast(present_key[base]); + for (var v : u32 = 0u; v < head_size_vec; v++) { + k_tile[slot][v] = tq_dequant_vec4(&present_key, base, v * 4u, scale); + } + } else { + for (var v : u32 = 0u; v < head_size_vec; v++) { + k_tile[slot][v] = q_value_t(0); + } + } + } +#else // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; @@ -122,9 +198,28 @@ fn loadk(k_start : u32, batch_head_idx : u32, local_idx : u32, k_step : u32, tot let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < total_seq); k_tile[slot][idx % head_size_vec] = val; } +#endif } fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, total_seq : u32) { +#if turbo_quant + // TurboQuant: KV cache is u32-packed (scale + 4-bit indices). Dequantize on load. + let kv_head_idx = batch_head_idx / uniforms.n_reps; + for (var slot : u32 = local_idx; slot < v_step; slot += workgroup_size_x) { + let seq_idx = v_start + slot; + if (seq_idx < total_seq) { + let base = kv_head_idx * uniforms.present_sequence_length * COMPRESSED_HEAD_U32 + seq_idx * COMPRESSED_HEAD_U32; + let scale = bitcast(present_value[base]); + for (var v : u32 = 0u; v < head_size_vec; v++) { + v_tile[slot][v] = tq_dequant_vec4(&present_value, base, v * 4u, scale); + } + } else { + for (var v : u32 = 0u; v < head_size_vec; v++) { + v_tile[slot][v] = q_value_t(0); + } + } + } +#else // Stored as float16[batch_size,num_heads,present_sequence_length,head_size] let offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; @@ -133,6 +228,7 @@ fn loadv(v_start : u32, batch_head_idx : u32, local_idx : u32, v_step : u32, tot let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < total_seq); v_tile[slot][idx % head_size_vec] = val; } +#endif } #if is_qualcomm @@ -217,6 +313,13 @@ $MAIN { return; } +#if turbo_quant + // Load centroid LUT into shared memory once. The workgroupBarrier before loadk/loadv synchronizes. + if (local_idx < 16u) { + tq_lut[local_idx] = TQ_CENTROIDS[local_idx]; + } +#endif + // Load Q let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx; let valid_q = q_idx_global < uniforms.new_sequence_length; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template index 778e07fbf63ff..d5d2b89fcf29c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -1,19 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#param compressed_head_size_u32 #param has_attention_bias -#param v_head_size_vec #param is_unidirectional #param m_tile #param q_BNSH #param sub_tile_count #param tile_size #param tile_size_k_vec +#param turbo_quant #param use_indirect_dispatch #param use_seqlen_k +#param v_head_size_vec #use .getByOffset .setByOffset +#if turbo_quant +#include "bert/turbo_quant_common.wgsl.template" +#include "bert/turbo_quant_dequant.wgsl.template" +const COMPRESSED_HEAD_U32 : u32 = compressed_head_size_u32; +const COMPRESSED_HEAD_U32_WITHOUT_SCALE : u32 = COMPRESSED_HEAD_U32 - 1u; +#endif + // Fused QK^T + softmax + V multiply shader. // // Each workgroup processes one KV tile (tile_size rows of present_key/value) @@ -26,11 +35,19 @@ // // The VxReduce shader performs the final rescaling across tiles. +#if turbo_quant +// TQ: preload all Q vec4s and centroid LUT into shared memory. +var all_q: array, m_tile>; +var tq_k_scales: array; +var tq_v_scales: array; +var tq_lut: array; +#else var tile_q: array, m_tile>; +#endif var inner_qk_values: array, tile_size>, m_tile>; -var tile_qk: array, m_tile>; -var tile_output: array, m_tile>; -var qkv_values: array, sub_tile_count>, m_tile>; +var tile_qk: array, m_tile>; +var tile_output: array, m_tile>; +var qkv_values: array, sub_tile_count>, m_tile>; var tile_max: array; var tile_sum: array; @@ -84,6 +101,66 @@ $MAIN { #else let total_sequence_length = global_total_sequence_length; #endif + +#if turbo_quant + let kv_head_offset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * COMPRESSED_HEAD_U32; + + // Preload centroid LUT. + if (local_idx < 16u) { + tq_lut[local_idx] = TQ_CENTROIDS[local_idx]; + } + + // Preload K scales for this tile. + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + let scale_base = kv_head_offset + (total_seq_offset + local_idx) * COMPRESSED_HEAD_U32; + tq_k_scales[local_idx] = bitcast(present_key[scale_base]); + tq_v_scales[local_idx] = bitcast(present_value[scale_base]); + } + + // Preload all Q into shared memory. + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; +#if q_BNSH + let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec + + head_idx * uniforms.new_sequence_length * uniforms.head_size_vec + + q_idx * uniforms.head_size_vec; +#else + let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec; +#endif + for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) { + all_q[m][i] = q.getByOffset(q_offset + i) * q_element_t(uniforms.alpha); + } + } + workgroupBarrier(); + + // ============================================================ + // Phase 1 (TQ): QK^T with dequantized K from packed u32 + // ============================================================ + // Each thread processes one u32 word per iteration (8 nibbles → 2 vec4 dot products). + for (var kw: u32 = 0u; kw < COMPRESSED_HEAD_U32_WITHOUT_SCALE; kw += tile_size_k_vec) { + let word_idx = kw + local_col; + if (word_idx < COMPRESSED_HEAD_U32_WITHOUT_SCALE) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + let seq_idx = total_seq_offset + row_offset + local_row; + if (seq_idx < total_sequence_length) { + let base = kv_head_offset + seq_idx * COMPRESSED_HEAD_U32; + let packed = present_key[base + 1u + word_idx]; + let k_lo = tq_unpack_nibbles(packed); + let k_hi = tq_unpack_nibbles(packed >> 16u); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let mq_lo = all_q[m][word_idx * 2u]; + let mq_hi = all_q[m][word_idx * 2u + 1u]; + inner_qk_values[m][row_offset + local_row][local_col] += dot(k_lo, mq_lo) + dot(k_hi, mq_hi); + } + } + } + } + } + workgroupBarrier(); + +#else let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; @@ -133,6 +210,7 @@ $MAIN { } workgroupBarrier(); } +#endif // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask #if is_unidirectional @@ -148,6 +226,10 @@ $MAIN { for (var i = 0u; i < tile_size_k_vec; i++) { sum += inner_qk_values[m][local_idx][i]; } +#if turbo_quant + // Apply the deferred scale (L2 norm factored out of the inner loop). + sum *= q_element_t(tq_k_scales[local_idx]); +#endif sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx); #if is_unidirectional @@ -155,7 +237,7 @@ $MAIN { sum = q_element_t(-65504.0f); } #endif - tile_qk[m][local_idx] = present_value_element_t(sum); + tile_qk[m][local_idx] = sum; } workgroupBarrier(); @@ -184,14 +266,51 @@ $MAIN { // Normalize tile_qk with local max/sum for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); + tile_qk[m][local_idx] = q_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); } } workgroupBarrier(); +#if turbo_quant + // TQ V multiply: dequantize V from packed u32 on the fly. + for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] = q_value_t(0); + } + workgroupBarrier(); + + if (k + local_col < v_head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + let seq_idx = total_seq_offset + row_offset + local_row; + if (seq_idx < total_sequence_length) { + let elem_base = (k + local_col) * 4u; + let tq_word_idx = elem_base >> 3u; + let base = kv_head_offset + seq_idx * COMPRESSED_HEAD_U32; + let scale = tq_v_scales[row_offset + local_row]; + let packed = present_value[base + 1u + tq_word_idx]; + let tq_shift = (elem_base & 4u) << 2u; + let v_val = tq_unpack_nibbles(packed >> tq_shift) * q_element_t(scale); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] += v_val * tile_qk[m][row_offset + local_row]; + } + } + } + } + workgroupBarrier(); + + if (local_idx < tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + for (var i = 0u; i < sub_tile_count; i++) { + tile_output[m][k + local_idx] += qkv_values[m][i][local_idx]; + } + } + } + workgroupBarrier(); + } +#else for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { - qkv_values[m][local_row][local_col] = present_value_value_t(0); + qkv_values[m][local_row][local_col] = q_value_t(0); } workgroupBarrier(); @@ -216,6 +335,7 @@ $MAIN { } workgroupBarrier(); } +#endif // Write output let tile_idx = workgroup_idx % num_total_seq_length_tile; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 24ace3487a4c5..879a63a2482ca 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -241,6 +241,23 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } GroupQueryAttentionParameters params = {}; + + // KV cache quantization uses 4-bit quantization with 32 extra bits (1 u32) per head for the L2 norm. + // Requires head_size >= 8 and power-of-2. + const uint32_t kv_cache_bits = context.KvCacheQuantizationBits(); + const bool kv_cache_quant = kv_cache_bits != 0; + const int kv_cache_bit_width = static_cast(kv_cache_bits); + const int kv_cache_extra_bits = kv_cache_quant ? 32 : 0; + if (kv_cache_quant) { + const int qkv_last_dim = static_cast(query->Shape().GetDims()[2]); + const bool is_packed = (key == nullptr); + const int hs = is_packed ? qkv_last_dim / (num_heads_ + 2 * kv_num_heads_) : qkv_last_dim / num_heads_; + if (hs < 8 || (hs & (hs - 1)) != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV cache quantization requires head_size >= 8 and a power of 2. Got head_size=", hs); + } + } + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, value, @@ -255,8 +272,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& total_seqlen_tensor, scale_, softcap_, - 0, - onnxruntime::narrow(context.DeviceLimits().maxComputeInvocationsPerWorkgroup))); + kv_cache_bit_width, + onnxruntime::narrow(context.DeviceLimits().maxComputeInvocationsPerWorkgroup), + kv_cache_extra_bits)); params.use_smooth_softmax = use_smooth_softmax_; params.rotary_interleaved = rotary_interleaved_; @@ -310,11 +328,19 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& output_shape[1] = static_cast(parameters.sequence_length_); output_shape[2] = static_cast(parameters.hidden_size_); Tensor* output = context.Output(0, output_shape); + + // When TurboQuant is enabled, the KV cache head dimension is compressed. + // Derive from quantization parameters: (head_size * bit_width + extra_bits) / bits_per_element. + int64_t kv_head_dim = parameters.head_size_; + if (kv_cache_bit_width > 0) { + int bits_per_element = static_cast(query->DataType()->Size()) * 8; + kv_head_dim = (parameters.head_size_ * kv_cache_bit_width + kv_cache_extra_bits) / bits_per_element; + } std::vector present_dims{ parameters.batch_size_, kv_num_heads_, parameters.seqlen_present_kv_cache_, - parameters.head_size_}; + kv_head_dim}; std::vector present_kv_shape(present_dims); Tensor* present_key = context.Output(1, present_kv_shape); Tensor* present_value = context.Output(2, present_kv_shape); @@ -379,6 +405,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& } } else if (parameters.is_packed_qkv_ && do_rotary_) { // Use the ultimate fused operation when FlashAttention and static KV cache is enabled. + // When TurboQuant is active, ApplyFlashAttention handles the fused split+rotary+Hadamard+quantize path. if (will_use_flash_attention && parameters.past_present_share_buffer_) { // Directly call ApplyFlashAttention with fused split/rotary/copyKV enabled // query points to packed QKV, K and V are nullptr since they're not needed @@ -479,6 +506,14 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& total_seqlen_tensor); } + // KV cache quantization compresses the KV cache; non-flash attention paths cannot interpret it. + if (context.KvCacheQuantizationEnabled()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV cache quantization requires flash attention. " + "The non-flash attention path cannot be used with compressed KV caches. " + "Check that smooth_softmax and local_window_size are not set."); + } + // Non-flash attention path does not support kv_sequence_length==0 (shared KV layers). if (kv_empty) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc new file mode 100644 index 0000000000000..ab984799fdded --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/hadamard_transform.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status HadamardTransformProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& input = sh.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = sh.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(sh, "bert/hadamard_transform.wgsl.template", + WGSL_TEMPLATE_PARAMETER(components, components_), + WGSL_TEMPLATE_PARAMETER(hadamard_size_log2, slice_size_log2_), + WGSL_TEMPLATE_VARIABLE(input, input), + WGSL_TEMPLATE_VARIABLE(output, output)); +} + +Status ApplyHadamardTransform(onnxruntime::webgpu::ComputeContext& context, + const Tensor* input, + Tensor* output, + int explicit_slice_size) { + const auto& shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() >= 1, "Input tensor must have at least 1 dimension."); + + // Use explicit slice size if provided, otherwise derive from last dimension. + const int slice_size = explicit_slice_size > 0 ? explicit_slice_size : static_cast(shape[shape.NumDimensions() - 1]); + ORT_ENFORCE((slice_size & (slice_size - 1)) == 0, "Last dimension must be a power of 2 for Hadamard transform, got ", slice_size); + ORT_ENFORCE(slice_size >= 4, "Last dimension must be at least 4 for vectorized Hadamard transform, got ", slice_size); + + const int slice_size_log2 = Log2OfPowerOfTwo(slice_size); + + const int components = slice_size % 4 == 0 ? 4 : (slice_size % 2 == 0 ? 2 : 1); + ORT_ENFORCE(shape.Size() % slice_size == 0, "Total tensor size must be divisible by slice_size, got ", shape.Size(), " % ", slice_size, " != 0"); + const uint32_t num_slices = static_cast(shape.Size() / slice_size); + + // Workgroup size: use up to 64 threads. Each thread handles multiple butterfly pairs. + const uint32_t workgroup_size = std::min(static_cast(slice_size / 2), 64u); + + HadamardTransformProgram program(slice_size_log2, components); + program.AddInput({input, ProgramTensorMetadataDependency::TypeAndRank, components}); + program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, components}); + + program.SetDispatchGroupSize(num_slices) + .SetWorkgroupSize(workgroup_size) + .CacheHint(slice_size_log2, components) + .AddUniformVariables({{num_slices}}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h new file mode 100644 index 0000000000000..a92f1cb669488 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.h @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +// Avoid `using namespace` in headers. Pull in only what we need. +using onnxruntime::webgpu::Program; +using onnxruntime::webgpu::ProgramUniformVariableDataType; +using onnxruntime::webgpu::ShaderHelper; + +// Returns floor(log2(value)) for a positive power-of-two `value` +// (i.e. the bit position of its single set bit). +inline int Log2OfPowerOfTwo(int value) { + int log2 = 0; + for (int tmp = value; tmp > 1; tmp >>= 1) { + log2++; + } + return log2; +} + +class HadamardTransformProgram final : public Program { + public: + HadamardTransformProgram(int slice_size_log2, int components) + : Program{"HadamardTransform"}, + slice_size_log2_(slice_size_log2), + components_(components) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"num_slices", ProgramUniformVariableDataType::Uint32}); + + private: + int slice_size_log2_; + int components_; +}; + +// Apply the normalized Walsh-Hadamard transform. +// The normalized Hadamard matrix is symmetric (H == H^T) and orthogonal +// (H @ H^T = I), so applying it twice recovers the original data. +// This means the same function serves as both the forward and inverse transform. +// +// If explicit_slice_size > 0, it is used as the transform dimension size. +// Otherwise, the last dimension of the tensor shape is used. +// The transform size must be a power of 2 (>= 4). +// All elements are divided into slices of that size, each transformed independently. +Status ApplyHadamardTransform(onnxruntime::webgpu::ComputeContext& context, + const Tensor* input, + Tensor* output, + int explicit_slice_size = 0); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template new file mode 100644 index 0000000000000..54dc39929816d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform.wgsl.template @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Fast Walsh-Hadamard Transform (FWHT) applied along the last dimension. +// Reads from a const input buffer, writes to a separate output buffer. +// Each workgroup processes one contiguous slice of size 1 << hadamard_size_log2. + +#param hadamard_size_log2 +#param components + +const HADAMARD_SIZE : u32 = 1u << hadamard_size_log2; +const HADAMARD_SIZE_VEC : u32 = HADAMARD_SIZE / components; + +#use .getByOffset .setByOffset + +var hadamard_buffer : array; + +#include "bert/hadamard_transform_common.wgsl.template" + +$MAIN { + let slice_idx = workgroup_idx; + if (slice_idx >= uniforms.num_slices) { + return; + } + + let base_offset = slice_idx * HADAMARD_SIZE_VEC; + + // Load from input into shared memory. + for (var i = local_idx; i < HADAMARD_SIZE_VEC; i += workgroup_size_x) { + let val = input.getByOffset(base_offset + i); +#if components == 4 + hadamard_buffer[i * 4u + 0u] = val[0]; + hadamard_buffer[i * 4u + 1u] = val[1]; + hadamard_buffer[i * 4u + 2u] = val[2]; + hadamard_buffer[i * 4u + 3u] = val[3]; +#elif components == 2 + hadamard_buffer[i * 2u + 0u] = val[0]; + hadamard_buffer[i * 2u + 1u] = val[1]; +#else + hadamard_buffer[i] = val; +#endif + } + workgroupBarrier(); + + // Walsh-Hadamard Transform butterfly passes. + wht_butterfly(local_idx, workgroup_size_x); + + // Normalize by 1/sqrt(HADAMARD_SIZE) to make the transform orthogonal. + let norm = input_element_t(1) / input_element_t(sqrt(f32(HADAMARD_SIZE))); + + // Write results to output buffer. + for (var i = local_idx; i < HADAMARD_SIZE_VEC; i += workgroup_size_x) { +#if components == 4 + let v0 = hadamard_buffer[i * 4u + 0u] * norm; + let v1 = hadamard_buffer[i * 4u + 1u] * norm; + let v2 = hadamard_buffer[i * 4u + 2u] * norm; + let v3 = hadamard_buffer[i * 4u + 3u] * norm; + output.setByOffset(base_offset + i, output_value_t(v0, v1, v2, v3)); +#elif components == 2 + let v0 = hadamard_buffer[i * 2u + 0u] * norm; + let v1 = hadamard_buffer[i * 2u + 1u] * norm; + output.setByOffset(base_offset + i, output_value_t(v0, v1)); +#else + output.setByOffset(base_offset + i, output_value_t(hadamard_buffer[i] * norm)); +#endif + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template new file mode 100644 index 0000000000000..7ac43cd37ab93 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/hadamard_transform_common.wgsl.template @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// The Hadamard transform is an orthogonal linear transformation (H @ H^T = I) +// that spreads information uniformly across all dimensions of a vector. Unlike +// random rotations, it requires no learned parameters and is its own inverse — +// applying it twice recovers the original input. Hadamard matrix is also +// symmetric so H == H^T. +// +// In quantization a Hadamard transform is applied to distribute outlier magnitudes +// evenly across dimensions. This prevents a single large value from dominating and +// wasting quantization range, dramatically reducing quantization error. +// After dequantization, the inverse transform (which is the same operation, since H = H^{-1}) +// restores the original representation. +// +// Fast Walsh-Hadamard Transform (WHT) via in-place butterfly operations. +// +// Requires the caller to declare: +// var hadamard_buffer : array; // T can be f32 or f16 +// hadamard_size_log2 (via #param in the including template) +// +// The Fast Hadamard Transform avoids materializing the full N×N Hadamard matrix. +// Instead it performs log2(N) butterfly stages, each doing N/2 butterfly ops: +// +// Total butterflies = (N/2) × log2(N) +// +// Each butterfly is just: sum = a + b; diff = a - b; +// +// For N=128: 64 × 7 = 448 butterflies (896 adds/subs) vs 16,384 for dense multiply. +// For N=4096: 2048 × 12 = 24,576 butterflies vs 16,777,216 for dense multiply (~300× less). +// +// This is why methods like QuaRot can apply orthogonal rotations to every activation +// during inference with negligible overhead — it's just shuffles, adds, and subtracts. +// +// Consider the H8 hadamard matrix +// +// [ 1 1 1 1 1 1 1 1 +// 1 -1 1 -1 1 -1 1 -1 +// 1 1 -1 -1 1 1 -1 -1 +// 1 -1 -1 1 1 -1 -1 1 +// 1 1 1 1 -1 -1 -1 -1 +// 1 -1 1 -1 -1 1 -1 1 +// 1 1 -1 -1 -1 -1 1 1 +// 1 -1 -1 1 -1 1 1 -1] +// +// and a vector matrix multiply operation [a b c d e f g h] @ H8. +// +// The butterfly pattern — pair_distance doubles each stage: +// +// Stage 0 (pair_distance=1): pairs are adjacent +// [a b c d e f g h] → [a+b a-b c+d c-d e+f e-f g+h g-h] +// +// Stage 1 (pair_distance=2): pairs are 2 apart +// [A B C D E F G H] → [A+C B+D A-C B-D E+G F+H E-G F-H] +// +// Stage 2 (pair_distance=4): pairs are 4 apart +// [A B C D E F G H] → [A+E B+F C+G D+H A-E B-F C-G D-H] +// +// After log2(N) stages the output is the WHT of the input. + + +fn wht_butterfly(lid: u32, wg_size: u32) { + let hadamard_size = (1u << hadamard_size_log2); + let pair_count = (hadamard_size) / 2u; // hadamard_size/2 pairs per step + var pair_distance = 1u; + for (var step = 0u; step < hadamard_size_log2; step++) { + for (var idx = lid; idx < pair_count; idx += wg_size) { + // The array is divided into groups of (2 * pair_distance). + // A group is the unit for addition subtraction + // (ab) in stage 0 + // (ABCD) in stage 1 + // + // Each group's first half pairs with its second half: + // + // pair_distance=4, array of 16: + // Group 0: [0 1 2 3 | 4 5 6 7] Group 1: [8 9 10 11 | 12 13 14 15] + // ^---------^ ^------------^ + // first second first second + // + // idx=0 → group=0, group_idx=0 → i = 0*8+0 = 0, j = 0+4 = 4 + // idx=5 → group=1, group_idx=1 → i = 1*8+1 = 9, j = 9+4 = 13 + // So we butterfly hadamard_buffer[i] with hadamard_buffer[j]. + // + // to compute the group we are in, it would be input_index/(pair_distance * 2u) + // since idx is iterating pairs idx is 2 * input_index and therefore + // group = idx/pair_distance + let group = idx / pair_distance; + let group_size = (pair_distance * 2u); + let group_idx = idx % pair_distance; + let i = group * group_size + group_idx; + let j = i + pair_distance; + let a = hadamard_buffer[i]; + let b = hadamard_buffer[j]; + hadamard_buffer[i] = a + b; + hadamard_buffer[j] = a - b; + } + workgroupBarrier(); + pair_distance = pair_distance << 1u; + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/indirect_dispatch_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/indirect_dispatch_common.wgsl.template new file mode 100644 index 0000000000000..79cbcbe6fa2ee --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/indirect_dispatch_common.wgsl.template @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Writes the flash attention dispatch dimensions into indirect_buffer (3 x u32). +// Handles WebGPU's 65535 maxComputeWorkgroupsPerDimension limit by collapsing +// to 2D (sqrt) or 3D (cbrt) when any dimension exceeds the limit. +// Caller contract: indirect_buffer must be a storage output of at least 3 u32 elements. +fn populate_indirect_dispatch_buffer(x: u32, y: u32, z: u32) { + let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension + if (x <= limit && y <= limit && z <= limit) { + indirect_buffer[0] = x; + indirect_buffer[1] = y; + indirect_buffer[2] = z; + return; + } + let size = f32(x) * f32(y) * f32(z); + let dispatch_avg_2d = u32(ceil(sqrt(size))); + if (dispatch_avg_2d <= limit) { + indirect_buffer[0] = dispatch_avg_2d; + indirect_buffer[1] = dispatch_avg_2d; + indirect_buffer[2] = 1u; + return; + } + let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0))); + indirect_buffer[0] = dispatch_avg_3d; + indirect_buffer[1] = dispatch_avg_3d; + indirect_buffer[2] = dispatch_avg_3d; +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index e3d92c036d2c1..91a087801a252 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -6,6 +6,10 @@ #use guardAgainstOutOfBoundsWorkgroupSizes #use .setByIndices .getByIndices .getByOffset +#if prepare_indirect_dispatch +#include "bert/indirect_dispatch_common.wgsl.template" +#endif + $MAIN { guardAgainstOutOfBoundsWorkgroupSizes(uniforms.dispatch_size); diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template new file mode 100644 index 0000000000000..f6c7fcad58eb9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_common.wgsl.template @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TurboQuant 4-bit codebook constants and dequantization helpers. +// Shared between turbo_quant_hadamard (quantize) and flash_attention (dequantize). +// +// KV cache layout per head (HEAD_SIZE elements): +// Word 0: bitcast(fp32 L2 norm) +// Words 1..N: HEAD_SIZE × 4-bit centroid indices packed 8-per-u32 +// Total words: HEAD_SIZE / 8 + 1 + +// 16 MSE-optimal centroids for unit-sphere coordinates (symmetric: C[i] = -C[15-i]) +const TQ_CENTROIDS = array( + -0.2377, -0.1809, -0.1419, -0.1104, -0.0829, -0.0578, -0.0342, -0.0113, + 0.0113, 0.0342, 0.0578, 0.0829, 0.1104, 0.1419, 0.1809, 0.2377); + +// 15 interior decision boundaries for branchless quantization +const TQ_BOUNDARIES = array( + -0.2093, -0.1614, -0.1261, -0.0966, -0.0704, -0.0460, -0.0227, + 0.0000, 0.0227, 0.0460, 0.0704, 0.0966, 0.1261, 0.1614, 0.2093); + +// Branchless centroid index lookup (0..15) for a unit-normalized value. +// Counts how many boundaries the value exceeds — no divergence across threads. +fn snap_to_centroid_index(x: f32) -> u32 { + return u32(x > TQ_BOUNDARIES[0]) + u32(x > TQ_BOUNDARIES[1]) + + u32(x > TQ_BOUNDARIES[2]) + u32(x > TQ_BOUNDARIES[3]) + + u32(x > TQ_BOUNDARIES[4]) + u32(x > TQ_BOUNDARIES[5]) + + u32(x > TQ_BOUNDARIES[6]) + u32(x > TQ_BOUNDARIES[7]) + + u32(x > TQ_BOUNDARIES[8]) + u32(x > TQ_BOUNDARIES[9]) + + u32(x > TQ_BOUNDARIES[10]) + u32(x > TQ_BOUNDARIES[11]) + + u32(x > TQ_BOUNDARIES[12]) + u32(x > TQ_BOUNDARIES[13]) + + u32(x > TQ_BOUNDARIES[14]); +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_dequant.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_dequant.wgsl.template new file mode 100644 index 0000000000000..14a19217ba428 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_dequant.wgsl.template @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// TurboQuant 4-bit dequantization helper shared by the attention (dequantize) +// shaders. Kept separate from turbo_quant_common.wgsl.template because the +// quantize-side shaders include that file but do not define the `q_value_t` / +// `q_element_t` aliases or the `tq_lut` centroid table this helper requires. +// +// The includer must define the `q_value_t` / `q_element_t` aliases and a +// `tq_lut` centroid lookup table (workgroup array preloaded from TQ_CENTROIDS). + +// Dequantize 4 consecutive nibbles from the low 16 bits of a packed u32 word. +fn tq_unpack_nibbles(packed: u32) -> q_value_t { + return q_value_t( + q_element_t(tq_lut[(packed) & 0xFu]), + q_element_t(tq_lut[(packed >> 4u) & 0xFu]), + q_element_t(tq_lut[(packed >> 8u) & 0xFu]), + q_element_t(tq_lut[(packed >> 12u) & 0xFu])); +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary_hadamard.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary_hadamard.wgsl.template new file mode 100644 index 0000000000000..5c452b5ea188b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_fused_rotary_hadamard.wgsl.template @@ -0,0 +1,237 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Fused TurboQuant: Split packed QKV + Rotary K + Hadamard + Quantize K/V + Rotary Q. +// +// A single dispatch handles all three components: +// Workgroups [0, num_kv_slices): K — split from packed QKV, apply rotary, WHT, quantize, write to present_key +// Workgroups [num_kv_slices, 2*num_kv_slices): V — split from packed QKV, WHT, quantize, write to present_value +// Workgroups [2*num_kv_slices, total): Q — split from packed QKV, apply rotary, write to query output +// +// K/V path: apply Walsh-Hadamard butterfly transform in shared memory, compute L2 norm, +// quantize each element to a 4-bit centroid index, pack 8 indices per u32, and store +// norm (as bitcast(f32)) followed by packed index words. +// Q path: per-element rotary embedding and return (no shared memory or barriers needed). +// +// Output layout per KV head: [norm_u32, packed_indices_0, ..., packed_indices_(HEAD_SIZE/8 - 1)] + +#param hadamard_size_log2 +#param half_rotary_dim +#param compressed_head_size_u32 +#param past_present_share_buffer +#param use_seqlen_k +#param prepare_indirect_dispatch +#param multi_rotary_cache_concat_offset +#param use_multi_rotary_cache_concat +#use .getByIndices .getByOffset .setByOffset + +const HEAD_SIZE : u32 = 1u << hadamard_size_log2; +const COMPRESSED_HEAD_U32 : u32 = compressed_head_size_u32; +const HALF_ROTARY_DIM : u32 = half_rotary_dim; + +var hadamard_buffer : array; + +#include "bert/hadamard_transform_common.wgsl.template" +#include "bert/turbo_quant_common.wgsl.template" +#if prepare_indirect_dispatch +#include "bert/indirect_dispatch_common.wgsl.template" +#endif + +var scale_reduction_buffer : array; +// Reuse shared memory for packing: store centroid indices (0-15) as u32 per element. +var index_buffer : array; + +$MAIN { + // Compute total_seq_length. +#if use_seqlen_k + let total_seq_length = u32(seqlen_k[0u]) + 1u; +#else + let total_seq_length = uniforms.total_sequence_length; +#endif + let past_seq_length = total_seq_length - uniforms.kv_sequence_length; + + // Base position offset for rotary embedding cos/sin cache lookup. + let position_id = past_seq_length; +#if use_multi_rotary_cache_concat + let base_position = select(0u, multi_rotary_cache_concat_offset, total_seq_length > multi_rotary_cache_concat_offset); +#else + let base_position = 0u; +#endif + + // Prepare indirect dispatch buffer (first workgroup, first thread only). +#if prepare_indirect_dispatch + if (workgroup_idx == 0u && local_idx == 0u) { + let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; + populate_indirect_dispatch_buffer(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); + } +#endif + + // Map flat workgroup index to logical component (Q, K, or V). + let num_kv_slices = uniforms.num_kv_slices; + let is_q = workgroup_idx >= 2u * num_kv_slices; + let is_value = !is_q && workgroup_idx >= num_kv_slices; + + // ============ Q WORKGROUP PATH (early return, no shared memory/barriers) ============ + if (is_q) { + let q_slice = workgroup_idx - 2u * num_kv_slices; + if (q_slice >= uniforms.num_q_slices) { return; } + + // Unflatten q_slice into (batch, seq, head). + let batch = q_slice / (uniforms.kv_sequence_length * uniforms.num_heads); + let head = (q_slice / uniforms.kv_sequence_length) % uniforms.num_heads; + let seq = q_slice % uniforms.kv_sequence_length; + + // Compute source offset in packed QKV for this Q head. + // Packed QKV layout per token: [Q_hidden | K_kv_hidden | V_kv_hidden]. + let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; + let token_offset = (batch * uniforms.kv_sequence_length + seq) * token_size; + let q_src_base = token_offset + head * HEAD_SIZE; + + // Destination offset in query output (BSD layout: batch * seq * hidden_size). + let q_dst_base = (batch * uniforms.kv_sequence_length + seq) * uniforms.hidden_size + head * HEAD_SIZE; + + // Position index for rotary embedding cos/sin lookup. + let seq_position_id = past_seq_length + seq; + + // Apply rotary embedding: rotate first HALF_ROTARY_DIM pairs. + for (var i = local_idx; i < HALF_ROTARY_DIM; i += workgroup_size_x) { + let cos_v = cos_cache.getByIndices(vec2(base_position + seq_position_id, i)); + let sin_v = sin_cache.getByIndices(vec2(base_position + seq_position_id, i)); + + let q_i = packed_qkv.getByOffset(q_src_base + i); + let q_j = packed_qkv.getByOffset(q_src_base + i + HALF_ROTARY_DIM); + + let q_re = q_i * cos_v - q_j * sin_v; + let q_im = q_i * sin_v + q_j * cos_v; + + query.setByOffset(q_dst_base + i, q_re); + query.setByOffset(q_dst_base + i + HALF_ROTARY_DIM, q_im); + } + + // Copy non-rotary elements (if head_size > 2 * half_rotary_dim) + for (var i = local_idx; i < HEAD_SIZE - 2u * HALF_ROTARY_DIM; i += workgroup_size_x) { + let actual_idx = 2u * HALF_ROTARY_DIM + i; + query.setByOffset(q_dst_base + actual_idx, packed_qkv.getByOffset(q_src_base + actual_idx)); + } + return; + } + + // ============ K/V WORKGROUP PATH (WHT + quantize) ============ + let kv_slice = select(workgroup_idx, workgroup_idx - num_kv_slices, is_value); + if (kv_slice >= num_kv_slices) { return; } + + // Unflatten kv_slice into (batch, head, seq). + let batch = kv_slice / (uniforms.kv_num_heads * uniforms.kv_sequence_length); + let head = (kv_slice / uniforms.kv_sequence_length) % uniforms.kv_num_heads; + let seq = kv_slice % uniforms.kv_sequence_length; + + // Compute destination offset in present_key/present_value (u32 packed, BNSH layout). +#if past_present_share_buffer + let dest_seq = past_seq_length + seq; +#else + let dest_seq = seq; +#endif + let present_base = ((batch * uniforms.kv_num_heads + head) * uniforms.present_seq_length + dest_seq) * COMPRESSED_HEAD_U32; + + // Compute source offset in packed QKV for this KV head. + // Packed QKV layout per token: [Q_hidden | K_kv_hidden | V_kv_hidden]. + let token_size = uniforms.hidden_size + 2u * uniforms.kv_hidden_size; + let token_offset = (batch * uniforms.kv_sequence_length + seq) * token_size; + let k_src_base = token_offset + uniforms.hidden_size + head * HEAD_SIZE; + let v_src_base = token_offset + uniforms.hidden_size + uniforms.kv_hidden_size + head * HEAD_SIZE; + let src_base = select(k_src_base, v_src_base, is_value); + + // Position index for rotary embedding cos/sin lookup (K only). + let seq_position_id = past_seq_length + seq; + + // Load from packed QKV into shared memory (f32). K gets rotary, V loads directly. + if (!is_value) { + // K path: apply rotary embedding during load. + for (var i = local_idx; i < HALF_ROTARY_DIM; i += workgroup_size_x) { + let cos_v = f32(cos_cache.getByIndices(vec2(base_position + seq_position_id, i))); + let sin_v = f32(sin_cache.getByIndices(vec2(base_position + seq_position_id, i))); + + let k_i = f32(packed_qkv.getByOffset(src_base + i)); + let k_j = f32(packed_qkv.getByOffset(src_base + i + HALF_ROTARY_DIM)); + + hadamard_buffer[i] = k_i * cos_v - k_j * sin_v; + hadamard_buffer[i + HALF_ROTARY_DIM] = k_i * sin_v + k_j * cos_v; + } + // Copy non-rotary elements (if head_size > 2 * half_rotary_dim) + for (var i = local_idx; i < HEAD_SIZE - 2u * HALF_ROTARY_DIM; i += workgroup_size_x) { + let actual_idx = 2u * HALF_ROTARY_DIM + i; + hadamard_buffer[actual_idx] = f32(packed_qkv.getByOffset(src_base + actual_idx)); + } + } else { + // V path: straight copy, no rotary. + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + hadamard_buffer[i] = f32(packed_qkv.getByOffset(src_base + i)); + } + } + workgroupBarrier(); + + // Walsh-Hadamard Transform butterfly passes. + wht_butterfly(local_idx, workgroup_size_x); + + // WHT normalization. + let wht_norm = 1.0f / sqrt(f32(HEAD_SIZE)); + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + hadamard_buffer[i] *= wht_norm; + } + workgroupBarrier(); + + // Compute L2 norm via parallel reduction. + var partial_sq_sum = 0.0f; + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + partial_sq_sum += hadamard_buffer[i] * hadamard_buffer[i]; + } + scale_reduction_buffer[local_idx] = partial_sq_sum; + workgroupBarrier(); + // Tree reduction: each iteration halves active threads, accumulating the total + // sum of squares into scale_reduction_buffer[0] in log2(workgroup_size) steps. + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride >>= 1u) { + if (local_idx < stride) { + scale_reduction_buffer[local_idx] += scale_reduction_buffer[local_idx + stride]; + } + workgroupBarrier(); + } + let l2_norm = sqrt(scale_reduction_buffer[0]); + let inv_l2 = select(0.0f, 1.0f / l2_norm, l2_norm > 0.0f); + + // Quantize: compute centroid index for each element and store in index_buffer. + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + let unit_val = hadamard_buffer[i] * inv_l2; + index_buffer[i] = snap_to_centroid_index(unit_val); + } + workgroupBarrier(); + + // Pack 8 indices per u32 word and write to output. + // Thread 0 writes the norm word. + if (local_idx == 0u) { + if (!is_value) { + present_key.setByOffset(present_base, bitcast(l2_norm)); + } else { + present_value.setByOffset(present_base, bitcast(l2_norm)); + } + } + + // Each thread packs one or more u32 words (HEAD_SIZE/8 words total). + let num_packed_words = HEAD_SIZE >> 3u; + for (var w = local_idx; w < num_packed_words; w += workgroup_size_x) { + let base_elem = w << 3u; + var packed = 0u; + packed |= (index_buffer[base_elem + 0u] & 0xFu); + packed |= (index_buffer[base_elem + 1u] & 0xFu) << 4u; + packed |= (index_buffer[base_elem + 2u] & 0xFu) << 8u; + packed |= (index_buffer[base_elem + 3u] & 0xFu) << 12u; + packed |= (index_buffer[base_elem + 4u] & 0xFu) << 16u; + packed |= (index_buffer[base_elem + 5u] & 0xFu) << 20u; + packed |= (index_buffer[base_elem + 6u] & 0xFu) << 24u; + packed |= (index_buffer[base_elem + 7u] & 0xFu) << 28u; + if (!is_value) { + present_key.setByOffset(present_base + 1u + w, packed); + } else { + present_value.setByOffset(present_base + 1u + w, packed); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc new file mode 100644 index 0000000000000..ab251d1a033fb --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.cc @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/turbo_quant_hadamard.h" +#include "contrib_ops/webgpu/bert/hadamard_transform.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/common/logging/logging.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status TurboQuantHadamardProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& key = shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& value = shader.AddInput("value", ShaderUsage::UseUniform); + // present_key/present_value are u32 arrays (packed 4-bit quantized data) + const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform); + const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform); + + if (use_seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::None); + } + if (prepare_indirect_dispatch_) { + shader.AddOutput("indirect_buffer", ShaderUsage::None); + } + + // Past KV cache is already u32-packed — add as uniform only (no type aliases needed). + // The variable bindings are always passed to the template; when has_past_ is false the + // template never references them (guarded by #if has_past), so binding them to placeholder + // variables (key/value) is harmless and avoids passing a null variable pointer. + const ShaderVariableHelper* past_key = &key; + const ShaderVariableHelper* past_value = &value; + if (has_past_) { + past_key = &shader.AddInput("past_key", ShaderUsage::UseUniform); + past_value = &shader.AddInput("past_value", ShaderUsage::UseUniform); + } + + return WGSL_TEMPLATE_APPLY(shader, "bert/turbo_quant_hadamard.wgsl.template", + WGSL_TEMPLATE_PARAMETER(components, components_), + WGSL_TEMPLATE_PARAMETER(compressed_head_size_u32, compressed_head_size_u32_), + WGSL_TEMPLATE_PARAMETER(hadamard_size_log2, head_size_log2_), + WGSL_TEMPLATE_PARAMETER(has_past, has_past_), + WGSL_TEMPLATE_PARAMETER(kv_BNSH, kv_BNSH_), + WGSL_TEMPLATE_PARAMETER(past_present_share_buffer, past_present_share_buffer_), + WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), + WGSL_TEMPLATE_VARIABLE(key, key), + WGSL_TEMPLATE_VARIABLE(past_key, *past_key), + WGSL_TEMPLATE_VARIABLE(past_value, *past_value), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(value, value)); +} + +Status TurboQuantCopyToQuantizedKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, + uint32_t num_q_tiles) { + const int head_size = parameters.head_size_; + const int components = head_size % 4 == 0 ? 4 : (head_size % 2 == 0 ? 2 : 1); + ORT_ENFORCE((head_size & (head_size - 1)) == 0 && head_size >= 8, + "head_size must be a power of 2 >= 8 for Hadamard transform, got ", head_size); + + const int head_size_log2 = Log2OfPowerOfTwo(head_size); + + // Compressed KV cache: 1 u32 for norm + head_size/8 u32s for packed 4-bit indices. + const int compressed_head_size_u32 = head_size / 8 + 1; + + bool has_past = !parameters.past_present_share_buffer_ && past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0; + int kv_num_heads = parameters.is_gqa_ ? parameters.kv_num_heads_ : parameters.num_heads_; + int copy_sequence_length = parameters.past_present_share_buffer_ ? parameters.kv_sequence_length_ : parameters.total_sequence_length_; + uint32_t num_slices_per_kv = static_cast(parameters.batch_size_ * kv_num_heads * copy_sequence_length); + uint32_t total_workgroups = 2 * num_slices_per_kv; // K + V + + const uint32_t workgroup_size = std::min(static_cast(head_size / 2), 64u); + + bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + bool use_seqlen_k = (seqlen_k != nullptr); + ORT_RETURN_IF_ERROR( + (!use_seqlen_k || parameters.batch_size_ == 1) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant graph-capture decode path reads seqlen_k[0] for all batches and " + "currently supports batch_size == 1 only; got batch_size = ", + parameters.batch_size_)); + bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH; + + TurboQuantHadamardProgram program{"TurboQuantCopyToQuantizedKVCache", has_past, kv_BNSH, + parameters.past_present_share_buffer_, + head_size_log2, components, + compressed_head_size_u32, + prepare_indirect_dispatch, use_seqlen_k}; + // Inputs: K and V in their original format (fp16/fp32, vectorized). + if (kv_BNSH) { + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); + } else { + ORT_RETURN_IF_ERROR( + (parameters.qkv_format_ == Q_K_V_BSNH) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "qkv format ", parameters.qkv_format_, " is not supported yet.")); + TensorShape reshaped_KV_shape{parameters.batch_size_, parameters.kv_sequence_length_, kv_num_heads, head_size / components}; + program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}, + {V, ProgramTensorMetadataDependency::TypeAndRank, reshaped_KV_shape, components}}); + } + + if (use_seqlen_k) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); + } + + // Past KV cache is already u32-packed (no vectorization). + if (has_past) { + program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank}, + {past_value, ProgramTensorMetadataDependency::TypeAndRank}}); + } + + // Output: present KV cache as u32 (packed 4-bit quantized). + program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank}, + {present_value, ProgramTensorMetadataDependency::Rank}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + // present_key has shape (batch, kv_num_heads, present_seq_length, compressed_head_size_u32) + uint32_t present_seq_length = static_cast(present_key->Shape()[2]); + + program.SetDispatchGroupSize(total_workgroups) + .SetWorkgroupSize(workgroup_size) + .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, + prepare_indirect_dispatch, use_seqlen_k, head_size_log2, components, compressed_head_size_u32) + .AddUniformVariables({{static_cast(parameters.batch_size_)}, + {static_cast(compressed_head_size_u32)}, + {static_cast(kv_num_heads)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(parameters.num_heads_)}, + {num_q_tiles}, + {num_slices_per_kv}, + {present_seq_length}, + {tile_size}, + {static_cast(parameters.total_sequence_length_)}}); + + return context.RunProgram(program); +} + +Status TurboQuantFusedRotaryProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& packed_qkv = shader.AddInput("packed_qkv", ShaderUsage::UseUniform); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + + if (use_seqlen_k_) { + shader.AddInput("seqlen_k", ShaderUsage::None); + } + + const auto& query = shader.AddOutput("query", ShaderUsage::UseUniform); + // present_key/present_value are u32 arrays (packed 4-bit quantized data) + const auto& present_key = shader.AddOutput("present_key", ShaderUsage::UseUniform); + const auto& present_value = shader.AddOutput("present_value", ShaderUsage::UseUniform); + + if (prepare_indirect_dispatch_) { + shader.AddOutput("indirect_buffer", ShaderUsage::None); + } + + return WGSL_TEMPLATE_APPLY(shader, "bert/turbo_quant_fused_rotary_hadamard.wgsl.template", + WGSL_TEMPLATE_PARAMETER(compressed_head_size_u32, compressed_head_size_u32_), + WGSL_TEMPLATE_PARAMETER(hadamard_size_log2, head_size_log2_), + WGSL_TEMPLATE_PARAMETER(half_rotary_dim, half_rotary_dim_), + WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_), + WGSL_TEMPLATE_PARAMETER(past_present_share_buffer, past_present_share_buffer_), + WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0), + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_), + WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache), + WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(query, query), + WGSL_TEMPLATE_VARIABLE(sin_cache, sin_cache)); +} + +Status TurboQuantApplyRotaryAndCopyToQuantizedKVCache(onnxruntime::webgpu::ComputeContext& context, + const WebgpuAttentionParameters& parameters, + const Tensor* packedQKV, + const Tensor* seqlen_k, + const Tensor* cos_cache, + const Tensor* sin_cache, + Tensor* query, + Tensor* present_key, + Tensor* present_value, + Tensor* indirect_buffer, + uint32_t tile_size, + uint32_t num_q_tiles) { + const int head_size = parameters.head_size_; + ORT_ENFORCE((head_size & (head_size - 1)) == 0 && head_size >= 8, + "head_size must be a power of 2 >= 8 for TurboQuant fused rotary, got ", head_size); + + const int head_size_log2 = Log2OfPowerOfTwo(head_size); + + const int compressed_head_size_u32 = head_size / 8 + 1; + const int kv_num_heads = parameters.is_gqa_ ? parameters.kv_num_heads_ : parameters.num_heads_; + const int half_rotary_dim = static_cast(cos_cache->Shape()[1]); + + // Dispatch: K slices + V slices + Q slices + uint32_t num_kv_slices = static_cast(parameters.batch_size_ * kv_num_heads * parameters.kv_sequence_length_); + uint32_t num_q_slices = static_cast(parameters.batch_size_ * parameters.num_heads_ * parameters.kv_sequence_length_); + uint32_t total_workgroups = 2 * num_kv_slices + num_q_slices; + + const uint32_t workgroup_size = std::min(static_cast(head_size / 2), 64u); + + bool prepare_indirect_dispatch = (indirect_buffer != nullptr); + bool use_seqlen_k = (seqlen_k != nullptr); + ORT_RETURN_IF_ERROR( + (!use_seqlen_k || parameters.batch_size_ == 1) + ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "TurboQuant graph-capture decode path reads seqlen_k[0] for all batches and " + "currently supports batch_size == 1 only; got batch_size = ", + parameters.batch_size_)); + const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset(); + + TurboQuantFusedRotaryProgram program{"TurboQuantFusedRotary", head_size_log2, + half_rotary_dim, + compressed_head_size_u32, + parameters.past_present_share_buffer_, + prepare_indirect_dispatch, use_seqlen_k, + multi_rotary_cache_concat_offset}; + + program.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}); + program.AddInputs({ + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}, + }); + + if (use_seqlen_k) { + program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); + } + + program.AddOutputs({{query, ProgramTensorMetadataDependency::None}, + {present_key, ProgramTensorMetadataDependency::Rank}, + {present_value, ProgramTensorMetadataDependency::Rank}}); + + if (prepare_indirect_dispatch) { + program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None}); + } + + uint32_t present_seq_length = static_cast(present_key->Shape()[2]); + + program.SetDispatchGroupSize(total_workgroups) + .SetWorkgroupSize(workgroup_size) + .CacheHint(parameters.past_present_share_buffer_, + prepare_indirect_dispatch, use_seqlen_k, head_size_log2, + half_rotary_dim, compressed_head_size_u32, multi_rotary_cache_concat_offset) + .AddUniformVariables({{static_cast(parameters.batch_size_)}, + {static_cast(compressed_head_size_u32)}, + {static_cast(parameters.hidden_size_)}, + {static_cast(parameters.kv_hidden_size_)}, + {static_cast(kv_num_heads)}, + {static_cast(parameters.kv_sequence_length_)}, + {static_cast(parameters.num_heads_)}, + {num_kv_slices}, + {num_q_slices}, + {num_q_tiles}, + {present_seq_length}, + {tile_size}, + {static_cast(parameters.total_sequence_length_)}}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h new file mode 100644 index 0000000000000..4bea6f8237094 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.h @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +// Avoid `using namespace` in headers. Pull in only what we need. +using onnxruntime::webgpu::Program; +using onnxruntime::webgpu::ProgramUniformVariableDataType; +using onnxruntime::webgpu::ShaderHelper; + +// Fused TurboQuant copy-to-KV-cache with Hadamard rotation and 4-bit quantization. +// Applies the Walsh-Hadamard transform to new K/V tokens, quantizes to 4-bit +// centroid indices packed into u32 words with fp32 L2 norm, then writes into +// the present KV cache (stored as u32). +// Each workgroup handles one (batch, head, seq) slice for either K or V. +class TurboQuantHadamardProgram final : public Program { + public: + TurboQuantHadamardProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, + bool past_present_share_buffer, int head_size_log2, int components, + int compressed_head_size_u32, + bool prepare_indirect_dispatch = false, bool use_seqlen_k = false) + : Program{kernel_name}, + has_past_(has_past), + kv_BNSH_(kv_BNSH), + past_present_share_buffer_(past_present_share_buffer), + head_size_log2_(head_size_log2), + components_(components), + compressed_head_size_u32_(compressed_head_size_u32), + prepare_indirect_dispatch_(prepare_indirect_dispatch), + use_seqlen_k_(use_seqlen_k) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"compressed_head_size_u32", ProgramUniformVariableDataType::Uint32}, + {"kv_num_heads", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"num_q_tiles", ProgramUniformVariableDataType::Uint32}, + {"num_slices_per_kv", ProgramUniformVariableDataType::Uint32}, + {"present_seq_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length", ProgramUniformVariableDataType::Uint32}); + + private: + bool has_past_; + bool kv_BNSH_; + bool past_present_share_buffer_; + int head_size_log2_; + int components_; + int compressed_head_size_u32_; + bool prepare_indirect_dispatch_; + bool use_seqlen_k_; +}; + +// --------------------------------------------------------------------------- +// TurboQuant present-KV allocator contract (IMPORTANT for pre-allocated outputs) +// +// When TurboQuant is active the present_key/present_value buffers do NOT store +// fp16/fp32 head vectors. Each head is compressed to: +// compressed_u32_words = head_size / 8 + 1 +// u32 words (one fp32 L2-norm scale + head_size 4-bit indices packed 8-per-u32), +// i.e. (head_size * 4 + 32) bits per head. Expressed in the tensor's element +// type the last dimension is compressed_u32_words * (4 / sizeof(element)). +// +// ONNX shape inference is provider-agnostic and still reports the uncompressed +// head_size for these outputs, because TurboQuant is a WebGPU provider option +// resolved at runtime — there is no static graph metadata channel to advertise +// the compressed layout. Therefore any caller that PRE-ALLOCATES the present +// buffers (IO-binding, graph capture, ORT GenAI) MUST size the last dimension +// to the compressed length above, keyed off the same `turboQuant` provider +// option, NOT off the model-reported head_size. +// +// This formula is duplicated in the GenAI allocator +// (onnxruntime-genai/src/models/kv_cache.cpp, ComputeTurboQuantHeadSize) and +// must be kept in sync. As a safety net, ApplyFlashAttention validates the +// supplied present buffer's last-dim byte size and fails with INVALID_ARGUMENT +// on a mismatch rather than corrupting memory. +// --------------------------------------------------------------------------- +Status TurboQuantCopyToQuantizedKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, + const Tensor* K, const Tensor* past_key, Tensor* present_key, + const Tensor* V, const Tensor* past_value, Tensor* present_value, + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, + uint32_t num_q_tiles); + +// Fused TurboQuant: Split packed QKV + Rotary K + Hadamard + Quantize K/V + Rotary Q. +// Single dispatch handles all Q/K/V processing from packed QKV input. +class TurboQuantFusedRotaryProgram final : public Program { + public: + TurboQuantFusedRotaryProgram(const std::string& kernel_name, int head_size_log2, + int half_rotary_dim, + int compressed_head_size_u32, + bool past_present_share_buffer, + bool prepare_indirect_dispatch, bool use_seqlen_k, + uint32_t multi_rotary_cache_concat_offset) + : Program{kernel_name}, + head_size_log2_(head_size_log2), + half_rotary_dim_(half_rotary_dim), + compressed_head_size_u32_(compressed_head_size_u32), + past_present_share_buffer_(past_present_share_buffer), + prepare_indirect_dispatch_(prepare_indirect_dispatch), + use_seqlen_k_(use_seqlen_k), + multi_rotary_cache_concat_offset_(multi_rotary_cache_concat_offset) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"batch_size", ProgramUniformVariableDataType::Uint32}, + {"compressed_head_size_u32", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_num_heads", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"num_kv_slices", ProgramUniformVariableDataType::Uint32}, + {"num_q_slices", ProgramUniformVariableDataType::Uint32}, + {"num_q_tiles", ProgramUniformVariableDataType::Uint32}, + {"present_seq_length", ProgramUniformVariableDataType::Uint32}, + {"tile_size", ProgramUniformVariableDataType::Uint32}, + {"total_sequence_length", ProgramUniformVariableDataType::Uint32}); + + private: + int head_size_log2_; + int half_rotary_dim_; + int compressed_head_size_u32_; + bool past_present_share_buffer_; + bool prepare_indirect_dispatch_; + bool use_seqlen_k_; + uint32_t multi_rotary_cache_concat_offset_; +}; + +Status TurboQuantApplyRotaryAndCopyToQuantizedKVCache(onnxruntime::webgpu::ComputeContext& context, + const WebgpuAttentionParameters& parameters, + const Tensor* packedQKV, + const Tensor* seqlen_k, + const Tensor* cos_cache, + const Tensor* sin_cache, + Tensor* query, + Tensor* present_key, + Tensor* present_value, + Tensor* indirect_buffer, + uint32_t tile_size, + uint32_t num_q_tiles); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template new file mode 100644 index 0000000000000..81b5ff2997515 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/turbo_quant_hadamard.wgsl.template @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Fused TurboQuant copy-to-KV-cache with Walsh-Hadamard Transform and 4-bit quantization. +// Each workgroup handles one (batch, head, seq) slice for either K or V. +// Workgroup layout: [0, num_slices_per_kv) -> K, [num_slices_per_kv, 2*num_slices_per_kv) -> V. +// +// For new tokens: apply Walsh-Hadamard butterfly transform in shared memory, +// compute L2 norm, quantize each element to a 4-bit centroid index, pack 8 indices per u32, +// and store norm (as bitcast(f32)) followed by packed index words. +// For past tokens (has_past): simple u32-word copy from past to present. +// +// Output layout per head: [norm_u32, packed_indices_0, packed_indices_1, ..., packed_indices_(HEAD_SIZE/8 - 1)] + +#param has_past +#param kv_BNSH +#param past_present_share_buffer +#param use_seqlen_k +#param prepare_indirect_dispatch +#param hadamard_size_log2 +#param components +#param compressed_head_size_u32 +#use .indicesToOffset .getByOffset .setByOffset + +const HEAD_SIZE : u32 = 1u << hadamard_size_log2; +const HEAD_SIZE_VEC : u32 = HEAD_SIZE / components; +const COMPRESSED_HEAD_U32 : u32 = compressed_head_size_u32; + +var hadamard_buffer : array; + +#include "bert/hadamard_transform_common.wgsl.template" +#include "bert/turbo_quant_common.wgsl.template" +#if prepare_indirect_dispatch +#include "bert/indirect_dispatch_common.wgsl.template" +#endif + +var scale_reduction_buffer : array; +// Reuse shared memory for packing: store centroid indices (0-15) as u32 per element. +var index_buffer : array; + +$MAIN { + // Map flat workgroup index to logical (K-or-V, batch, head, seq) coordinates. + let is_value = workgroup_idx >= uniforms.num_slices_per_kv; + let kv_slice = select(workgroup_idx, workgroup_idx - uniforms.num_slices_per_kv, is_value); + if (kv_slice >= uniforms.num_slices_per_kv) { return; } + + // Compute total_seq_length +#if use_seqlen_k + let total_seq_length = u32(seqlen_k[0u]) + 1u; +#else + let total_seq_length = uniforms.total_sequence_length; +#endif + + // uniforms.kv_sequence_length is the sequence length of the new key/values. + let past_seq_length = total_seq_length - uniforms.kv_sequence_length; + + // Unflatten kv_slice into (batch, head, seq). The seq extent depends on whether + // past tokens also need processing (copy) or just new tokens. +#if past_present_share_buffer + let copy_seq_length = uniforms.kv_sequence_length; +#elif has_past + let copy_seq_length = total_seq_length; +#else + let copy_seq_length = uniforms.kv_sequence_length; +#endif + let batch = kv_slice / (uniforms.kv_num_heads * copy_seq_length); + let head = (kv_slice / copy_seq_length) % uniforms.kv_num_heads; + let seq = kv_slice % copy_seq_length; + + // Prepare indirect dispatch buffer (first workgroup, first thread only). +#if prepare_indirect_dispatch + if (workgroup_idx == 0u && local_idx == 0u) { + let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; + populate_indirect_dispatch_buffer(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); + } +#endif + + // Compute destination offset in present_key/present_value (u32 packed, always BNSH layout). +#if past_present_share_buffer + let dest_seq = past_seq_length + seq; +#else + let dest_seq = seq; +#endif + let present_base = ((batch * uniforms.kv_num_heads + head) * uniforms.present_seq_length + dest_seq) * COMPRESSED_HEAD_U32; + + // Handle past tokens: simple u32 word copy (already quantized). +#if has_past + if (seq < past_seq_length) { + let past_base = ((batch * uniforms.kv_num_heads + head) * past_seq_length + seq) * COMPRESSED_HEAD_U32; + for (var i = local_idx; i < COMPRESSED_HEAD_U32; i += workgroup_size_x) { + if (!is_value) { + present_key.setByOffset(present_base + i, past_key.getByOffset(past_base + i)); + } else { + present_value.setByOffset(present_base + i, past_value.getByOffset(past_base + i)); + } + } + return; + } + let new_seq = seq - past_seq_length; +#else + let new_seq = seq; +#endif + + // Compute source offset in K/V for new tokens. +#if kv_BNSH + let src_base = key.indicesToOffset(key_indices_t(batch, head, new_seq, 0u)); +#else + let src_base = key.indicesToOffset(key_indices_t(batch, new_seq, head, 0u)); +#endif + + // Load from K or V into shared memory (f32). + for (var i = local_idx; i < HEAD_SIZE_VEC; i += workgroup_size_x) { + var val : key_value_t; + if (!is_value) { + val = key.getByOffset(src_base + i); + } else { + val = value.getByOffset(src_base + i); + } +#if components == 4 + hadamard_buffer[i * 4u + 0u] = f32(val[0]); + hadamard_buffer[i * 4u + 1u] = f32(val[1]); + hadamard_buffer[i * 4u + 2u] = f32(val[2]); + hadamard_buffer[i * 4u + 3u] = f32(val[3]); +#elif components == 2 + hadamard_buffer[i * 2u + 0u] = f32(val[0]); + hadamard_buffer[i * 2u + 1u] = f32(val[1]); +#else + hadamard_buffer[i] = f32(val); +#endif + } + workgroupBarrier(); + + // Walsh-Hadamard Transform butterfly passes. + wht_butterfly(local_idx, workgroup_size_x); + + // WHT normalization. + let wht_norm = 1.0f / sqrt(f32(HEAD_SIZE)); + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + hadamard_buffer[i] *= wht_norm; + } + workgroupBarrier(); + + // Compute L2 norm via parallel reduction. + var partial_sq_sum = 0.0f; + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + partial_sq_sum += hadamard_buffer[i] * hadamard_buffer[i]; + } + scale_reduction_buffer[local_idx] = partial_sq_sum; + workgroupBarrier(); + // Tree reduction: each iteration halves active threads, accumulating the total + // sum of squares into scale_reduction_buffer[0] in log2(workgroup_size) steps. + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride >>= 1u) { + if (local_idx < stride) { + scale_reduction_buffer[local_idx] += scale_reduction_buffer[local_idx + stride]; + } + workgroupBarrier(); + } + let l2_norm = sqrt(scale_reduction_buffer[0]); + let inv_l2 = select(0.0f, 1.0f / l2_norm, l2_norm > 0.0f); + + // Quantize: compute centroid index for each element and store in index_buffer. + for (var i = local_idx; i < HEAD_SIZE; i += workgroup_size_x) { + let unit_val = hadamard_buffer[i] * inv_l2; + index_buffer[i] = snap_to_centroid_index(unit_val); + } + workgroupBarrier(); + + // Pack 8 indices per u32 word and write to output. + // Thread 0 writes the norm word. + if (local_idx == 0u) { + if (!is_value) { + present_key.setByOffset(present_base, bitcast(l2_norm)); + } else { + present_value.setByOffset(present_base, bitcast(l2_norm)); + } + } + + // Each thread packs one or more u32 words (HEAD_SIZE/8 words total). + let num_packed_words = HEAD_SIZE >> 3u; + for (var w = local_idx; w < num_packed_words; w += workgroup_size_x) { + let base_elem = w << 3u; + var packed = 0u; + packed |= (index_buffer[base_elem + 0u] & 0xFu); + packed |= (index_buffer[base_elem + 1u] & 0xFu) << 4u; + packed |= (index_buffer[base_elem + 2u] & 0xFu) << 8u; + packed |= (index_buffer[base_elem + 3u] & 0xFu) << 12u; + packed |= (index_buffer[base_elem + 4u] & 0xFu) << 16u; + packed |= (index_buffer[base_elem + 5u] & 0xFu) << 20u; + packed |= (index_buffer[base_elem + 6u] & 0xFu) << 24u; + packed |= (index_buffer[base_elem + 7u] & 0xFu) << 28u; + if (!is_value) { + present_key.setByOffset(present_base + 1u + w, packed); + } else { + present_value.setByOffset(present_base + 1u + w, packed); + } + } +} diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 632e04a36c7bf..ae4b63ef293e1 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -103,6 +103,20 @@ class ComputeContextBase { return ep_.MultiRotaryCacheConcatOffset(); } + // + // Get the KV cache quantization bits (0 = disabled, 4 = 4-bit). + // + inline uint32_t KvCacheQuantizationBits() const { + return ep_.KvCacheQuantizationBits(); + } + + // + // Get whether KV cache quantization is enabled. + // + inline bool KvCacheQuantizationEnabled() const { + return ep_.KvCacheQuantizationEnabled(); + } + // // Get the logger. // diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 0992dd62d73c3..600591d3e3d98 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -582,6 +582,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, // enable_int64_ is always true when enable_graph_capture_ is true enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, + kv_cache_quantization_bits_{config.kv_cache_quantization_bits}, prepack_allocator_{std::make_shared( [this]() -> const webgpu::BufferManager& { return context_.InitializerBufferManager(); }, false)} { if (enable_graph_capture_ && config.session_buffer_pool_generations > 0) { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 20fab4b3361a7..cf0c19e2e44f8 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -51,6 +51,7 @@ struct WebGpuExecutionProviderConfig { // across captured-graph lifetimes. 0 disables pooling. Default 1 caches one // generator's worth of intermediate buffers. size_t session_buffer_pool_generations{1}; + uint32_t kv_cache_quantization_bits{0}; // KV cache quantization bits (0 = off, 4 = 4-bit) std::vector force_cpu_node_names{}; }; @@ -112,6 +113,8 @@ class WebGpuExecutionProvider : public IExecutionProvider { AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } std::span GetForceCpuNodeNames() const { return force_cpu_node_names_; } uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; } + uint32_t KvCacheQuantizationBits() const { return kv_cache_quantization_bits_; } + bool KvCacheQuantizationEnabled() const { return kv_cache_quantization_bits_ != 0; } #if defined(ORT_USE_EP_API_ADAPTERS) inline onnxruntime::ep::adapter::Logger& GetEpLogger() const { @@ -135,6 +138,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool graph_buffer_mgr_active_ = false; bool enable_int64_ = false; uint32_t multi_rotary_cache_concat_offset_ = 0; + uint32_t kv_cache_quantization_bits_ = 0; std::unordered_map graph_id_to_run_count_; // Required regular runs before graph capture for any necessary allocations. const int min_num_runs_before_graph_capture_ = 0; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index ebb15c014e05d..e9a2343be1b8d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -98,6 +98,17 @@ WebGpuExecutionProviderConfig ParseEpConfig(const ConfigOptions& config_options) } } + std::string kv_cache_quantization_bits_str; + if (config_options.TryGetConfigEntry(kKvCacheQuantizationBits, kv_cache_quantization_bits_str)) { + if (kv_cache_quantization_bits_str == kKvCacheQuantizationBits_OFF) { + webgpu_ep_config.kv_cache_quantization_bits = 0; + } else if (kv_cache_quantization_bits_str == kKvCacheQuantizationBits_4Bit) { + webgpu_ep_config.kv_cache_quantization_bits = 4; + } else { + ORT_THROW("Invalid kvCacheQuantizationBits value: ", kv_cache_quantization_bits_str, ". Must be \"0\" or \"4\"."); + } + } + // parse force CPU node names // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. // each line is a node name that will be forced to run on CPU. diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 3f6a49a4ab1e0..390d325d66ec5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -17,6 +17,7 @@ constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGr constexpr const char* kSessionBufferPoolGenerations = "ep.webgpuexecutionprovider.sessionBufferPoolGenerations"; constexpr const char* kEnableInt64 = "ep.webgpuexecutionprovider.enableInt64"; constexpr const char* kMultiRotaryCacheConcatOffset = "ep.webgpuexecutionprovider.multiRotaryCacheConcatOffset"; +constexpr const char* kKvCacheQuantizationBits = "ep.webgpuexecutionprovider.kvCacheQuantizationBits"; constexpr const char* kDawnProcTable = "ep.webgpuexecutionprovider.dawnProcTable"; @@ -67,6 +68,12 @@ constexpr const char* kEnablePIXCapture_OFF = "0"; constexpr const char* kPreserveDevice_ON = "1"; constexpr const char* kPreserveDevice_OFF = "0"; +// kKvCacheQuantizationBits value is the number of quantization bits as a string. +// "0" disables quantization; "4" enables 4-bit KV cache quantization. +// (Future: "8" for 8-bit.) +constexpr const char* kKvCacheQuantizationBits_OFF = "0"; +constexpr const char* kKvCacheQuantizationBits_4Bit = "4"; + constexpr const char* kBufferCacheMode_Disabled = "disabled"; constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; constexpr const char* kBufferCacheMode_Simple = "simple"; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 28c41b5bf5ed4..c6bcd3ec447e1 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -13,6 +13,9 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#ifdef USE_WEBGPU +#include "core/providers/webgpu/webgpu_provider_options.h" +#endif namespace onnxruntime { namespace test { @@ -2872,5 +2875,778 @@ TEST(GroupQueryAttentionTest, BatchedRightPaddedRotaryPrefillNonFlashAttention_W RunBatchedRightPaddedRotaryPrefillForEP(GqaTargetEp::kWebGpu, {4, 2, 6}, /*smooth_softmax=*/true); } +#ifdef USE_WEBGPU +// --------------------------------------------------------------------------- +// TurboQuant KV cache quantization tests. +// Tests exercise the TQ4 code paths in GroupQueryAttention + FlashAttention. +// The helpers below reference webgpu::options::* constants, which are only +// available when USE_WEBGPU is defined; guard the whole section so non-WebGPU +// test builds (CPU/CUDA) still compile the rest of this file. +// --------------------------------------------------------------------------- + +// Helper: creates a WebGPU EP with TurboQuant 4-bit enabled. +static std::unique_ptr WebGpuEPWithTurboQuant4() { + ConfigOptions config_options{}; + ORT_THROW_IF_ERROR(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled)); + ORT_THROW_IF_ERROR(config_options.AddConfigEntry(webgpu::options::kKvCacheQuantizationBits, + webgpu::options::kKvCacheQuantizationBits_4Bit)); + return WebGpuExecutionProviderWithOptions(config_options); +} + +// Helper to run a GQA op with TurboQuant enabled and separate Q/K/V with rotary. +// past_seq_len controls total KV cache depth; sequence_length controls prefill vs decode. +// Returns the output tensor data on success. +static std::vector RunGQATurboQuant( + int batch_size, + int sequence_length, + int past_seq_len, + int num_heads, + int kv_num_heads, + int head_size, + bool do_rotary, + bool is_packed_qkv, + OpTester::ExpectResult expect = OpTester::ExpectResult::kExpectSuccess, + const std::string& expected_error = "") { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int total_sequence_length = past_seq_len + sequence_length; + + // TQ4 compressed KV head dim: (head_size * 4 + 32) / 32 for float32 + const int kv_head_dim = (head_size * 4 + 32) / 32; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (do_rotary) { + tester.AddAttribute("do_rotary", static_cast(1)); + } + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + if (is_packed_qkv) { + const int packed_dim = hidden_size + 2 * kv_hidden_size; + std::vector packed_data(batch_size * sequence_length * packed_dim); + for (auto& v : packed_data) v = dist(rng); + tester.AddInput("query", {batch_size, sequence_length, packed_dim}, packed_data); + tester.AddOptionalInputEdge(); // key + tester.AddOptionalInputEdge(); // value + } else { + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + for (auto& v : query_data) v = dist(rng); + for (auto& v : key_data) v = dist(rng); + for (auto& v : value_data) v = dist(rng); + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + } + + // Past KV in compressed TQ4 format (float payload whose raw bits are interpreted as u32-packed data). + const int past_kv_size = batch_size * kv_num_heads * past_seq_len * kv_head_dim; + std::vector past_key_data(past_kv_size); + std::vector past_value_data(past_kv_size); + for (auto& v : past_key_data) v = dist(rng); + for (auto& v : past_value_data) v = dist(rng); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_key_data); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_value_data); + + std::vector tq_seqlens_k(batch_size, total_sequence_length - 1); + tester.AddInput("seqlens_k", {batch_size}, tq_seqlens_k); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + if (do_rotary) { + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + } else { + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + } + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_seq_len = total_sequence_length; + const int present_size = batch_size * kv_num_heads * present_seq_len * kv_head_dim; + tester.AddOutput("present_key", {batch_size, kv_num_heads, present_seq_len, kv_head_dim}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, present_seq_len, kv_head_dim}, + std::vector(present_size, 0.0f)); + + // TurboQuant present_key/present_value are u32-packed quantized data reinterpreted as float. + // Values can be astronomically large, so skip value checks via custom verifier. + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([batch_size, sequence_length, hidden_size, kv_num_heads, present_seq_len, kv_head_dim]( + const std::vector& fetches, + const std::string& /*provider_type*/) { + ASSERT_EQ(fetches.size(), 3u); + ASSERT_TRUE(fetches[0].IsTensor()); + ASSERT_TRUE(fetches[1].IsTensor()); + ASSERT_TRUE(fetches[2].IsTensor()); + + const auto& out_tensor = fetches[0].Get(); + EXPECT_EQ(out_tensor.Shape().NumDimensions(), 3); + EXPECT_EQ(out_tensor.Shape()[0], batch_size); + EXPECT_EQ(out_tensor.Shape()[1], sequence_length); + EXPECT_EQ(out_tensor.Shape()[2], hidden_size); + + const auto& pk = fetches[1].Get(); + EXPECT_EQ(pk.Shape().NumDimensions(), 4); + EXPECT_EQ(pk.Shape()[0], batch_size); + EXPECT_EQ(pk.Shape()[1], kv_num_heads); + EXPECT_EQ(pk.Shape()[2], present_seq_len); + EXPECT_EQ(pk.Shape()[3], kv_head_dim); + + const auto& pv = fetches[2].Get(); + EXPECT_EQ(pv.Shape().NumDimensions(), 4); + EXPECT_EQ(pv.Shape()[0], batch_size); + EXPECT_EQ(pv.Shape()[1], kv_num_heads); + EXPECT_EQ(pv.Shape()[2], present_seq_len); + EXPECT_EQ(pv.Shape()[3], kv_head_dim); + }); + + std::vector> execution_providers; + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + execution_providers.push_back(std::move(ep)); + tester.Run(expect, expected_error, {}, nullptr, &execution_providers); + + if (expect == OpTester::ExpectResult::kExpectSuccess) { + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); + } + return {}; +} + +// --- Error path: TurboQuant with smooth_softmax (non-flash attention) --- +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_RejectsNonFlashAttention) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 1; + constexpr int past_seq_len = 8; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 128; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + constexpr int kv_head_dim = (head_size * 4 + 32) / 32; + constexpr int total_sequence_length = past_seq_len + sequence_length; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + tester.AddAttribute("smooth_softmax", static_cast(1)); // Forces non-flash path + + std::vector query_data(batch_size * sequence_length * hidden_size, 0.1f); + std::vector key_data(batch_size * sequence_length * kv_hidden_size, 0.1f); + std::vector value_data(batch_size * sequence_length * kv_hidden_size, 0.1f); + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + std::vector past_key_data(batch_size * kv_num_heads * past_seq_len * kv_head_dim, 0.0f); + std::vector past_value_data(batch_size * kv_num_heads * past_seq_len * kv_head_dim, 0.0f); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_key_data); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_value_data); + + tester.AddInput("seqlens_k", {batch_size}, {total_sequence_length - 1}); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * kv_head_dim; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectFailure, + "KV cache quantization requires flash attention", + {}, nullptr, &execution_providers); +} + +// --- Error path: TurboQuant with invalid head_size (not power of 2) --- +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_RejectsNonPowerOf2HeadSize) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 1; + constexpr int past_seq_len = 8; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 96; // Not a power of 2 + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + // Even with invalid head_size, we still need a valid past shape for the test to reach the check. + // Use the would-be compressed dim (though it won't actually be used since the op errors out). + constexpr int kv_head_dim = (head_size * 4 + 32) / 32; + constexpr int total_sequence_length = past_seq_len + sequence_length; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + + std::vector query_data(batch_size * sequence_length * hidden_size, 0.1f); + std::vector key_data(batch_size * sequence_length * kv_hidden_size, 0.1f); + std::vector value_data(batch_size * sequence_length * kv_hidden_size, 0.1f); + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + std::vector past_key_data(batch_size * kv_num_heads * past_seq_len * kv_head_dim, 0.0f); + std::vector past_value_data(batch_size * kv_num_heads * past_seq_len * kv_head_dim, 0.0f); + tester.AddInput("past_key", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_key_data); + tester.AddInput("past_value", {batch_size, kv_num_heads, past_seq_len, kv_head_dim}, past_value_data); + + tester.AddInput("seqlens_k", {batch_size}, {total_sequence_length - 1}); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * kv_head_dim; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectFailure, + "KV cache quantization requires head_size >= 8 and a power of 2", + {}, nullptr, &execution_providers); +} + +// --- Success paths: TurboQuant with flash attention at various K sizes --- +// K=1 (decode with minimal past), K=24 (moderate), K=128 (large) +// These exercise the split-reduce decode path (QKV + VxReduce kernels) for seq_len=1, +// and the prefill path (single FlashAttentionProgram kernel) for seq_len>1. + +// Decode (sequence_length=1) with separate K/V, no rotary. past_seq_len controls k_size. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Decode_K1) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/1, /*past_seq_len=*/1, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant decode K=1 output should not be all zeros"; +} + +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Decode_K24) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/1, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant decode K=24 output should not be all zeros"; +} + +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Decode_K128) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/1, /*past_seq_len=*/128, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant decode K=128 output should not be all zeros"; +} + +// Prefill (sequence_length > 1) with separate K/V, no rotary. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Prefill_K1) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/4, /*past_seq_len=*/1, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant prefill K=1 output should not be all zeros"; +} + +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Prefill_K24) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/4, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant prefill K=24 output should not be all zeros"; +} + +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Prefill_K128) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/4, /*past_seq_len=*/128, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant prefill K=128 output should not be all zeros"; +} + +// Decode with rotary embedding (separate K/V path). +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Decode_Rotary_K24) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/1, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/true, /*is_packed_qkv=*/false); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant decode rotary K=24 output should not be all zeros"; +} + +// Decode with packed QKV + rotary (fused split+rotary+Hadamard+quantize path). +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Decode_PackedRotary_K24) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/1, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/true, /*is_packed_qkv=*/true); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant decode packed+rotary K=24 output should not be all zeros"; +} + +// Prefill with packed QKV + rotary (fused path, sequence_length > 1). +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Prefill_PackedRotary_K24) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + auto output = RunGQATurboQuant(/*batch_size=*/1, /*sequence_length=*/4, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/true, /*is_packed_qkv=*/true); + bool all_zero = std::all_of(output.begin(), output.end(), [](float v) { return v == 0.0f; }); + EXPECT_FALSE(all_zero) << "TurboQuant prefill packed+rotary K=24 output should not be all zeros"; +} + +// --- Error path: multi-batch with per-batch seqlens_k is rejected with TurboQuant --- +// WebGPU flash attention does not implement right-padded per-batch prefill, so +// CanApplyFlashAttention() returns false whenever batch_size > 1 and per-batch +// seqlens_k are supplied with a non-empty KV cache. KV cache quantization requires +// flash attention, so this configuration is rejected rather than silently falling +// back to the (uncompressed-only) non-flash path. genai decode runs batch_size==1, +// so this is not a supported production path. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_RejectsMultiBatch) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + RunGQATurboQuant(/*batch_size=*/2, /*sequence_length=*/1, /*past_seq_len=*/24, + /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, + /*do_rotary=*/false, /*is_packed_qkv=*/false, + OpTester::ExpectResult::kExpectFailure, + "KV cache quantization requires flash attention"); +} + +// --------------------------------------------------------------------------- +// TurboQuant cross-validation tests: compare TQ output vs non-TQ reference. +// With past_seq_len=0 (no past KV), both versions receive identical Q/K/V. +// The Hadamard transform is orthogonal (preserves dot products), so TQ output +// should approximate non-TQ output within 4-bit quantization error bounds. +// --------------------------------------------------------------------------- + +// Helper: runs GQA without TurboQuant (standard uncompressed KV cache) and returns the output. +static std::vector RunGQAReference( + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + bool do_rotary) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int total_sequence_length = sequence_length; // no past + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (do_rotary) { + tester.AddAttribute("do_rotary", static_cast(1)); + } + + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + tester.AddOptionalInputEdge(); // past_key + tester.AddOptionalInputEdge(); // past_value + + std::vector seqlens_k(batch_size, total_sequence_length - 1); + tester.AddInput("seqlens_k", {batch_size}, seqlens_k); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + if (do_rotary) { + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + } else { + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + } + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * head_size; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, head_size}, + std::vector(present_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([](const std::vector&, const std::string&) {}); + + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); +} + +// Helper: runs GQA with TurboQuant4, past_seq_len=0, returns the output. +static std::vector RunGQATurboQuantNoPast( + int batch_size, + int sequence_length, + int num_heads, + int kv_num_heads, + int head_size, + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + bool do_rotary) { + const int hidden_size = num_heads * head_size; + const int kv_hidden_size = kv_num_heads * head_size; + const int total_sequence_length = sequence_length; // no past + const int kv_head_dim = (head_size * 4 + 32) / 32; + + OpTester tester("GroupQueryAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + if (do_rotary) { + tester.AddAttribute("do_rotary", static_cast(1)); + } + + tester.AddInput("query", {batch_size, sequence_length, hidden_size}, query_data); + tester.AddInput("key", {batch_size, sequence_length, kv_hidden_size}, key_data); + tester.AddInput("value", {batch_size, sequence_length, kv_hidden_size}, value_data); + + // Empty past with compressed head dim so shape inference derives correct present shape. + tester.AddInput("past_key", {batch_size, kv_num_heads, static_cast(0), kv_head_dim}, {}); + tester.AddInput("past_value", {batch_size, kv_num_heads, static_cast(0), kv_head_dim}, {}); + + std::vector seqlens_k(batch_size, total_sequence_length - 1); + tester.AddInput("seqlens_k", {batch_size}, seqlens_k); + tester.AddInput("total_sequence_length", {1}, {total_sequence_length}, /*is_initializer=*/true); + + if (do_rotary) { + const int max_seq_len = total_sequence_length + 8; + const int half_rotary = head_size / 2; + std::vector cos_cache(max_seq_len * half_rotary); + std::vector sin_cache(max_seq_len * half_rotary); + for (int pos = 0; pos < max_seq_len; ++pos) { + for (int d = 0; d < half_rotary; ++d) { + float freq = 1.0f / std::pow(10000.0f, 2.0f * static_cast(d) / static_cast(head_size)); + cos_cache[pos * half_rotary + d] = std::cos(static_cast(pos) * freq); + sin_cache[pos * half_rotary + d] = std::sin(static_cast(pos) * freq); + } + } + tester.AddInput("cos_cache", {max_seq_len, half_rotary}, cos_cache); + tester.AddInput("sin_cache", {max_seq_len, half_rotary}, sin_cache); + } else { + tester.AddOptionalInputEdge(); // cos_cache + tester.AddOptionalInputEdge(); // sin_cache + } + + tester.AddOptionalInputEdge(); // position_ids + tester.AddOptionalInputEdge(); // attention_bias + tester.AddOptionalInputEdge(); // head_sink + + const int output_size = batch_size * sequence_length * hidden_size; + tester.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(output_size, 0.0f)); + const int present_size = batch_size * kv_num_heads * total_sequence_length * kv_head_dim; + tester.AddOutput("present_key", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + tester.AddOutput("present_value", {batch_size, kv_num_heads, total_sequence_length, kv_head_dim}, + std::vector(present_size, 0.0f)); + + tester.SetOutputTolerance(1e6f); + tester.SetCustomOutputVerifier([](const std::vector&, const std::string&) {}); + + std::vector> execution_providers; + execution_providers.push_back(WebGpuEPWithTurboQuant4()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + + auto fetches = tester.GetFetches(); + const float* out_data = fetches[0].Get().Data(); + return std::vector(out_data, out_data + output_size); +} + +// Cross-validate TQ vs non-TQ: Prefill with 4 tokens, no past, no rotary. +// With 4-bit quantization (16 centroids), expect bounded error. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_CrossValidate_Prefill_NoRotary) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 4; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 128; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + // Deterministic data + std::mt19937 rng(123); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + for (auto& v : query_data) v = dist(rng); + for (auto& v : key_data) v = dist(rng); + for (auto& v : value_data) v = dist(rng); + + auto ref_output = RunGQAReference(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + auto tq_output = RunGQATurboQuantNoPast(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + + ASSERT_EQ(ref_output.size(), tq_output.size()); + float max_abs_err = 0.0f; + float sum_sq_err = 0.0f; + float sum_sq_ref = 0.0f; + for (size_t i = 0; i < ref_output.size(); i++) { + float err = std::abs(ref_output[i] - tq_output[i]); + max_abs_err = std::max(max_abs_err, err); + sum_sq_err += (ref_output[i] - tq_output[i]) * (ref_output[i] - tq_output[i]); + sum_sq_ref += ref_output[i] * ref_output[i]; + } + float rmse = std::sqrt(sum_sq_err / static_cast(ref_output.size())); + float relative_rmse = (sum_sq_ref > 0) ? std::sqrt(sum_sq_err / sum_sq_ref) : rmse; + + // 4-bit quantization with 16 MSE-optimal centroids: expect relative RMSE < 20% + // and max absolute error bounded (values are in [-0.5, 0.5] range). + EXPECT_LT(relative_rmse, 0.2f) << "TurboQuant relative RMSE too large: " << relative_rmse; + EXPECT_LT(max_abs_err, 0.3f) << "TurboQuant max absolute error too large: " << max_abs_err; +} + +// Cross-validate TQ vs non-TQ: Prefill with rotary embedding. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_CrossValidate_Prefill_Rotary) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 4; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 128; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + std::mt19937 rng(456); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + for (auto& v : query_data) v = dist(rng); + for (auto& v : key_data) v = dist(rng); + for (auto& v : value_data) v = dist(rng); + + auto ref_output = RunGQAReference(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/true); + auto tq_output = RunGQATurboQuantNoPast(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/true); + + ASSERT_EQ(ref_output.size(), tq_output.size()); + float max_abs_err = 0.0f; + float sum_sq_err = 0.0f; + float sum_sq_ref = 0.0f; + for (size_t i = 0; i < ref_output.size(); i++) { + float err = std::abs(ref_output[i] - tq_output[i]); + max_abs_err = std::max(max_abs_err, err); + sum_sq_err += (ref_output[i] - tq_output[i]) * (ref_output[i] - tq_output[i]); + sum_sq_ref += ref_output[i] * ref_output[i]; + } + float rmse = std::sqrt(sum_sq_err / static_cast(ref_output.size())); + float relative_rmse = (sum_sq_ref > 0) ? std::sqrt(sum_sq_err / sum_sq_ref) : rmse; + + EXPECT_LT(relative_rmse, 0.2f) << "TurboQuant+rotary relative RMSE too large: " << relative_rmse; + EXPECT_LT(max_abs_err, 0.3f) << "TurboQuant+rotary max absolute error too large: " << max_abs_err; +} + +// Cross-validate: single decode token (sequence_length=1, past_seq_len=0). +// This exercises the split-reduce decode kernel path. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_CrossValidate_Decode) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 1; + constexpr int num_heads = 2; + constexpr int kv_num_heads = 1; + constexpr int head_size = 128; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + std::mt19937 rng(789); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + for (auto& v : query_data) v = dist(rng); + for (auto& v : key_data) v = dist(rng); + for (auto& v : value_data) v = dist(rng); + + auto ref_output = RunGQAReference(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + auto tq_output = RunGQATurboQuantNoPast(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + + ASSERT_EQ(ref_output.size(), tq_output.size()); + float max_abs_err = 0.0f; + float sum_sq_err = 0.0f; + float sum_sq_ref = 0.0f; + for (size_t i = 0; i < ref_output.size(); i++) { + float err = std::abs(ref_output[i] - tq_output[i]); + max_abs_err = std::max(max_abs_err, err); + sum_sq_err += (ref_output[i] - tq_output[i]) * (ref_output[i] - tq_output[i]); + sum_sq_ref += ref_output[i] * ref_output[i]; + } + float relative_rmse = (sum_sq_ref > 0) ? std::sqrt(sum_sq_err / sum_sq_ref) : std::sqrt(sum_sq_err); + + // Single token: attention is just softmax(Q*K^T/sqrt(d)) * V with a single KV pair. + // Quantization error on one pair should be small. + EXPECT_LT(relative_rmse, 0.15f) << "TurboQuant decode relative RMSE too large: " << relative_rmse; + EXPECT_LT(max_abs_err, 0.25f) << "TurboQuant decode max absolute error too large: " << max_abs_err; +} + +// Cross-validate: longer prefill (8 tokens) with multiple KV heads. +TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_CrossValidate_Prefill8_MultiKVHead) { + auto ep = WebGpuEPWithTurboQuant4(); + if (!ep) { + GTEST_SKIP() << "WebGPU EP not available"; + } + + constexpr int batch_size = 1; + constexpr int sequence_length = 8; + constexpr int num_heads = 4; + constexpr int kv_num_heads = 2; + constexpr int head_size = 128; + constexpr int hidden_size = num_heads * head_size; + constexpr int kv_hidden_size = kv_num_heads * head_size; + + std::mt19937 rng(1001); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::vector query_data(batch_size * sequence_length * hidden_size); + std::vector key_data(batch_size * sequence_length * kv_hidden_size); + std::vector value_data(batch_size * sequence_length * kv_hidden_size); + for (auto& v : query_data) v = dist(rng); + for (auto& v : key_data) v = dist(rng); + for (auto& v : value_data) v = dist(rng); + + auto ref_output = RunGQAReference(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + auto tq_output = RunGQATurboQuantNoPast(batch_size, sequence_length, num_heads, kv_num_heads, head_size, + query_data, key_data, value_data, /*do_rotary=*/false); + + ASSERT_EQ(ref_output.size(), tq_output.size()); + float max_abs_err = 0.0f; + float sum_sq_err = 0.0f; + float sum_sq_ref = 0.0f; + for (size_t i = 0; i < ref_output.size(); i++) { + float err = std::abs(ref_output[i] - tq_output[i]); + max_abs_err = std::max(max_abs_err, err); + sum_sq_err += (ref_output[i] - tq_output[i]) * (ref_output[i] - tq_output[i]); + sum_sq_ref += ref_output[i] * ref_output[i]; + } + float relative_rmse = (sum_sq_ref > 0) ? std::sqrt(sum_sq_err / sum_sq_ref) : std::sqrt(sum_sq_err); + + EXPECT_LT(relative_rmse, 0.2f) << "TurboQuant 8-token multi-head relative RMSE too large: " << relative_rmse; + EXPECT_LT(max_abs_err, 0.3f) << "TurboQuant 8-token multi-head max absolute error too large: " << max_abs_err; +} + +#endif // USE_WEBGPU + } // namespace test } // namespace onnxruntime From 22f1b1ad47d203b7295c2ba03d9e955026829303 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 26 Jun 2026 16:37:59 -0700 Subject: [PATCH 2/3] rebase - fixup --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 2 +- ...it_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 45f1af8939a2c..a3a3ea0d7f5d2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -135,7 +135,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << " if (global_idx == 0u) {\n" << " let global_total_seq_length = u32(total_sequence_length_input[0]);\n" << " let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" - << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" + << " populate_indirect_dispatch_buffer(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index 91a087801a252..6d88f883a5abb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -49,7 +49,7 @@ $MAIN { if (global_idx == 0u) { let global_total_seq_length = u32(total_sequence_length_input[0]); let num_total_seq_length_tile = (global_total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size; - normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); + populate_indirect_dispatch_buffer(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif From b1b0bdf4eb7b71b15cef8f0c000cd7a90597203c Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Fri, 26 Jun 2026 17:29:37 -0700 Subject: [PATCH 3/3] fix test --- .../webgpu/bert/flash_attention.cc | 2 +- .../group_query_attention_op_test.cc | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index a3a3ea0d7f5d2..ff6938c0e9bf4 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -457,7 +457,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const bool turbo_quant_enabled = context.KvCacheQuantizationEnabled(); if (turbo_quant_enabled && (parameters.head_size_ < 8 || (parameters.head_size_ & (parameters.head_size_ - 1)) != 0)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "KV cache quantization requires head_size >= 8 and a power of 2. Got head_size=", + "KV cache quantization requires head_size >= 8 and a power of 2. Got head_size=", parameters.head_size_); } diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index c6bcd3ec447e1..ef7a04bc67d4f 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -3025,7 +3025,10 @@ static std::vector RunGQATurboQuant( std::vector> execution_providers; auto ep = WebGpuEPWithTurboQuant4(); if (!ep) { - GTEST_SKIP() << "WebGPU EP not available"; + // GTEST_SKIP() cannot be used in a value-returning helper (it expands to a + // void `return`). Callers already GTEST_SKIP() when the EP is unavailable, so + // this branch is unreachable in practice; return empty to keep the helper valid. + return {}; } execution_providers.push_back(std::move(ep)); tester.Run(expect, expected_error, {}, nullptr, &execution_providers); @@ -3276,12 +3279,12 @@ TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_Prefill_PackedRotary_K24) { } // --- Error path: multi-batch with per-batch seqlens_k is rejected with TurboQuant --- -// WebGPU flash attention does not implement right-padded per-batch prefill, so -// CanApplyFlashAttention() returns false whenever batch_size > 1 and per-batch -// seqlens_k are supplied with a non-empty KV cache. KV cache quantization requires -// flash attention, so this configuration is rejected rather than silently falling -// back to the (uncompressed-only) non-flash path. genai decode runs batch_size==1, -// so this is not a supported production path. +// The TurboQuant copy-to-quantized-KV-cache kernel reads seqlen_k[0] for every +// batch on the graph-capture decode path, so it only supports batch_size == 1 and +// explicitly rejects batch_size > 1 rather than silently corrupting batches 1..N-1. +// (The non-quantized flash-attention copy path does support per-batch seqlens_k, so +// this restriction is specific to KV cache quantization.) genai decode runs +// batch_size==1, so multi-batch is not a supported production path. TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_RejectsMultiBatch) { auto ep = WebGpuEPWithTurboQuant4(); if (!ep) { @@ -3291,7 +3294,7 @@ TEST(GroupQueryAttentionTest, WebGPU_TurboQuant_RejectsMultiBatch) { /*num_heads=*/2, /*kv_num_heads=*/1, /*head_size=*/128, /*do_rotary=*/false, /*is_packed_qkv=*/false, OpTester::ExpectResult::kExpectFailure, - "KV cache quantization requires flash attention"); + "supports batch_size == 1 only"); } // ---------------------------------------------------------------------------