Skip to content

Marcusr/aiesw 32176 w4a16 ck wmma#930

Draft
marcusr-amd wants to merge 5 commits into
gfx11from
marcusr/aiesw-32176-w4a16-ck-wmma
Draft

Marcusr/aiesw 32176 w4a16 ck wmma#930
marcusr-amd wants to merge 5 commits into
gfx11from
marcusr/aiesw-32176-w4a16-ck-wmma

Conversation

@marcusr-amd
Copy link
Copy Markdown

@marcusr-amd marcusr-amd commented May 8, 2026

Purpose

Closes AIESW-32176.

Adds a CK WMMA W4A16 b_scale GEMM kernel and dispatches it from
HybridW4A16LinearKernel for the Qwen3-4B gate_up_proj prefill shape on
Strix Halo (gfx1151). This is the largest GEMM by FLOPs in Qwen3-4B prefill
(M=3968, N=19456, K=2560, group_size=128) — roughly 2x the next largest — so
optimizing it gives the biggest TTFT win for AWQ / W4A16 Qwen3-4B.

Two variants land in this PR:

  • Symmetric (torch.ops._rocm_C.ck_w4a16_b_scale_gemm) for
    compressed_tensors uint4b8 checkpoints (e.g.
    RedHatAI/Qwen3-4B-quantized.w4a16).
  • Asymmetric (torch.ops._rocm_C.ck_w4a16_b_scale_zp_gemm) for AWQ
    checkpoints with per-group zero points (e.g. Qwen/Qwen3-4B-AWQ — the
    AC#3 model). Reuses the symmetric kernel via the identity
    (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale, with
    scaled_zp = (zp - 8) * scale precomputed once at weight load. Adds one
    fp16 fma per dequant pack — both variants share all tile sizing,
    scheduler, and threadmap config.

Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR
(CK's ck/config.h is generated by CK's own configure step). Without those
flags csrc/rocm/ck_w4a16.cu is skipped and the dispatcher falls through to
Triton. VLLM_DISABLE_CK_W4A16=1 forces fall-through at runtime for A/B
benchmarking.

Dispatch decision lives inside the existing hybrid_w4a16_apply custom op
(extended with optional w_q_ck, ck_target_m, w_scaled_zp_ck args),
keeping it opaque to dynamo. CK-format weights are precomputed once in
process_weights_after_loading; the runtime M check is a plain int compare
against a per-layer cached value.

Caveats

  • Depends on six matching CK header changes that are NOT in this PR.
    vLLM consumes CK via include paths at build time; the new
    csrc/rocm/ck_w4a16.cu references a DequantPack8WithZp element-op
    and an optional p_b_zero_point arg on
    DeviceGemm_BScale_Wmma_CShuffleV3::MakeArgument that don't exist in
    upstream CK yet. The CK changes are additive and if constexpr-gated
    (zero impact on existing symmetric CK callers, verified bit-identical),
    but they need to land somewhere this build can find them before this PR
    will compile.

    Options being discussed: (a) open a companion PR against
    ROCm/composable_kernel; (b) vendor the modified CK headers under
    csrc/rocm/external/. Either way, this PR is draft/WIP awaiting the
    CK side
    .

    Build behavior in the meantime:

    • VLLM_CK_INCLUDE_DIR unset (default): csrc/rocm/ck_w4a16.cu is
      skipped, ops are unregistered, dispatcher falls through to Triton.
      Build succeeds; no behavior change vs gfx11 baseline.
    • VLLM_CK_INCLUDE_DIR set, pointing at a CK with the matching
      changes applied: kernel compiles, ops register, dispatch fires for
      the target shape.
    • VLLM_CK_INCLUDE_DIR set, pointing at upstream CK without the
      changes: compile error in csrc/rocm/ck_w4a16.cu.
  • Single-shape (column) specialization: only the Qwen3-4B gate_up_proj
    column (N=19456, K=2560) on gfx1151 routes to CK; the M dimension is
    validated at M=2048 (default chunked-prefill chunk size) and M=3968 (full
    prompt with --max-num-batched-tokens 4096), both ~30 TFLOPS standalone.
    Other columns (qkv_proj, o_proj, down_proj, other Qwen variants) stay on
    Triton — left for follow-up tuning. Per-layer dispatch carries a small
    list of validated M values (ck_target_ms) so adding shapes is a
    one-line change to _CK_W4A16_TARGET_SHAPES.

Test Plan

  1. Standalone CK kernel benchmark at the target shape, comparing against
    Triton W4A16 (via benchmarks/kernels/benchmark_hybrid_w4a16_gemm.py,
    edited to add batch_size=3968 + bf16 providers) and the hipBLASLt fp16
    TN roofline (hipblaslt-bench).
  2. Numerical correctness smoke tests: random fp16 weights → ExLlama
    pack → repack to CK layout → call new op → compare to torch fp16
    reference using the dequantized weights. Both symmetric and asymmetric
    variants.
  3. E2E A/B with VLLM_DISABLE_CK_W4A16 env-var toggle on the same
    vllm bench serve invocation:
    • Qwen/Qwen3-4B-AWQ (asymmetric, AC#3 model)
    • RedHatAI/Qwen3-4B-quantized.w4a16 (symmetric reference)
  4. lm_eval MMLU on Qwen3-4B-AWQ comparing CK kernel vs Triton baseline
    per AC#1 — TBD.
  5. Numerical equivalence at full vLLM model output (AC#3 verbatim:
    "produces the same output as with the triton kernel") — currently
    confirmed via the per-layer fp16 tolerance smoke tests; full-model
    output diff TBD.

Test Result

Measured on Strix Halo (gfx1151, Radeon 8060S) against ROCm 7.13.

Standalone GEMM (N=19456, K=2560, group=128, fp16, gfx1151)

At the target shape M=3968:

Path TFLOPS vs roofline vs Triton
hipBLASLt fp16 TN (roofline) 38.6 100%
Triton W4A16 (baseline) 15.6 41% 1.00×
CK b_scale (symmetric) 30.0 78% 1.92×
CK b_scale_zp (asymmetric) ~26.4 ~68% ~1.69×

The same kernel config holds across the M dimension on this column —
both M=2048 (chunked-prefill default chunk size) and M=3968 dispatch
to CK and hit ~30 TFLOPS:

M CK TFLOPS % of roofline
2048 (default chunk) 30.86 80%
3968 (full prompt @ chunk=4096) 30.00 78%
4096 30.10 78%

Asymmetric kernel is ~3.7% slower than symmetric per call — exactly the
design promise of one extra fp16 fma per dequant pack.

E2E (vllm bench serve, max_num_seqs=1, fp16, num_prompts=10)

Qwen/Qwen3-4B-AWQ (asymmetric, AC#3 model). The CK dispatch is
keyed per-layer by a min-M threshold (ck_min_m=256 for the gate_up
column, where standalone sweep shows the same kernel config holds
28-31 TFLOPS uniformly across M=256-16384). The default
chunked-prefill chunk (max_num_batched_tokens=2048)
now hits CK
on both prefill chunks (M=2048 and M=1920) — the prior
--max-num-batched-tokens 4096 workaround is not required:

Config TTFT mean TPOT
chunk=2048 (default), CK ON (min_m=256: M=2048+M=1920 both CK) 1922 ms 16.3
chunk=2048 (default), CK OFF (Triton) 2139 ms 17.1
chunk=4096 (override), CK ON (M=3968 single dispatch) 2166 ms 16.4
chunk=4096 (override), CK OFF (Triton) 2682 ms 17.5

The default chunk + CK ON config (1922 ms) wins both A/Bs:

  • vs VLLM_DISABLE_CK_W4A16=1 at the same chunk: −217 ms (−10%)
  • vs the prior chunk=4096 + CK ON config: −244 ms (−11%)
    splitting the prefill into two chunks (both CK) is faster than one
    big CK call. Per-dispatch profile shows Triton W4A16's qkv/o_proj/
    down_proj scale super-linearly with M, so chunked prefill saves
    ~270 ms of downstream Triton GEMM time per prompt by feeding those
    ops smaller M dimensions.

RedHatAI/Qwen3-4B-quantized.w4a16 (symmetric reference, prior
chunk=4096 measurement):

Config TTFT mean TPOT
chunk=4096, CK ON 2226 ms 16.0
chunk=4096, CK OFF 2630 ms 15.9
Delta −404 ms (−15%) ~0

Decode (TPOT) is unchanged in all A/Bs — CK only fires on prefill
chunks at the registered M values, leaving the decode path on the
existing skinny / Triton kernels.

Numerical correctness smoke tests

Variant Max abs / max ref
Symmetric (vs torch fp16 ref using dequantized weights) 5.4e−04
Asymmetric (vs torch fp16 ref using dequantized weights) 5.4e−04

Both well within the fp16 GEMM tolerance class used by
tests/kernels/quantization/.

MMLU regression check (AC#1)

TBD — will post lm_eval MMLU before/after on Qwen/Qwen3-4B-AWQ.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Adds csrc/rocm/ck_w4a16.cu wrapping CK's DeviceGemm_BScale_Wmma_CShuffleV3
as torch.ops._rocm_C.ck_w4a16_b_scale_gemm. Targets the Qwen3-4B gate_up_proj
prefill shape (M=3968, N=19456, K=2560, group=128, fp16) on Strix Halo, where
the Triton path leaves significant performance on the table.

Build is gated on -DVLLM_CK_INCLUDE_DIR + -DVLLM_CK_BUILD_INCLUDE_DIR (CK's
ck/config.h is generated by CK's own configure step, so both source and build
include dirs are needed). Without those flags csrc/rocm/ck_w4a16.cu is skipped
and the dispatcher falls through to Triton.

Dispatch lives inside the existing hybrid_w4a16_apply custom op (extended with
optional w_q_ck + ck_target_m args), keeping it opaque to dynamo. CK-format
weights are precomputed once in process_weights_after_loading; the runtime M
check is a plain int compare against a per-layer cached value. Set
VLLM_DISABLE_CK_W4A16=1 to force fall-through for A/B testing.

The vLLM ExLlama [N, K//8] int32 weight layout maps to CK's [K0, N, K1//2]
int8 via a single reshape + axis swap (no nibble re-packing). Scales pass
through unchanged.

Symmetric (uint4b8) only -- the kernel deliberately skips when zero-points
are present, so asymmetric AWQ checkpoints fall through to Triton until
follow-up work adds zero-point support.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Adds csrc/rocm/ck_w4a16.cu::ck_w4a16_b_scale_zp_gemm and wires the dispatch
in HybridW4A16LinearKernel so AWQ checkpoints with per-group zero points
(e.g. Qwen/Qwen3-4B-AWQ) reach the CK kernel. Symmetric callers are unchanged.

Implementation reuses the existing symmetric kernel via the identity
  (nibble - zp) * scale = (nibble - 8) * scale - (zp - 8) * scale
The caller precomputes scaled_zp = (zp - 8) * scale per group at weight load
and passes it to the new op; the CK kernel subtracts it inline during
dequant. All tile sizing, scheduler, and threadmap config is shared with
the symmetric path so future tuning benefits both kernels in one place.

Requires the matching CK header changes in
  include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
  include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
  include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp
  include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp
  include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp
which add an optional zero-point pointer threaded through to the dequant
inner loop. Build is gated on -DVLLM_CK_INCLUDE_DIR/-DVLLM_CK_BUILD_INCLUDE_DIR
exactly as before; without those flags the new op is skipped and asymmetric
callers stay on the Triton path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
… default)

Standalone sweep on the same EXP1_FINAL kernel config shows it holds
~30 TFLOPS on M=2048 (default chunked-prefill chunk size on the Qwen3-4B
gate_up_proj N=19456 K=2560 column), within ~3% of the M=3968 number it
was tuned for. Adds M=2048 to the per-layer target-M list so users no
longer need --max-num-batched-tokens 4096 to hit the CK path on that shape.

Generalizes the dispatch table from a single per-layer M to a small list,
threaded through the hybrid_w4a16_apply custom op as SymInt[]?. Membership
test stays inside the opaque custom op; runtime check is still a plain
Python int compare against a 1-2 element list per layer.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
…hreshold

Standalone sweep on the gate_up shape (N=19456, K=2560) shows the same
EXP1_FINAL kernel config holds 28-31 TFLOPS uniformly across M=256-16384,
so a min-M threshold is more accurate than enumerating measured values.
Below ~256 the kernel's fixed launch overhead (~0.4 ms) dominates and
Triton is comparable; the threshold avoids that range.

In particular this dispatches the M=1920 second-chunk case for
chunk=2048+prompt=3968 (which the discrete list missed and fell back
to Triton). E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048:

  before: 1987 ms TTFT (discrete list, M=1920 -> Triton fallback)
  after:  1922 ms TTFT (min_m=256, M=1920 -> CK)

  Per-layer kernel time at M=1920:
    Triton 8.48 ms -> CK 6.14 ms (-2.3 ms/layer * 36 layers = -84 ms;
    realized -65 ms after chunked-prefill overhead).

Also generalizes to arbitrary chunk sizes -- users no longer need to
match a specific M for CK to fire.

Custom op signature: SymInt[]? ck_target_ms -> SymInt ck_min_m=0.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
The same EXP1_FINAL kernel binary handles all four Qwen3-4B prefill
linear columns -- M/N/K are runtime args, only KPerBlock=32 is templated,
and all three additional K values (2560, 4096, 9728) are multiples of
both KPerBlock and Scale_Block_K=128. Standalone CK at the relevant
M (1920, 2048) and CPU verify pass:

  Layer       M     N     K     CK ms   Triton ms (profile)   delta/layer
  qkv         1920  6144  2560  1.96    2.70                  -0.74
  qkv         2048  6144  2560  2.17    2.87                  -0.70
  o_proj      1920  2560  4096  1.73    2.17                  -0.44
  o_proj      2048  2560  4096  1.95    2.14                  -0.19
  down_proj   1920  2560  9728  3.32    5.16                  -1.84
  down_proj   2048  2560  9728  3.76    5.27                  -1.51

E2E on Qwen/Qwen3-4B-AWQ at default chunk=2048 (5 reps each, num_prompts=10):

  gate_up only: TTFT mean 1966 ms, median 1954 ms
  all four:     TTFT mean 1894 ms, median 1868 ms

  -72 ms mean (-86 ms median); run-to-run noise ~+/-40 ms.

Each wired layer adds one CK-format weight copy (~0.92 GB total for
the four Qwen3-4B columns; roughly +6% of available memory on a
16 GB iGPU). Falls through to Triton if VLLM_DISABLE_CK_W4A16=1.

Signed-off-by: Marcus Rosen <marcus.rosen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant