Skip to content

Add decode (flash-decoding) attention kernels to contributed/#129

Open
varuntej07 wants to merge 2 commits into
aws-neuron:mainfrom
varuntej07:contributed/decode-attention-gqa
Open

Add decode (flash-decoding) attention kernels to contributed/#129
varuntej07 wants to merge 2 commits into
aws-neuron:mainfrom
varuntej07:contributed/decode-attention-gqa

Conversation

@varuntej07

@varuntej07 varuntej07 commented Jun 11, 2026

Copy link
Copy Markdown

Description of changes:

Adds a new community kernel file, contributed/decode_attention.py, implementing the decode (single-query) step of autoregressive attention. This is the memory-bound complement to the compute-bound prefill kernel in contributed/pipelined_attention.py.

Two kernels are included:

  1. decode_attention_fwd - simplest correct version: one head, one KV tile, seqlen_kv <= 128. A single query position (seqlen_q = 1) attends over the full cached K/V, running QK -> scale -> softmax -> PV with the softmax scale applied.

  2. decode_attention_gqa_fwd - lifts the length limit with KV tiling and a running online softmax (state carried across tiles), and adds grouped-query attention so query heads sharing a KV head also share its K/V loads (loaded once per group).

Both use the d-on-partition-axis layout from the existing attention tutorials.

IO layouts:

  • decode_attention_fwd: q (d, 1), k (d, seqlen_kv), v (d, seqlen_kv) -> o (1, d)
  • decode_attention_gqa_fwd: q (d, n_q_heads), k/v (n_kv_heads, d, seqlen_kv) -> o (n_q_heads, d)

Current limits: d <= 128; GQA requires n_q_heads % n_kv_heads == 0 and seqlen_kv % TILE_KV == 0 (no padding yet); single batch element. Split-KV flash-decoding for very long context is noted as planned future work.

This kernel targets contributed/, not src/reference/, so the baremetal /benchmark / integration requirements (which CONTRIBUTING.md scopes to src/reference/) do not apply. Correctness is validated on CPU via nki.simulate_kernel against NumPy reference implementations included in the file (numpy_decode_reference and numpy_decode_gqa_reference, the repeat_kv +per-head softmax oracle). Run python contributed/decode_attention.py to execute both checks; the GQA case exercises 4 KV tiles with group = 4 and asserts agreement within atol = rtol = 1e-2. The kernels are not yet run on Neuron hardware; this is stated in the module docstring WARNING.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant