Skip to content

MoE: Grouped Triton GEMM for TTFT improvements#970

Draft
mgehre-amd wants to merge 3 commits into
matthias.hybrid-w4a16-qwen3-omni-tunefrom
matthias.hybrid-w4a16-grouped-gemm
Draft

MoE: Grouped Triton GEMM for TTFT improvements#970
mgehre-amd wants to merge 3 commits into
matthias.hybrid-w4a16-qwen3-omni-tunefrom
matthias.hybrid-w4a16-grouped-gemm

Conversation

@mgehre-amd
Copy link
Copy Markdown

@mgehre-amd mgehre-amd commented May 26, 2026

Hybrid W4A16 MoE: BK cap-lift + experimental grouped-GEMM prefill path

Stacks on top of #962 (group_size≤64 dispatch
tune for Qwen3-Omni-30B-A3B-AWQ). Adds two follow-up commits that explore further
prefill speedups on the same shape.

What

1. Allow BLOCK_K > group_size (b7c9d59a2d)

fused_moe_kernel_gptq_awq's shuffle_w4a16 path used to load one scale per K-tile,
which forced invoke_fused_moe_kernel_hybrid_triton to cap BLOCK_SIZE_K at
group_size. At group_size=32 (Qwen3-Omni) that meant 64 inner-loop iterations
for K=2048 — many short matmuls instead of a few large ones.

Add a constexpr-gated multi-scale path: when BLOCK_SIZE_K > group_size, load a
per-K-row scale tensor [BLOCK_K, BLOCK_N] (mirrors the non-shuffle wna16 path
that already does this); otherwise keep the original [BLOCK_N] one-scale-per-tile
path unchanged.

Net effect at the Qwen3-Omni shape: BK=32 still wins the per-gemm sweep
(register pressure caps gains at larger BK; this kernel is already near peak at
small BK on gfx1151). The lift is preserved for future tuning at other
(K, group_size) combinations.

2. Experimental grouped-GEMM prefill path (84660b1119)

The production path uses moe_align_block_size + sorted_token_ids with the
kernel doing a virtual gather of A. This pads num_slots up to
M*top_k + E*(BM-1), launching ~50% padded blocks that early-exit but still
consume kernel-launch slots.

Add an alternative apply path that uses the existing moe_permute /
moe_unpermute ops (already used by cutlass_moe and exllama_moe) to lay
activations out in expert-contiguous order, and a new Triton kernel
(fused_moe_kernel_hybrid_w4a16_grouped) that reads them contiguously, indexed
by a per-block (expert_id, m_start, m_count) table. No padding, no
expert_ids[block]==-1 sentinel, no virtual gather.

Gated behind VLLM_HYBRID_W4A16_GROUPED=1 (off by default). Only triggers for
prefill (num_tokens > 5) and only when expert_map is None and
apply_router_weight_on_input=False. Other config combinations fall through
to the existing path.

TTFT numbers (Qwen3-Omni-30B-A3B-AWQ, Strix Halo gfx1151)

Bench: --num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1.
All same-session, back-to-back runs.

