Skip to content

feat(gfx1151): custom head-dim-tiled Triton flash attention for Qwen3.x ViT#1357

Open
carlushuang wants to merge 4 commits into
mainfrom
carhuang/gfx1151_vit_triton_attn
Open

feat(gfx1151): custom head-dim-tiled Triton flash attention for Qwen3.x ViT#1357
carlushuang wants to merge 4 commits into
mainfrom
carhuang/gfx1151_vit_triton_attn

Conversation

@carlushuang

@carlushuang carlushuang commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

feat(gfx1151): Triton flash attention for the Qwen3.x vision encoder

On gfx1151 (RDNA3.5 / Radeon 8060S) torch scaled_dot_product_attention falls back to the unfused math backend for the ViT — flash/mem-efficient SDPA are disabled on this arch (AOTRITON "experimental" path). That's slow and O(N²) in memory, so large images are slow and can OOM.

This routes the vision self-attention to aiter's Triton prefill kernel (context_attention_fwd), which:

  • runs correctly on gfx1151 at the ViT head_dim=72 (non-power-of-2),
  • is a proper flash attention (O(N) memory),
  • is selected only on non-gfx9 arches (not aiter_hip_kernels_supported()); gfx9/CDNA keep SDPA (fast flash there), and torch SDPA stays as the fallback when aiter is unavailable.

Per-image cu_seqlens are built from grid_thw and threaded VisionTransformer → VisionBlock → VisionAttention, so each image attends only within its own patches — this also fixes the previous full cross-image attention for multi-image inputs.

Correctness

Output unchanged — image color-identification test still correct (Red/Blue/Green). Numerically matches SDPA (max relerr < 0.008 across sizes).

Performance (single Radeon 8060S, gfx1151)

Attention kernel, 16 heads / head_dim 72 / bf16 / non-causal:

Final implementation (commit 3): a purpose-built head-dim-tiled kernel. head_dim 72 isn't a multiple of 16 (the WMMA k-tile), and every stock Triton attention pads it to the next power of two = 128 (tl.arange must be pow2), wasting 1.78x of the QK/AV contraction. The custom kernel instead hand-unrolls head_dim into 5×16=80 tiles (each tile a pow2 tensor), contracting 80-deep — no 128 padding, no external F.pad. Per-image varlen via cu_seqlens; non-causal; online softmax. Gated to head_dim ≤ 80, SDPA fallback otherwise.

Attention kernel on a Radeon 8060S (head_dim 72, bf16, non-causal):

tokens ~image tiled-80 (final) torch SDPA vs SDPA real TFLOPS
1024 512px 0.23 ms 4.33 ms ~19x 20.6
2048 724px 0.82 ms 16.6 ms ~20x 23.5
4096 1024px 3.15 ms 65.8 ms ~21x 24.5

Real throughput went 7 → 14 → 24.5 TFLOPS (raw d72 → pad-80 → tiled-80) ≈ 45% MFU, on par with the GEMMs. The 3 commits show the progression (Triton switch → WMMA-align pad → tiled kernel).

Full 27-layer ViT forward: 1.8x @512px, 2.6x @768px (grows with image size; O(N) memory also avoids the large-image OOM).

Scope

Vision-attention only. ViT LayerNorm/GEMM/Conv3d stay torch (aiter's Triton LayerNorm runs but is slower; gemm_a16w16 ≈ torch hipBLASLt; aiter CK/asm FMHA is CDNA-only and fails on gfx1151 — Triton is the only working flash path here).

On gfx1151 (RDNA3.5) torch's scaled_dot_product_attention falls back to the
unfused math backend for the vision encoder (flash/mem-efficient SDPA are
disabled on this arch), which is slow and O(N^2) memory -- large images OOM.

Route the ViT self-attention to aiter's Triton prefill kernel
(context_attention_fwd) instead, which runs on gfx1151 at the ViT head_dim
(72) and is a proper flash attention (O(N) memory). Per-image cu_seqlens are
built from grid_thw and threaded through VisionTransformer -> VisionBlock ->
VisionAttention, so each image attends only within its own patches (this also
fixes the previous full cross-image attention for multi-image inputs).

Gated on `not aiter_hip_kernels_supported()` so only non-gfx9 arches use the
Triton path; gfx9/CDNA keep SDPA (which has a fast flash backend there).
torch SDPA remains the fallback when aiter is unavailable.

Output is unchanged (validated: image color identification still correct).

Attention kernel speedup vs torch SDPA on a single Radeon 8060S (16 heads,
head_dim 72, bf16, non-causal):

  tokens  ~image   triton    sdpa   speedup
    1024   512px   0.89ms   4.33ms    4.9x
    2048   724px   2.99ms  16.58ms    5.6x
    4096  1024px  11.0ms   65.8ms     6.0x
    8192  1448px  45.3ms   287ms      6.4x

Full 27-layer ViT forward: 1.8x at 512px, 2.6x at 768px (grows with size).
The Triton prefill kernel is pathologically slow when head_dim is not a
multiple of 16 (the WMMA load granularity): at the Qwen3 ViT head_dim 72 it
runs ~2x slower than head_dim 80 or 128 due to unaligned masked head-dim
loads (measured on a Radeon 8060S: d72=11.0ms vs d80=5.4ms vs d128=6.7ms at
N=4096). It still pads internally to next_power_of_2=128 either way, so the
penalty is purely the unaligned actual head_dim.

Zero-pad q/k/v to the next multiple of 16 before the kernel and slice the
output back. The padded dims contribute nothing to QK^T / AV, so the result
is unchanged (max relerr unchanged, ~0.008 vs SDPA). ~2x across sizes:

  tokens   d72      pad80   speedup
   1024   0.91ms   0.45ms   2.03x
   2048   2.99ms   1.50ms   2.00x
   4096  11.0ms    5.34ms   2.06x

Combined with the Triton switch this is ~10-12x over the torch SDPA math
fallback. No-op when head_dim is already a multiple of 16.
Replace the lightllm prefill kernel (+ external head_dim pad-to-80) with a
purpose-built Triton flash attention for the Qwen3.x vision encoder. It tiles
head_dim into 5x16=80 chunks instead of padding to the next power of two (128),
so the QK^T / AV contraction is 80-deep rather than 128 -- 1.6x fewer WMMA
k-steps, and no external F.pad per layer.

Triton constraints handled: tl.arange must be pow2 and list/generator
comprehensions are unsupported, so the 5 head tiles are hand-unrolled, keeping
every tile a power-of-two tensor ([BLOCK_M,16] / [BLOCK_M,BLOCK_N]). Per-image
varlen (cu_seqlens from grid_thw) so each image attends only within its patches.
Online softmax; non-causal. Gated to head_dim <= 80; falls back to SDPA
otherwise and on gfx9/CDNA.

Measured on a Radeon 8060S (head_dim 72, bf16, non-causal), vs the previous
pad-80 lightllm kernel:

  tokens   tiled-80   pad-80    speedup   real TFLOPS
   1024     0.23ms    0.39ms     1.67x       20.6
   2048     0.82ms    1.33ms     1.62x       23.5
   4096     3.15ms    5.03ms     1.59x       24.5

That is ~21x over the torch SDPA math fallback (65.8ms -> 3.15ms at 4096) and
lifts attention to ~45% MFU, on par with the GEMMs. Output unchanged (max
relerr < 0.007 vs SDPA, single- and multi-image). End-to-end image
identification still correct with prefix caching on.
@carlushuang carlushuang changed the title feat(gfx1151): Triton flash attention for Qwen3.x vision encoder feat(gfx1151): custom head-dim-tiled Triton flash attention for Qwen3.x ViT Jun 26, 2026
@zufayu zufayu requested review from ZhangLirong-amd and removed request for ZhangLirong-amd June 26, 2026 06:12
@zufayu zufayu requested a review from yhl-amd June 26, 2026 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant