Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 26 additions & 32 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const

template <typename T>
Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width,
int& past_sequence_length) {
int& past_sequence_length, int kv_cache_extra_bits = 0) {
const auto& past_key_dims = past_key->Shape().GetDims();
const auto& past_value_dims = past_value->Shape().GetDims();

Expand Down Expand Up @@ -140,17 +140,25 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_
// We assume all sequence in past kv are right-padded to max or past sequence length
past_sequence_length = static_cast<int>(past_key_dims[2]);

// For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte.
// Note that we have checked that head_size is a multiple of 8 in Check_QKV.
int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size;
// Compute expected KV cache head dimension from quantization parameters.
// kv_cache_bit_width: bits per element (4 or 8). 0 means no quantization.
// kv_cache_extra_bits: additional metadata bits per head
// (e.g., 32bits for TurboQuant storing scale).
int packed_head_size;
if (kv_cache_bit_width == 0) {
packed_head_size = head_size;
} else {
int bits_per_element = static_cast<int>(past_key->DataType()->Size()) * 8;
packed_head_size = (head_size * kv_cache_bit_width + kv_cache_extra_bits) / bits_per_element;
}
if (past_key_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
"Input 'past_key' dimension 3 should match the packed KV head dimension, got ",
past_key_dims[3], " expected ", packed_head_size);
}
if (past_value_dims[3] != packed_head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
"Input 'past_value' dimension 3 should match the packed KV head dimension, got ",
past_value_dims[3], " expected ", packed_head_size);
}
return Status::OK();
Expand Down Expand Up @@ -206,7 +214,12 @@ Status CheckInputs(const T* query,
const T* total_seqlen,
float scale,
float softcap,
int kv_cache_bit_width) {
int kv_cache_bit_width,
int max_threads_per_block = 0,
int kv_cache_extra_bits = 0) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
// Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache
// past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
// past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr
Expand Down Expand Up @@ -246,10 +259,15 @@ Status CheckInputs(const T* query,
kv_sequence_length = sequence_length;
}

if (kv_cache_extra_bits != 0 && kv_cache_bit_width == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"kv_cache_extra_bits requires kv_cache_bit_width to be non-zero.");
}

// Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length));
ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length, kv_cache_extra_bits));
// When past KV exists, Q and K/V must have the same sequence length,
// UNLESS kv_sequence_length is 0 (shared KV: new K/V are empty, past buffer
// already contains the full shared KV cache — no append needed).
Expand Down Expand Up @@ -377,30 +395,6 @@ Status CheckInputs(const T* query,
return Status::OK();
}

template <typename T = Tensor>
Status CheckInputs(const T* query,
const T* key,
const T* value,
const T* past_key,
const T* past_value,
const T* cos_cache,
const T* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
const T* seqlens_k,
const T* total_seqlen,
float scale,
float softcap,
int kv_cache_bit_width,
int max_threads_per_block) {
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}

return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width);
}

template <typename T = Tensor>
Status CheckCustomAttentionInputs(const T* position_ids,
const T* attention_bias,
Expand Down
Loading
Loading