feat(gfx1151): custom head-dim-tiled Triton flash attention for Qwen3.x ViT#1357
Open
carlushuang wants to merge 4 commits into
Open
feat(gfx1151): custom head-dim-tiled Triton flash attention for Qwen3.x ViT#1357carlushuang wants to merge 4 commits into
carlushuang wants to merge 4 commits into
Conversation
On gfx1151 (RDNA3.5) torch's scaled_dot_product_attention falls back to the
unfused math backend for the vision encoder (flash/mem-efficient SDPA are
disabled on this arch), which is slow and O(N^2) memory -- large images OOM.
Route the ViT self-attention to aiter's Triton prefill kernel
(context_attention_fwd) instead, which runs on gfx1151 at the ViT head_dim
(72) and is a proper flash attention (O(N) memory). Per-image cu_seqlens are
built from grid_thw and threaded through VisionTransformer -> VisionBlock ->
VisionAttention, so each image attends only within its own patches (this also
fixes the previous full cross-image attention for multi-image inputs).
Gated on `not aiter_hip_kernels_supported()` so only non-gfx9 arches use the
Triton path; gfx9/CDNA keep SDPA (which has a fast flash backend there).
torch SDPA remains the fallback when aiter is unavailable.
Output is unchanged (validated: image color identification still correct).
Attention kernel speedup vs torch SDPA on a single Radeon 8060S (16 heads,
head_dim 72, bf16, non-causal):
tokens ~image triton sdpa speedup
1024 512px 0.89ms 4.33ms 4.9x
2048 724px 2.99ms 16.58ms 5.6x
4096 1024px 11.0ms 65.8ms 6.0x
8192 1448px 45.3ms 287ms 6.4x
Full 27-layer ViT forward: 1.8x at 512px, 2.6x at 768px (grows with size).
The Triton prefill kernel is pathologically slow when head_dim is not a multiple of 16 (the WMMA load granularity): at the Qwen3 ViT head_dim 72 it runs ~2x slower than head_dim 80 or 128 due to unaligned masked head-dim loads (measured on a Radeon 8060S: d72=11.0ms vs d80=5.4ms vs d128=6.7ms at N=4096). It still pads internally to next_power_of_2=128 either way, so the penalty is purely the unaligned actual head_dim. Zero-pad q/k/v to the next multiple of 16 before the kernel and slice the output back. The padded dims contribute nothing to QK^T / AV, so the result is unchanged (max relerr unchanged, ~0.008 vs SDPA). ~2x across sizes: tokens d72 pad80 speedup 1024 0.91ms 0.45ms 2.03x 2048 2.99ms 1.50ms 2.00x 4096 11.0ms 5.34ms 2.06x Combined with the Triton switch this is ~10-12x over the torch SDPA math fallback. No-op when head_dim is already a multiple of 16.
Replace the lightllm prefill kernel (+ external head_dim pad-to-80) with a purpose-built Triton flash attention for the Qwen3.x vision encoder. It tiles head_dim into 5x16=80 chunks instead of padding to the next power of two (128), so the QK^T / AV contraction is 80-deep rather than 128 -- 1.6x fewer WMMA k-steps, and no external F.pad per layer. Triton constraints handled: tl.arange must be pow2 and list/generator comprehensions are unsupported, so the 5 head tiles are hand-unrolled, keeping every tile a power-of-two tensor ([BLOCK_M,16] / [BLOCK_M,BLOCK_N]). Per-image varlen (cu_seqlens from grid_thw) so each image attends only within its patches. Online softmax; non-causal. Gated to head_dim <= 80; falls back to SDPA otherwise and on gfx9/CDNA. Measured on a Radeon 8060S (head_dim 72, bf16, non-causal), vs the previous pad-80 lightllm kernel: tokens tiled-80 pad-80 speedup real TFLOPS 1024 0.23ms 0.39ms 1.67x 20.6 2048 0.82ms 1.33ms 1.62x 23.5 4096 3.15ms 5.03ms 1.59x 24.5 That is ~21x over the torch SDPA math fallback (65.8ms -> 3.15ms at 4096) and lifts attention to ~45% MFU, on par with the GEMMs. Output unchanged (max relerr < 0.007 vs SDPA, single- and multi-image). End-to-end image identification still correct with prefix caching on.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat(gfx1151): Triton flash attention for the Qwen3.x vision encoder
On gfx1151 (RDNA3.5 / Radeon 8060S) torch
scaled_dot_product_attentionfalls back to the unfused math backend for the ViT — flash/mem-efficient SDPA are disabled on this arch (AOTRITON "experimental" path). That's slow and O(N²) in memory, so large images are slow and can OOM.This routes the vision self-attention to aiter's Triton prefill kernel (
context_attention_fwd), which:head_dim=72(non-power-of-2),not aiter_hip_kernels_supported()); gfx9/CDNA keep SDPA (fast flash there), and torch SDPA stays as the fallback when aiter is unavailable.Per-image
cu_seqlensare built fromgrid_thwand threadedVisionTransformer → VisionBlock → VisionAttention, so each image attends only within its own patches — this also fixes the previous full cross-image attention for multi-image inputs.Correctness
Output unchanged — image color-identification test still correct (Red/Blue/Green). Numerically matches SDPA (max relerr < 0.008 across sizes).
Performance (single Radeon 8060S, gfx1151)
Attention kernel, 16 heads / head_dim 72 / bf16 / non-causal:
Final implementation (commit 3): a purpose-built head-dim-tiled kernel. head_dim 72 isn't a multiple of 16 (the WMMA k-tile), and every stock Triton attention pads it to the next power of two = 128 (
tl.arangemust be pow2), wasting 1.78x of the QK/AV contraction. The custom kernel instead hand-unrolls head_dim into 5×16=80 tiles (each tile a pow2 tensor), contracting 80-deep — no 128 padding, no externalF.pad. Per-image varlen viacu_seqlens; non-causal; online softmax. Gated tohead_dim ≤ 80, SDPA fallback otherwise.Attention kernel on a Radeon 8060S (head_dim 72, bf16, non-causal):
Real throughput went 7 → 14 → 24.5 TFLOPS (raw d72 → pad-80 → tiled-80) ≈ 45% MFU, on par with the GEMMs. The 3 commits show the progression (Triton switch → WMMA-align pad → tiled kernel).
Full 27-layer ViT forward: 1.8x @512px, 2.6x @768px (grows with image size; O(N) memory also avoids the large-image OOM).
Scope
Vision-attention only. ViT LayerNorm/GEMM/Conv3d stay torch (aiter's Triton LayerNorm runs but is slower;
gemm_a16w16≈ torch hipBLASLt; aiter CK/asm FMHA is CDNA-only and fails on gfx1151 — Triton is the only working flash path here).