MoE: Grouped Triton GEMM for TTFT improvements#970
Draft
mgehre-amd wants to merge 3 commits into
Draft
Conversation
The existing Triton prefill MoE tune (commits 1490142 + 8af2e37) was derived from Qwen3.5-A3B's shape (group_size=128, E=256). At small group_size the kernel wrapper caps BLOCK_K to group_size, so the narrow-BN/small-BM strategy that wins for group_size=128 becomes severely under-occupied -- regressing Qwen3-Omni-30B-A3B-AWQ-4bit (group_size=32, E=128) TTFT by ~20% on Strix Halo (gfx1151). Add a second tuned branch in _triton_config / _select_block_size_m gated by `self._group_size <= 64` (a hard cutoff between the only two production group_sizes we see, 32 and 128). The group_size > 64 path is byte-identical to before, so the existing Qwen3.5-A3B 2x speedup is preserved. Configs derived from a kernel-only sweep at Qwen3-Omni shape (Strix Halo gfx1151, M=2048 N=768 K=2048 E=128 top_k=8 group_size=32) via the new benchmarks/kernels/sweep_hybrid_w4a16_moe_triton.py tool: gemm1 (K=2048 N=1536): BM=128 BN=64 GM=1 nw=8 ns=1 -> 6.79 ms vs 11.27 ms at the Qwen3.5 tune (1.66x). gemm2 (K=768 N=2048): BM=64 BN=64 GM=1 nw=4 ns=1 -> 2.58 ms vs 3.18 ms at the Qwen3.5 tune (1.23x). Alignment block_size_m = lcm(128, 64) = 128 (the new TRITON_BLOCK_SIZE_M_SMALL_GS); gemm2 uses the existing _expert_ids_for repeat_interleave path so each 64-row sub-block sees the right expert id. This is the first caller to actually exercise BLOCK_M != alignment -- the infrastructure was already in place from 8af2e37. End-to-end vLLM serving TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct- AWQ-4bit (--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1): before this patch: 2228 ms (bad baseline, current gfx11 tip) after this patch: 1867 ms (-16%) Changes: - bench_hybrid_w4a16_moe.py's per-call timings are dominated by host setup (weight quant, sort, _resize_cache), so it can't resolve the per-config differences this tune relies on. The new sweep tool calls invoke_fused_moe_kernel_hybrid_triton directly to time only the kernel. - Bumped atol 2e-2 -> 3e-2 on test_hybrid_w4a16_moe_force_triton. The new tune's larger BM and num_warps change the partial-sum reduction order, which pushes the n=k=256 stress shape's max abs diff to ~0.027. Real-model TTFT precision is unaffected (large K averages out per-tile rounding). Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The hybrid Triton kernel's shuffle_w4a16 path loaded one scale per K-tile, which forced ``invoke_fused_moe_kernel_hybrid_triton`` to cap BLOCK_SIZE_K at group_size. At group_size=32 (Qwen3-Omni) this means 64 inner-loop iterations for K=2048 -- many short matmuls instead of a few large ones. Add a constexpr-gated multi-scale path to the kernel: when BLOCK_SIZE_K > group_size, load a per-K-row scale tensor [BLOCK_K, BLOCK_N] (mirrors what the non-shuffle wna16 path already does); otherwise keep the original [BLOCK_N] one-scale-per-tile path unchanged. The wrapper's cap is replaced with an assertion that BK either divides group_size or is a multiple of it. Net effect at the Qwen3-Omni shape on Strix Halo: BK=32 still wins the per-gemm sweep (register pressure caps BK gains; this kernel is already near peak at small BK on gfx1151), so the production _triton_config selection is unchanged. The lift unblocks future tuning at other (K, group_size) combinations. Verified: tests/kernels/moe/test_hybrid_w4a16_moe.py 70/70 pass. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
The production path uses moe_align_block_size + sorted_token_ids with the kernel doing a virtual gather of A. This pads num_slots up to M*top_k + E*(BM-1), launching ~50% padded blocks that early-exit but still consume kernel-launch slots. Add an alternative apply path that uses the existing moe_permute / moe_unpermute ops (already supported by cutlass_moe + exllama paths) to lay activations out in expert-contiguous order, and a new Triton kernel (fused_moe_kernel_hybrid_w4a16_grouped) that reads them contiguously, indexed by a per-block (expert_id, m_start, m_count) table. No padding, no expert_ids[block]==-1 sentinel, no virtual gather. End-to-end TTFT on cyankiwi/Qwen3-Omni-30B-A3B-Instruct-AWQ-4bit (--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1 on Strix Halo gfx1151, back-to-back same-session A/B): grouped OFF (production path): 1867 ms grouped ON: 1766 ms (-5.5%) Correctness checked vs the production path at 4 shapes (Qwen3-Omni m=128 and m=2048 with group_size=32, Qwen3.5-A3B m=2048 with group_size=128, and a tiny edge case m=16 n=k=256 group_size=32 with force_triton): grouped vs padded max abs diff = 0.0 in every case (bit-identical). Changes: - Gated behind VLLM_HYBRID_W4A16_GROUPED=1 (off by default) for the initial landing; only triggers for prefill (num_tokens > 5) and only when expert_map is None and apply_router_weight_on_input is False. Other config combinations fall through to the existing path. - The new grouped kernel reuses the constexpr-gated multi-scale-per- tile path from fused_moe_kernel_gptq_awq (so BLOCK_K > group_size is supported for free if it ever becomes useful). - Linear pid mapping rather than the GROUP_SIZE_M-swizzled mapping from the original kernel: consecutive blocks in the grouped layout belong to different experts so the B-reuse heuristic does not apply. GROUP_SIZE_M is accepted in the config but ignored. Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
0e3d214 to
424135e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hybrid W4A16 MoE: BK cap-lift + experimental grouped-GEMM prefill path
Stacks on top of #962 (group_size≤64 dispatch
tune for Qwen3-Omni-30B-A3B-AWQ). Adds two follow-up commits that explore further
prefill speedups on the same shape.
What
1. Allow
BLOCK_K > group_size(b7c9d59a2d)fused_moe_kernel_gptq_awq's shuffle_w4a16 path used to load one scale per K-tile,which forced
invoke_fused_moe_kernel_hybrid_tritonto capBLOCK_SIZE_Katgroup_size. Atgroup_size=32(Qwen3-Omni) that meant 64 inner-loop iterationsfor
K=2048— many short matmuls instead of a few large ones.Add a constexpr-gated multi-scale path: when
BLOCK_SIZE_K > group_size, load aper-K-row scale tensor
[BLOCK_K, BLOCK_N](mirrors the non-shuffle wna16 paththat already does this); otherwise keep the original
[BLOCK_N]one-scale-per-tilepath unchanged.
Net effect at the Qwen3-Omni shape: BK=32 still wins the per-gemm sweep
(register pressure caps gains at larger BK; this kernel is already near peak at
small BK on gfx1151). The lift is preserved for future tuning at other
(K, group_size)combinations.2. Experimental grouped-GEMM prefill path (
84660b1119)The production path uses
moe_align_block_size+sorted_token_idswith thekernel doing a virtual gather of A. This pads
num_slotsup toM*top_k + E*(BM-1), launching ~50% padded blocks that early-exit but stillconsume kernel-launch slots.
Add an alternative apply path that uses the existing
moe_permute/moe_unpermuteops (already used bycutlass_moeandexllama_moe) to layactivations out in expert-contiguous order, and a new Triton kernel
(
fused_moe_kernel_hybrid_w4a16_grouped) that reads them contiguously, indexedby a per-block
(expert_id, m_start, m_count)table. No padding, noexpert_ids[block]==-1sentinel, no virtual gather.Gated behind
VLLM_HYBRID_W4A16_GROUPED=1(off by default). Only triggers forprefill (
num_tokens > 5) and only whenexpert_map is Noneandapply_router_weight_on_input=False. Other config combinations fall throughto the existing path.
TTFT numbers (Qwen3-Omni-30B-A3B-AWQ, Strix Halo gfx1151)
Bench:
--num-prompts 10 --input-len 4096 --output-len 1 --max-num-seqs 1.All same-session, back-to-back runs.
.so)b7c9d59a2d)VLLM_HYBRID_W4A16_GROUPED=1)The grouped path beats the user-measured good baseline by ~22 ms.
Correctness
tests/kernels/moe/test_hybrid_w4a16_moe.py: 70/70 pass on both the productionand grouped paths.
Standalone A/B sanity (production vs grouped, max abs diff in output):
Reproducer
Open questions / follow-ups (not in this PR)
Currently env-gated to keep the production path the safe choice.
.item()sync per gemm. Couldbe optimized into a single fused kernel if a profile shows it on the
critical path (current TTFT improvement suggests it isn't).
apply_router_weight_on_inputandexpert_map—extending it is straightforward but not needed for the workloads tuned here.
BK > group_sizeis now legal in the kernel but not yet emitted by any_triton_configbranch; the lift is preserved for future shape tuning.