Skip to content

Cherry-pick FP4 acts + MXF4 + DG_USE_FP8_COMBINE; wrap mega_moe_pre_dispatch in tvm-ffi#31

Closed
Fridge003 wants to merge 5 commits into
release-0426from
baizhou-w4a4
Closed

Cherry-pick FP4 acts + MXF4 + DG_USE_FP8_COMBINE; wrap mega_moe_pre_dispatch in tvm-ffi#31
Fridge003 wants to merge 5 commits into
release-0426from
baizhou-w4a4

Conversation

@Fridge003
Copy link
Copy Markdown
Collaborator

Summary

Cherry-picks two FP4 mega-MoE features from dev-0426 (authored by @pranjalssh) onto release-0426, wraps the new public API with tvm-ffi, and adds a zero-numel-tensor fix uncovered while wiring this through sglang DP-attention.

Cherry-picked

Internal CUDA kernels are unchanged from those commits. Cherry-pick conflicts in csrc/jit_kernels/{heuristics,impls}/mega_moe*.hpp and tests/test_mega_moe.py were resolved by taking the dev-0426 side (kernel-side); csrc/apis/mega.hpp and deep_gemm/mega/__init__.py were merged manually so they remain consistent with the existing tvm-ffi calling convention on release-0426.

New on this branch

  1. tvm-ffi: wrap mega_moe_pre_dispatch and export from sgl_deep_gemm — adds a dg_mega_moe_pre_dispatch shim in csrc/tvm_ffi_api.cpp and registers it via TVM_FFI_DLL_EXPORT_TYPED_FUNC. Re-exports mega_moe_pre_dispatch from sgl_deep_gemm/__init__.py so the new kernel is reachable through both import deep_gemm and the sgl-deep-gemm wheel.
  2. tests: add sgl_deep_gemm smoke test — single-GPU sanity check for the tvm-ffi bound mega_moe_pre_dispatch + public-symbol audit.
  3. convert_to_torch_tensor: handle zero-numel tensors with empty CUDA tensortorch::from_blob invokes getDeviceFromPtr() to verify the data pointer lives on the requested device. For zero-numel tensors built with new_empty((0, K)) the data pointer can be nullptr, which fails the host-vs-device probe and throws tvm.error.InternalError: pointer resides on host memory from sglang's DP-attention forward_idle path. We fall through to torch::empty when numel == 0 || data == nullptr.

Tests

  • tests/test_mega_moe_pre_dispatch.py (cherry-picked) — passes for both FP8 and FP4 paths on 1×B200 (verified on baizhou-v4 devbox, CUDA 13.0).
  • tests/test_sgl_deep_gemm_smoke.py (new) — pass.

End-to-end validation (4×B200, sglang dg-w4a4 companion PR)

GSM8K 8-shot, n=1319, parallel=1319:

  • Accuracy: 0.951, latency 133.7s, throughput 938 tok/s (output).

bench_one_batch_server (input_len=1024, output_len=8) — see sgl-project/sglang#25052 for the full W4A4 vs W4A8 comparison table.

Co-authored-by: pranjalssh adkz.photos@gmail.com

🤖 Generated with Claude Code

pranjalssh and others added 5 commits May 12, 2026 03:01
…el (#27)

Two related additions for the DeepSeek-V4-Pro mega-MoE path:

1. **FP4 (E2M1) activations + `kind::mxf4` mainloop opt-in** for `fp8_fp4_mega_moe`.
   - `DG_USE_FP4_ACTS=1` halves the symm-buffer x-slot footprint (E2M1 nibbles
     vs E4M3 bytes); SF slot unchanged (still `hidden/32` UE8M0 bytes under
     gran_k=32).
   - `use_mxf4_kind=true` switches the L1+L2 mainloops to `cta_group::2 kind::mxf4`
     (2-CTA cluster) with dense FP4 smem layout (`_ALIGN8B`, 2 nibbles/byte).
     Per-stage A/B byte footprint halves → num_stages doubles for the same
     smem budget.
   - Threads `cumulative_local_expert_recv_stats` through the public mega-MoE
     API for per-rank expert counters used by sglang's expert-distribution
     recorder.
   - Block-m heuristic: under `use_mxf4_kind`, bumps `block_m=16 → 32` for the
     smallest-tokens-per-expert bucket so `load_block_m * block_k / 2` meets
     the 1024-byte smem alignment.
   - Multi-block_m support via `kCandidateBlockM` array + LCM-aligned pool
     padding; replaces the static `block_m=192` heuristic with token-density
     dispatch (8/16/32/64/96/128/192).

2. **`mega_moe_pre_dispatch` kernel**: BF16 → quant + topk-copy + pad-fill in
   one launch, gated on `kUseFp4Acts` + `kUsePDL`. Templated on
   `(kGroupSize, kUseFp4Acts, kUsePDL)`. Uses bucketize-style E2M1 encoder for
   byte-exact match against the `per_token_cast_to_fp4` host helper.
   - New: `deep_gemm.mega_moe_pre_dispatch(x, topk_idx, topk_weights, buf_x,
     buf_x_sf, buf_topk_idx, buf_topk_weights, num_tokens, group_size, use_fp4_acts)`
   - Test: `tests/test_mega_moe_pre_dispatch.py` — single-GPU bytewise check
     against host `per_token_cast_to_fp{8,4}` + pad-fill assertion.

Validated end-to-end on 8× B300 with DeepSeek-V4-Pro at 8K input bench:
- FP4 acts + MXF4 kind path produces matching tokens vs the FP8 baseline
  (rel-RMSE ≤ 0.5 sentinel; GSM8K accuracy parity within run-to-run variance).

PR also includes existing FP4-mega-MoE supporting changes that are required
by the kernel:
- `cluster_sync_with_relaxed_arrive` helper (used twice in `sm100_fp8_fp4_mega_moe.cuh`).
- `cvt_pack_f32_to_e2m1x2` / `cvt_pack_f32x4_to_e2m1x4` PTX wrappers.
- `SM100_MMA_MXF4_2x1SM_SS` 2-CTA cluster MMA wrapper.
- Generalized `red_add(int*, int)` for the `cumulative_local_expert_recv_stats`
  counter.
- `st.L1::no_allocate.relaxed.sys.global.u64` (correctness fix: previous
  generic-address variant could miss the global state space).

Co-authored-by: pranjalssh <adkz.photos@gmail.com>
(cherry picked from commit bca278e)
…bine path) (#28)

* Add DG_USE_FP8_COMBINE: FP8 + per-row UE8M0 SF on the second a2a (combine path)

The mega-MoE second all-to-all (combine) currently ships BF16 over NVLink:
each token, each topk slot = kHidden * 2 bytes. This commit adds an env-
gated FP8 path that ships FP8 E4M3 + a per-(token, N=128) UE8M0 SF byte —
kHidden + kHidden/128 bytes per token per slot, half the NVLink bytes.

Wiring:
- New `kUseFp8Combine` template flag (default false → keeps BF16 path
  byte-identical when off).
- New `combine_sf_buffer` symm-buffer slot, sized kHidden/128 bytes per
  (token, slot) when on, zero when off.
- Host: `DG_USE_FP8_COMBINE=1` env flag in `mega.hpp`. Independent of
  `DG_USE_FP4_ACTS` / `DG_USE_MXF4_KIND` (those control the dispatch a2a +
  mainloops; this controls the combine a2a only).

Producer side (L2 epilogue write-back, sm100_fp8_fp4_mega_moe.cuh):
- Read 8 BF16 from smem (existing STSM target).
- Compute per-row amax via `__shfl_xor_sync` reduction over the 16 lanes
  that share each row tile. Use a 16-lane mask (NOT 0xffffffff) — the
  outer `if (m_idx_in_block >= valid_m) break` may cause the OTHER half-
  warp to exit on padding rows, and a full-warp shfl would deadlock.
- Compute UE8M0 SF (E4M3 finfo_max=448, mirrors `get_e4m3_sf_and_sf_inv`).
- Cast 8 BF16 → 8 FP8 via `__nv_fp8x4_e4m3(float4)` ×2; pack into uint64.
- Write 8 FP8 bytes to remote (vs 16 BF16 bytes). Lane 0 of the 16-lane
  group writes the SF byte to `combine_sf_buffer`.

Consumer side (combine reduce):
- Per-slot SF base ptr cached at slot start.
- TMA-load FP8 chunk (kNumChunkBytes / 2 bytes when kUseFp8Combine).
- Per uint4 (16 FP8): __ldg the SF byte for the segment; FP8 → FP16x2
  via `cvt.rn.f16x2.e4m3x2`, FP16 → FP32 via `cvt.f32.f16`, then
  `__fmaf_rn(val, sf, acc)` for the accumulate-with-dequant.
- BF16 store-buffer layout for FP8 path: 2 BF16 uint4 per input uint4
  (16 elements → 2 × 8 BF16 stripes), at indices (j*32+lane)*2 + {0,1}.
  Total store uint4/lane same as BF16 path (kNumChunkUint4Bf16 / 32).

Validation:
- Microbench (`ptx/d_combine_reduce_v{1,2}_*`):
  - v1 BF16 baseline: 6,895 cycles/token, max_abs=0 (perfect).
  - v2 FP8 + UE8M0 SF: correctness PASS (max_abs=0 vs host reference
    that uses the same FP8 quant), 50% NVLink bytes savings.
- Single-GPU iso bench (8x B300, fp4_mxf4 vs fp4_mxf4+combine):
  - b=128:  364 us → 359 us (+1.5%)
  - b=512:  377 us → 386 us (-2.2%)
  - b=2048: 710 us → 739 us (-3.9%)
  Single-GPU is compute-bound (no NVLink saving); production is the
  point of the change.
- E2E DeepSeek-V4-Pro on 8x B300 (b=8192 input, 1024 output):
  - b=512:  91.92 s (FP8) → 78.37 s (FP4+MXF4+FP8combine) — +17.3%
  - b=2048: 259.4 s (FP8) → 238.2 s — +8.9%
  - b=4096: 489.5 s (FP8) → 444.2 s — +10.2%
  Sentinel test (FP4 acts vs FP8 acts): rel-RMSE <= 0.5 still passes.

Numerical: rel-RMSE on synthetic random init = 0.027 (combine FP8 vs
BF16 baseline, w/o SwiGLU clamping → tail outliers). Real activations
post-SwiGLU + topk-weighting are bounded; production accuracy parity
preserved (same GSM8K results as FP4 baseline).

* Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)

Switch the FP8 combine reduce inner loop from FP32 accumulator + scalar
fma to FP16x2 accumulator + hfma.f16x2. Halves the per-element op count
and halves the accumulator register pressure (94 regs vs 138 regs).

Inner loop, before:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  cvt.f32.f16  ×2     (FP16 → FP32)
  fma.rn.f32   ×2     (acc += sf_f32 * f32_val)
  = 5 ops per FP8x2 (= 2 elements)

After:
  cvt.rn.f16x2.e4m3x2 (FP8x2 → FP16x2)
  fma.rn.f16x2        (acc_fp16x2 += sf_pair * f16x2)
  = 2 ops per FP8x2

SF in FP16: UE8M0 byte → 1.0 * 2^(byte-127), packed as FP16 with bias 15.
Out-of-range SFs (byte < 112 or > 142) clamp to 0 / FP16-max — production
activations post-SwiGLU + topk-weighting fit comfortably in FP16 range.

End cast: FP16x2 → __half22float2 → __float22bfloat162_rn for the gmem
write-back (BF16 output unchanged).

Microbench (`ptx/d_combine_reduce_v3_fp8_hfma`):
  v1 BF16 baseline: 6,895 cycles/token
  v2 FP8 + FP32 acc: 10,797 cycles/token (+57% vs v1)
  v3 FP8 + FP16 HFMA: **5,799 cycles/token (-16% vs v1, -46% vs v2)**

E2E DeepSeek-V4-Pro 8x B300, 8K input + 1024 output:
  | batch | FP4+MXF4 | combine FP32 | combine HFMA |
  |------:|---------:|-------------:|-------------:|
  | 512   | —        | 7,526        | 7,350        |
  | 2048  | 9,814    | 9,903        | **9,992**    |
  | 4096  | 10,418   | 10,622       | **10,699**   |

HFMA wins at 2048/4096; ~tie at 512. Worth keeping as the default.

Numerical: v3 microbench correctness max_abs=0.0625, rel_rmse=3.8e-4
vs the FP32 reference. Production activations: still within sentinel
tolerance (rel-RMSE ≤ 0.5 vs FP8 baseline).

* Revert "Combine reduce: HFMA path (FP16 accumulator + fma.f16x2)"

This reverts commit 48e8101.

---------

Co-authored-by: pranjalssh <adkz.photos@gmail.com>
(cherry picked from commit 8fc78b4)
…nsor

torch::from_blob() runs getDeviceFromPtr() to verify the data pointer is on
the requested device. For zero-numel tensors created via new_empty((0, K))
the data pointer can be nullptr, which fails the host-vs-device probe and
crashes with InternalError from sglang forward_idle / DP attention idle batch.

Fall through to torch::empty when numel == 0 or data == nullptr — the caller
must already guard zero-size dimensions in any kernel that reads through it.
@Fridge003 Fridge003 closed this May 12, 2026
@Fridge003 Fridge003 deleted the baizhou-w4a4 branch May 12, 2026 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants