Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ stubs/

# Symlinks to compiled extensions
deep_gemm/*.so
deep_gemm/_C_build
deep_gemm/_C_buildsgl_deep_gemm/_C.so
sgl_deep_gemm/_C_build/
92 changes: 82 additions & 10 deletions csrc/apis/mega.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
#include "../jit_kernels/impls/sm100_mega_moe_pre_dispatch.hpp"
#include "../utils/math.hpp"
#include "../utils/system.hpp"

namespace deep_gemm::mega {

Expand All @@ -26,9 +29,33 @@ get_symm_buffer_size_for_mega_moe(
// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);

// Stream A0.0b: when `DG_USE_FP4_ACTS=1`, the symmetric `x` slot and the
// L1 token pool both hold packed E2M1 (FP4) instead of dense E4M3 (FP8).
// The per-token byte footprint halves; the SF slot is unchanged
// (`hidden/32` UE8M0 bytes — same `gran_k=32` for FP4 and FP8 acts under
// `kind::mxf8f6f4`). The host-side flag is read from the env so the
// existing `use_fp8_dispatch` API surface (which is hardcoded `true`
// throughout) doesn't need to change to opt in.
const bool host_use_fp4_acts = get_env<int>("DG_USE_FP4_ACTS") != 0;
const int input_token_bytes = host_use_fp4_acts ? (hidden / 2) : hidden;

// Stream B (combine path): when `DG_USE_FP8_COMBINE=1`, the combine slot
// holds FP8 E4M3 (kHidden bytes/token) + a separate combine_sf slot
// holding UE8M0 SF bytes (kHidden/128 bytes/token, gran_k=128). When off,
// the combine slot holds BF16 (kHidden*2 bytes/token) and combine_sf is
// unused (zero-sized).
const bool host_use_fp8_combine = get_env<int>("DG_USE_FP8_COMBINE") != 0;
constexpr int kCombineGranK = 128;
const int combine_token_bytes = host_use_fp8_combine ? hidden : (hidden * 2);
const int combine_sf_bytes_per_token = host_use_fp8_combine ? (hidden / kCombineGranK) : 0;

// Layouts
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_token_layout = layout::Data(input_token_bytes);
const auto combine_token_layout = layout::Data(combine_token_bytes);
// SF layout: bytes/token may not be a multiple of 16 (e.g. hidden=7168 →
// 7168/128=56 bytes), so disable TMA alignment requirement (the writes
// are 1-byte stores via `sym_buffer.map`, not TMA).
const auto combine_sf_layout = layout::Data(combine_sf_bytes_per_token, false);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
Expand Down Expand Up @@ -79,22 +106,35 @@ get_symm_buffer_size_for_mega_moe(
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());

// Combine input buffer: BF16 tokens for cross-rank combine
// Combine input buffer: BF16 tokens (default) OR FP8 (when host_use_fp8_combine)
// for cross-rank combine.
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
combine_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());
// Combine SF buffer: only sized when host_use_fp8_combine (otherwise zero).
// Layout matches combine_token_buffer's [num_topk][num_max_tokens_per_rank]
// outer shape, with kHidden/128 SF bytes per token.
const auto combine_sf_buffer = layout::Buffer(
combine_sf_layout, num_topk, num_max_tokens_per_rank,
combine_token_buffer.get_end_ptr());

// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);

// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
// Stream A0.0b: under `host_use_fp4_acts`, the `x` and `l1_acts` views
// expose packed E2M1 (`kPackedFP4` = `torch::kInt8`, 2 elements/byte) of
// shape `[..., hidden / 2]`. Underlying buffer bytes are the same as the
// sized `fp8_token_layout` slot, just half the row width.
const auto x_dtype = host_use_fp4_acts ? kPackedFP4 : torch::kFloat8_e4m3fn;
const int x_inner_cols = host_use_fp4_acts ? (hidden / 2) : hidden;
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
{num_max_tokens_per_rank, x_inner_cols},
torch::TensorOptions().dtype(x_dtype).device(buffer.device()));
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
Expand All @@ -109,8 +149,8 @@ get_symm_buffer_size_for_mega_moe(
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
{num_max_pool_tokens, x_inner_cols},
torch::TensorOptions().dtype(x_dtype).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 128},
Expand All @@ -127,7 +167,7 @@ get_symm_buffer_size_for_mega_moe(
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
return {reinterpret_cast<int64_t>(combine_sf_buffer.get_end_ptr()), slice_input_buffers};
}

static void fp8_fp4_mega_moe(
Expand Down Expand Up @@ -200,6 +240,26 @@ static void fp8_fp4_mega_moe(
// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);

// Stream A0.1: pick up FP4-acts flag from `DG_USE_FP4_ACTS` env var.
// Default off — preserves byte-identical FP8-acts behavior. Setting
// `DG_USE_FP4_ACTS=1` flips L1's epilogue quant to E2M1 + UE8M0 SF.
const bool use_fp4_acts = get_env<int>("DG_USE_FP4_ACTS") != 0;
// Stream A0.5: when also `DG_USE_MXF4_KIND=1`, the L1 and L2 mainloops
// run `tcgen05.mma.kind::mxf4.block_scale.block32` instead of
// `kind::mxf8f6f4` — K=64 dense per call (vs K=32 with-padding), dense
// FP4 smem (`_ALIGN8B`, half the byte footprint), scale_vec::2X SF
// protocol with HALF-WORD address bits. Only honored when
// `DG_USE_FP4_ACTS=1` (kind::mxf4 is FP4-only). See A6 capstone /
// B2 standalone GEMM for the +20-22% headline.
const bool use_mxf4_kind = use_fp4_acts and get_env<int>("DG_USE_MXF4_KIND") != 0;
// Stream B (combine path): when `DG_USE_FP8_COMBINE=1`, the L2 epilogue
// ships FP8 E4M3 + per-(token, N=128) UE8M0 SF over NVLink instead of
// BF16. The combine reduce dequantizes on the fly. NVLink bytes/token
// halve (from kHidden*2 → kHidden + kHidden/128). Independent of the
// FP4-acts / MXF4-kind flags above (those control the dispatch a2a +
// mainloops; this controls the combine a2a only).
const bool use_fp8_combine = get_env<int>("DG_USE_FP8_COMBINE") != 0;

// Dispatch into different architectures
if (arch_major == 10) {
sm100_fp8_fp4_mega_moe(y,
Expand All @@ -213,7 +273,8 @@ static void fp8_fp4_mega_moe(
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);
activation_clamp, fast_math,
use_fp4_acts, use_mxf4_kind, use_fp8_combine);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
Expand All @@ -230,6 +291,17 @@ static void register_apis(pybind11::module_& m) {
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
m.def("mega_moe_pre_dispatch", &mega_moe_pre_dispatch,
pybind11::arg("x"),
pybind11::arg("topk_idx"),
pybind11::arg("topk_weights"),
pybind11::arg("buf_x"),
pybind11::arg("buf_x_sf"),
pybind11::arg("buf_topk_idx"),
pybind11::arg("buf_topk_weights"),
pybind11::arg("num_tokens"),
pybind11::arg("group_size") = 32,
pybind11::arg("use_fp4_acts") = false);
#endif
}

Expand Down
35 changes: 26 additions & 9 deletions csrc/jit_kernels/heuristics/mega_moe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,18 @@ struct MegaMoEConfig {
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& num_tokens) {
const int& num_tokens,
const bool& use_mxf4_kind = false) {
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
if (num_expected_tokens_per_expert <= 8.5) {
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
return {2, 16, 8, 2};
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m.
// Under kind::mxf4, smem_a_per_stage = load_block_m * block_k / 2 must be a
// multiple of the 1024-byte smem alignment; load_block_m=8 (= block_m/2 for
// block_m=16) gives 512B which fails the static assert. Bump to block_m=32
// (load_block_m=16 → smem_a_per_stage=1024B) for the MXF4 path only.
return use_mxf4_kind ? std::tuple<int, int, int, int>{2, 32, 16, 2}
: std::tuple<int, int, int, int>{2, 16, 8, 2};
} else if (num_expected_tokens_per_expert <= 16.5) {
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
return {2, 32, 16, 2};
Expand Down Expand Up @@ -127,7 +133,11 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
const int& num_experts, const int& hidden,
const int& block_m, const int& block_n, const int& block_k, const int& store_block_m,
const int& sf_block_m, const int& sf_block_n,
const int& num_dispatch_warps, const int& num_epilogue_warps) {
const int& num_dispatch_warps, const int& num_epilogue_warps,
// Stream A0.5: under `use_mxf4_kind`, A and B smem use the dense FP4
// layout (`_ALIGN8B`, 2 nibbles/byte). Per-stage byte footprint halves
// for both A and B → num_stages doubles for the same smem budget.
const bool& use_mxf4_kind = false) {
constexpr int kSmemAlignment = 1024;
constexpr int kNumEpilogueStages = 2;
constexpr int kNumTMAStoreStages = 2;
Expand Down Expand Up @@ -162,8 +172,13 @@ static std::pair<int, int> get_pipeline_config_for_mega_moe(
const int smem_sfa_per_stage = sf_block_m * 4;
const int smem_sfb_per_stage = sf_block_n * 4;

// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers.
// Stream A0.5: dense FP4 (mxf4) halves both A and B byte footprints.
const int smem_a_per_stage = use_mxf4_kind ? (load_block_m * block_k / 2)
: (load_block_m * block_k);
const int smem_b_per_stage = use_mxf4_kind ? (block_n * block_k / 2)
: (block_n * block_k);
const int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;

// Fixed total
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
Expand All @@ -179,10 +194,11 @@ static MegaMoEConfig get_mega_moe_config(
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const int& num_padded_sf_pool_tokens) {
const int& num_padded_sf_pool_tokens,
const bool& use_mxf4_kind = false) {
// Block config
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens, use_mxf4_kind);
const int block_n = 128;
const int block_k = 128;
const int load_block_m = block_m / 2;
Expand Down Expand Up @@ -210,7 +226,8 @@ static MegaMoEConfig get_mega_moe_config(
num_experts, hidden,
block_m, block_n, block_k, store_block_m,
sf_block_m, sf_block_n,
num_dispatch_threads / 32, num_epilogue_threads / 32);
num_dispatch_threads / 32, num_epilogue_threads / 32,
use_mxf4_kind);

const auto config = MegaMoEConfig {
block_m, block_n, block_k,
Expand Down
Loading
Loading