Variant TTFT (ms) Δ vs bad Δ vs PR #962
Bad commit ee54c41 (before PR #962) 2228 0
User-measured good baseline (fresh .so) 1788 −20%
PR #962 tip (group_size≤64 tune only) 1867 −16% 0
+ BK cap-lift (b7c9d59a2d) 1867 −16% 0
+ Grouped GEMM ON (VLLM_HYBRID_W4A16_GROUPED=1) 1766 −21% −5.5%

The grouped path beats the user-measured good baseline by ~22 ms.

Correctness

tests/kernels/moe/test_hybrid_w4a16_moe.py: 70/70 pass on both the production
and grouped paths.

Standalone A/B sanity (production vs grouped, max abs diff in output):

Shape (m, n, k, e, topk, group_size) grouped vs padded
(128, 768, 2048, 128, 8, 32) Qwen3-Omni chunk 0.0 (bit-identical)
(2048, 768, 2048, 128, 8, 32) Qwen3-Omni full 0.0
(2048, 512, 2048, 256, 8, 128) Qwen3.5-A3B 0.0
(16, 256, 256, 8, 2, 32) tiny edge, force_triton 0.0

Reproducer

# 1. Check out the branch (depends on the gfx11 tip used here for stable .so)
git fetch origin
git checkout matthias.hybrid-w4a16-grouped-gemm

# 2. Run TTFT bench 
#
#    On this install the aiter mrope JIT module crashes at startup; revert the
#    wiring temporarily before the bench, restore afterwards.
git checkout 8e220d1ecb~1 -- vllm/model_executor/models/qwen3_moe.py vllm/_aiter_ops.py

# Production path (grouped GEMM OFF)
python vllm-bench.py \
    --model cyankiwi/Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit \
    --num-prompts 10 --max-model-len 8192 --ready-check-timeout-sec 1800 \
    --input-len 4096 --output-len 1 --dtype float16 --trust-remote-code \
    --target-gpu-memory-gb 28 --max-num-seqs 1 \
    -e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
    -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
    -e TORCH_BLAS_PREFER_HIPBLASLT=1

# Grouped GEMM ON: append one flag
python vllm-bench.py \
    --model cyankiwi/Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit \
    --num-prompts 10 --max-model-len 8192 --ready-check-timeout-sec 1800 \
    --input-len 4096 --output-len 1 --dtype float16 --trust-remote-code \
    --target-gpu-memory-gb 28 --max-num-seqs 1 \
    -e TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 \
    -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \
    -e TORCH_BLAS_PREFER_HIPBLASLT=1 \
    -e VLLM_HYBRID_W4A16_GROUPED=1

# Restore aiter wiring
git checkout HEAD -- vllm/model_executor/models/qwen3_moe.py vllm/_aiter_ops.py

# 3. Correctness:
gpu-lock .venv/bin/python -m pytest tests/kernels/moe/test_hybrid_w4a16_moe.py

Open questions / follow-ups (not in this PR)

  • Should the grouped path be made default once it has more model coverage?
    Currently env-gated to keep the production path the safe choice.
  • Block-table builder runs on device with one .item() sync per gemm. Could
    be optimized into a single fused kernel if a profile shows it on the
    critical path (current TTFT improvement suggests it isn't).
  • Grouped path bypasses apply_router_weight_on_input and expert_map
    extending it is straightforward but not needed for the workloads tuned here.
  • BK > group_size is now legal in the kernel but not yet emitted by any
    _triton_config branch; the lift is preserved for future shape tuning.

The existing Triton prefill MoE tune (commits 1490142 + 8af2e37)
was derived from Qwen3.5-A3B's shape (group_size=128, E=256).  At
small group_size the kernel wrapper caps BLOCK_K to group_size, so
the narrow-BN/small-BM strategy that wins for group_size=128 becomes
severely under-occupied -- regressing Qwen3-Omni-30B-A3B-AWQ-4bit
(group_size=32, E=128) TTFT by ~20% on Strix Halo (gfx1151).

Add a second tuned branch in _triton_config / _select_block_size_m
gated by `self._group_size <= 64` (a hard cutoff between the only
two production group_sizes we see, 32 and 128).  The
group_size > 64 path is byte-identical to before, so the existing
Qwen3.5-A3B 2x speedup is preserved.

Configs derived from a kernel-only sweep at Qwen3-Omni shape (Strix
Halo gfx1151, M=2048 N=768 K=2048 E=128 top_k=8 group_size=32) via
the new benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py tool:

  gemm1 (K=2048 N=1536):  BM=128 BN=64 GM=1 nw=8 ns=1
    -> 6.79 ms vs 11.27 ms at the Qwen3.5 tune (1.66x).
  gemm2 (K=768  N=2048):  BM=64  BN=64 GM=1 nw=4 ns=1
    -> 2.58 ms vs 3.18 ms  at the Qwen3.5 tune (1.23x).

Alignment block_size_m = lcm(128, 64) = 128 (the new
TRITON_BLOCK_SIZE_M_SMALL_GS); gemm2 uses the existing
_expert_ids_for repeat_interleave path so each 64-row sub-block
sees the right expert id.  This is the first caller to actually
exercise BLOCK_M != alignment -- the infrastructure was already in
place from 8af2e37.

End-to-end vLLM serving TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct-
AWQ-4bit (--num-prompts 10 --input-len 4096 --output-len 1
--max-num-seqs 1):
  before this patch:  2228 ms (bad baseline, current gfx11 tip)
  after  this patch:  1867 ms (-16%)

Changes:
- bench_hybrid_w4a16_moe.py's per-call timings are dominated by
  host setup (weight quant, sort, _resize_cache), so it can't
  resolve the per-config differences this tune relies on.  The
  new sweep tool calls invoke_fused_moe_kernel_hybrid_triton
  directly to time only the kernel.
- Bumped atol 2e-2 -> 3e-2 on test_hybrid_w4a16_moe_force_triton.
  The new tune's larger BM and num_warps change the partial-sum
  reduction order, which pushes the n=k=256 stress shape's max
  abs diff to ~0.027.  Real-model TTFT precision is unaffected
  (large K averages out per-tile rounding).

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The hybrid Triton kernel's shuffle_w4a16 path loaded one scale per
K-tile, which forced ``invoke_fused_moe_kernel_hybrid_triton`` to cap
BLOCK_SIZE_K at group_size.  At group_size=32 (Qwen3-Omni) this means
64 inner-loop iterations for K=2048 -- many short matmuls instead of
a few large ones.

Add a constexpr-gated multi-scale path to the kernel: when
BLOCK_SIZE_K > group_size, load a per-K-row scale tensor
[BLOCK_K, BLOCK_N] (mirrors what the non-shuffle wna16 path already
does); otherwise keep the original [BLOCK_N] one-scale-per-tile path
unchanged.  The wrapper's cap is replaced with an assertion that BK
either divides group_size or is a multiple of it.

Net effect at the Qwen3-Omni shape on Strix Halo: BK=32 still wins
the per-gemm sweep (register pressure caps BK gains; this kernel is
already near peak at small BK on gfx1151), so the production
_triton_config selection is unchanged.  The lift unblocks future
tuning at other (K, group_size) combinations.

Verified: tests/kernels/moe/test_hybrid_w4a16_moe.py 70/70 pass.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The production path uses moe_align_block_size + sorted_token_ids with
the kernel doing a virtual gather of A.  This pads num_slots up to
M*top_k + E*(BM-1), launching ~50% padded blocks that early-exit but
still consume kernel-launch slots.

Add an alternative apply path that uses the existing moe_permute /
moe_unpermute ops (already supported by cutlass_moe + exllama paths)
to lay activations out in expert-contiguous order, and a new Triton
kernel (fused_moe_kernel_hybrid_w4a16_grouped) that reads them
contiguously, indexed by a per-block (expert_id, m_start, m_count)
table.  No padding, no expert_ids[block]==-1 sentinel, no virtual gather.

End-to-end TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit
(--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1
on Strix Halo gfx1151, back-to-back same-session A/B):

  grouped OFF (production path): 1867 ms
  grouped ON:                    1766 ms  (-5.5%)

Correctness checked vs the production path at 4 shapes (Qwen3-Omni
m=128 and m=2048 with group_size=32, Qwen3.5-A3B m=2048 with
group_size=128, and a tiny edge case m=16 n=k=256 group_size=32 with
force_triton): grouped vs padded max abs diff = 0.0 in every case
(bit-identical).

Changes:
- Gated behind VLLM_HYBRID_W4A16_GROUPED=1 (off by default) for the
  initial landing; only triggers for prefill (num_tokens > 5) and
  only when expert_map is None and apply_router_weight_on_input is
  False.  Other config combinations fall through to the existing path.
- The new grouped kernel reuses the constexpr-gated multi-scale-per-
  tile path from fused_moe_kernel_gptq_awq (so BLOCK_K > group_size
  is supported for free if it ever becomes useful).
- Linear pid mapping rather than the GROUP_SIZE_M-swizzled mapping
  from the original kernel: consecutive blocks in the grouped layout
  belong to different experts so the B-reuse heuristic does not
  apply.  GROUP_SIZE_M is accepted in the config but ignored.

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
@mgehre-amd mgehre-amd force-pushed the matthias.hybrid-w4a16-qwen3-omni-tune branch from 0e3d214 to 424135e Compare May 26, 2026 21:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant