diff --git a/csrc/apis/mega.hpp b/csrc/apis/mega.hpp index efc3a780d1..dd8ace1b44 100644 --- a/csrc/apis/mega.hpp +++ b/csrc/apis/mega.hpp @@ -8,6 +8,7 @@ #endif #include "../jit/device_runtime.hpp" #include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" +#include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp" namespace deep_gemm::mega { @@ -23,6 +24,15 @@ get_symm_buffer_size_for_mega_moe( const bool& use_fp8_dispatch, const std::string& activation) { DG_HOST_ASSERT(num_experts % num_ranks == 0); + // Architecture-dependent SF dtype for the user-facing tensor view: + // * SM100: per-32 UE8M0 packed 4-into-int (`torch::kInt`). + // * SM90 : per-128 channel float (`torch::kFloat32`). + // Both use the same number of bytes per token (hidden / 32), so the symmetric + // buffer layout is shared; only the slice view dtype changes. + const auto arch_major = device_runtime->get_arch_major(); + const bool is_sm90 = arch_major == 9; + const auto sf_dtype = is_sm90 ? torch::kFloat32 : torch::kInt; + // Workspace bytes const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); @@ -31,7 +41,16 @@ get_symm_buffer_size_for_mega_moe( const auto bf16_token_layout = layout::Data(hidden * 2); const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); const auto fp8_sf_layout = layout::Data(hidden / 32); - const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + // L2 acts SF granularity differs by arch: + // * SM100 packs 4 UE8M0 bytes per int along K, so each token uses + // `intermediate_hidden / 32` bytes (per-32 K). + // * SM90 stores per-64 K floats so that each L1 epilogue block (which + // produces 64 post-SwiGLU columns) can write its own SF independently + // without cross-CTA amax synchronisation; bytes per token become + // `intermediate_hidden / 64 * sizeof(float) = intermediate_hidden / 16`. + const int fp8_intermediate_sf_bytes_per_token = + is_sm90 ? (intermediate_hidden / 16) : (intermediate_hidden / 32); + const auto fp8_intermediate_sf_layout = layout::Data(fp8_intermediate_sf_bytes_per_token); const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); @@ -86,10 +105,14 @@ get_symm_buffer_size_for_mega_moe( // Check SF buffer requirements DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); - DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + // SM100 packs 4 UE8M0 bytes per int along K, so the padded SF token count + // must be divisible by 4. SM90 stores per-128 floats and has no such constraint. + if (not is_sm90) + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major + // Dtype is per-arch (see `sf_dtype` above): float on SM90, int (packed UE8M0) on SM100. auto slice_input_buffers = [=](const torch::Tensor& buffer) { auto x = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), @@ -98,7 +121,7 @@ get_symm_buffer_size_for_mega_moe( auto x_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), {num_max_tokens_per_rank, hidden / 128}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); auto topk_idx = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), {num_max_tokens_per_rank, num_topk}, @@ -115,16 +138,16 @@ get_symm_buffer_size_for_mega_moe( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), {num_max_padded_sf_pool_tokens, hidden / 128}, {1, num_max_padded_sf_pool_tokens}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); auto l2_acts = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), {num_max_pool_tokens, intermediate_hidden}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); auto l2_acts_sf = torch::from_blob( math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), - {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {num_max_padded_sf_pool_tokens, is_sm90 ? intermediate_hidden / 64 : intermediate_hidden / 128}, {1, num_max_padded_sf_pool_tokens}, - torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + torch::TensorOptions().dtype(sf_dtype).device(buffer.device())); return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); }; return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; @@ -224,11 +247,117 @@ static void fp8_fp4_mega_moe( sym_buffer.zero_(); } +// SM90 (Hopper) FP8 MegaMoE entry point. +// +// Mirrors `fp8_fp4_mega_moe` but expects FP8 (e4m3) weights with per-128 channel +// float scale factors. Top-level routing (which entry to call) is the caller's +// responsibility (see `deep_gemm/mega/__init__.py`). +static void fp8_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Architecture check + const auto arch_major = device_runtime->get_arch_major(); + DG_HOST_ASSERT(arch_major == 9); + + // Config checks: SM90 uses block (128, 128) float SF for weights, + // per-token per-128-K float SF for activations. + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128); + DG_HOST_ASSERT(activation == "swiglu"); + + // Activation checks + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + // Tensor checks: SM90 weights must be FP8 e4m3, K-major + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights); + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + + // Shape constraints required by the SM90 kernel: + // * Hidden dims must be multiples of 128 (per-128 SF + scheduler integer-tiling). + // * `l2_arrival_mask` is uint64, with one bit per L1-output N-block of size 64 in the + // intermediate dim, so `kNumL1BlockNs = intermediate_hidden / 64` must be ≤ 64. + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(intermediate_hidden / 64 <= 64); + + // Check weight SF layout (block (128, 128) float, MN-major; not TMA-loaded + // so no TMA-stride alignment is required, but we do require contiguity in + // the K-direction within each expert). + constexpr int kGranMN = 128, kGranK = 128; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, false, true, torch::kFloat); + + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + // Check buffer bytes + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + // Already registered tensors + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + sm90_fp8_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + static void register_apis(pybind11::module_& m) { #if DG_TENSORMAP_COMPATIBLE m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); + m.def("fp8_mega_moe", &fp8_mega_moe); #endif } diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index 7d85a5f556..fb58a496ea 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -59,6 +59,12 @@ class Compiler { flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; if (get_env("DG_JIT_WITH_LINEINFO", 0)) flags += " -Xcompiler -rdynamic -lineinfo"; + // NOTES: `--device-debug` (-G) emits full device DWARF so that cuda-gdb + // can resolve `__device__` global variables / line numbers in JIT + // kernels. It DISABLES device-side optimization and will tank perf, so + // it is gated behind an explicit env var. + if (get_env("DG_JIT_WITH_DEVICE_DEBUG", 0)) + flags += " --device-debug"; } virtual ~Compiler() = default; diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index be3bc31c07..f073646507 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -120,8 +120,11 @@ using KernelHandle = CUfunction; using LaunchConfigHandle = CUlaunchConfig; using LaunchAttrHandle = CUlaunchAttribute; -// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4 -#if CUDA_VERSION >= 12040 +// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4. +// Define `DG_JIT_FORCE_LEGACY_LOAD` to force the older `cuModuleLoad` path +// (useful when building against a newer CUDA SDK but running with an older +// driver that lacks the `cuLibrary*` symbols). +#if CUDA_VERSION >= 12040 && !defined(DG_JIT_FORCE_LEGACY_LOAD) #define DG_JIT_USE_LIBRARY_ENUM_KERNELS DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels); diff --git a/csrc/jit_kernels/heuristics/mega_moe.hpp b/csrc/jit_kernels/heuristics/mega_moe.hpp index b1ba6bd70c..06b52b48d2 100644 --- a/csrc/jit_kernels/heuristics/mega_moe.hpp +++ b/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -9,6 +9,7 @@ #include "../../utils/math.hpp" #include "../../utils/system.hpp" #include "sm100.hpp" +#include "sm90.hpp" namespace deep_gemm { @@ -237,4 +238,206 @@ static MegaMoEConfig get_mega_moe_config( return config; } +// ============================================================================ +// SM90 (Hopper) MegaMoE configuration +// ---------------------------------------------------------------------------- +// SM90 differs from SM100 in: +// - No tensor memory (TMEM): WGMMA accumulators live in registers. +// - No FP4: weights are FP8 e4m3, scales are per-128 channel float. +// - No 2-CTA cluster MMA: TMA multicast cluster=2 may still be used. +// - SF for activations is float (not UE8M0 int) and per-128 (not per-32). +// The kernel is in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently +// a skeleton; this config is what the host runtime reads. +// ============================================================================ + +struct MegaMoESM90Config { + // Block tiling (no STORE_BLOCK_M / SF_BLOCK_M concept on SM90) + int block_m, block_n, block_k; + + // Cluster size for TMA multicast (1 or 2). Multicast is on A. + int cluster_size; + + // Pool capacity and SF-padded token count (SF is per-128 float on SM90) + int num_max_pool_tokens; + int num_padded_sf_pool_tokens; + + // Swizzle modes for TMA descriptors (acts/weights). Both are 128B on FP8 K-major. + int swizzle_acts_mode, swizzle_weights_mode; + + // Number of experts to process per wave + int num_experts_per_wave; + + // Pipeline stages and shared memory + int num_stages, smem_size; + + // Thread layout: dispatch + non-epilogue (TMA) + epilogue (math) + int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const MegaMoESM90Config& config) { + os << "MegaMoESM90Config(" + << "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k + << ", cluster_size=" << config.cluster_size + << ", num_max_pool_tokens=" << config.num_max_pool_tokens + << ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens + << ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode + << ", num_experts_per_wave=" << config.num_experts_per_wave + << ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size + << ", num_dispatch_threads=" << config.num_dispatch_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +static std::tuple get_block_config_for_mega_moe_sm90( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& num_tokens) { + // Pick block_m and number of math (epilogue) warpgroups. WGMMA::M = 64 is + // the hard floor on Hopper, so each warpgroup needs at least 64 rows; + // i.e. (block_m / num_epilogue_warpgroups) >= 64. + // + // The 2-WG BLOCK_M=128 path lowers the number of CTAs, but on SM90 it also + // reduces pipeline depth and leaves large-batch fused L2/epilogue/combine + // throughput behind the legacy grouped-GEMM baseline. The 1-WG BLOCK_M=64 + // path has finer scheduling granularity and was the best default across the + // DeepSeek-V4-Flash batch sweep. + constexpr int block_m = 64; + constexpr int num_epilogue_warpgroups = 1; + + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + return {block_m, num_epilogue_warpgroups * 128}; +} + +static int get_num_experts_per_wave_for_mega_moe_sm90( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { + // SM90 (Hopper) wave heuristic. + // + // The generic heuristic is useful in the middle of the block_m=64 band, but + // very sparse routing and large batches both do better as a single all-expert + // wave: sparse cases avoid extra L1->L2 wave transitions, while large cases + // keep enough work resident without fragmenting expert scheduling. + const float expected_tokens_per_expert = + static_cast(num_tokens) * num_topk / num_experts_per_rank; + if (block_m == 64 and (expected_tokens_per_expert < 1.0f or expected_tokens_per_expert > 4.0f)) { + return num_experts_per_rank; + } + return get_num_experts_per_wave_for_mega_moe( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); +} + +static std::pair get_pipeline_config_for_mega_moe_sm90( + const int& smem_capacity, + const int& num_experts, const int& hidden, + const int& block_m, const int& block_n, const int& block_k, + const int& num_dispatch_warps, const int& num_epilogue_warps) { + constexpr int kSmemAlignment = 1024; + + // Dispatch region (same as SM100) + const int smem_expert_count_size = align( + num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + const int smem_send_buffers_size = align( + static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + kSmemAlignment); + const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; + + // C/D output region: max of L1 FP8 (single-buffered, BLOCK_N/2 post-SwiGLU) + // and L2 BF16, then 1024-byte aligned (matches kernel's SMEM_CD_SIZE). + const auto num_epilogue_warpgroups = num_epilogue_warps / 4; + const int smem_cd_l1 = num_epilogue_warpgroups * block_m * (block_n / 2); // 1 byte/elem (FP8) + const int smem_cd_l2 = num_epilogue_warpgroups * block_m * block_n * static_cast(sizeof(nv_bfloat16)); + const int smem_cd = align(std::max(smem_cd_l1, smem_cd_l2), kSmemAlignment); + + // SF on SM90: + // * SFA per stage must hold the larger of L1 (BLOCK_M floats, per-128 K) + // and L2 (2 * BLOCK_M floats, per-64 K), aligned to 128 bytes + // * SFB is loaded directly from global by the math warpgroup (block-(128,128) + // weight quantization), so no SMEM is reserved for it. + const int smem_sfa_per_stage = align(2 * block_m * static_cast(sizeof(float)), 128); + const int smem_sfb_per_stage = 0; + + // Per-stage: A tile + B tile + SFA tile + SFB tile + const int smem_per_stage = block_m * block_k + block_n * block_k + + smem_sfa_per_stage + smem_sfb_per_stage; + + // Barriers (8 bytes each): + // * dispatch: num_dispatch_warps + // * GEMM full + empty: 2 * num_stages + // * combine: 2 * num_epilogue_warps + const int smem_barriers_fixed = (num_dispatch_warps + 2 * num_epilogue_warps) * 8; + const int smem_barriers_per_stage = 2 * 8; + + // Fixed total + const int smem_fixed = smem_dispatch_size + smem_cd + smem_barriers_fixed; + + // Select max num_stages + const int num_stages = (smem_capacity - smem_fixed) / + (smem_per_stage + smem_barriers_per_stage); + DG_HOST_ASSERT(num_stages >= 2); + return {num_stages, + smem_fixed + num_stages * (smem_per_stage + smem_barriers_per_stage)}; +} + +static MegaMoESM90Config get_mega_moe_config_sm90( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + const auto [block_m, num_epilogue_threads] = get_block_config_for_mega_moe_sm90( + num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); + const int block_n = 128; + const int block_k = 128; + // NOTES: cluster_size=1 for SM90 in this initial implementation. Cluster=2 + // multicast on A is feasible (each pair of CTAs shares m_block, splits N), + // but the SwiGLU/FP8-quantize epilogue would then need cross-CTA amax + // reduction so that one per-128 SF correctly covers both 64-col halves. + // We defer that optimisation; cluster=1 is correct and self-contained. + const int cluster_size = 1; + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( + num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + const int swizzle_acts_mode = 128; + const int swizzle_weights_mode = 128; + + const int num_sms = device_runtime->get_num_sms(); + const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe_sm90( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); + + const int num_dispatch_threads = 128; + const int num_non_epilogue_threads = 128; + + const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe_sm90( + SM90ArchSpec::smem_capacity, + num_experts, hidden, + block_m, block_n, block_k, + num_dispatch_threads / 32, num_epilogue_threads / 32); + + const auto config = MegaMoESM90Config { + block_m, block_n, block_k, + cluster_size, + num_max_pool_tokens, num_padded_sf_pool_tokens, + swizzle_acts_mode, swizzle_weights_mode, + num_experts_per_wave, + num_stages, smem_size, + num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads + }; + + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + const auto key = fmt::format( + "MegaMoESM90Config(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << key << ": " << config << std::endl; + printed.insert(key); + } + } + return config; +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp new file mode 100644 index 0000000000..cc6d9f0bbf --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_fp8_mega_moe.hpp @@ -0,0 +1,235 @@ +#pragma once + +#include +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE host runtime +// ---------------------------------------------------------------------------- +// This is the SM90 counterpart of `SM100FP8FP4MegaMoERuntime`. The kernel +// itself lives in `deep_gemm/impls/sm90_fp8_mega_moe.cuh` and is currently a +// skeleton: dispatch/combine paths are intended to be portable from the SM100 +// version, while the GEMM (TMA load + WGMMA + epilogue) is being implemented +// in a follow-up step. +// +// Differences from SM100 path: +// * Activations and weights are both FP8 (e4m3); no FP4. +// * Activation/weight scale factors (SF) are per-128-channel float (not UE8M0 +// int + per-32 UTCCP layout). +// * No tensor memory: WGMMA accumulators are register-resident. +// * Cluster size is at most 2 (TMA multicast on A); no 2-CTA UMMA. +// ============================================================================ + +class SM90FP8MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + MegaMoESM90Config config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormaps for activations and weights. Weight scale factors use + // block (128, 128) quantization and are loaded by the math warpgroup + // directly from global memory (no TMA descriptor required). + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + const float* l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + const float* l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {} + >); +}}; +)", + args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.l2_weights_sf + )); + } +}; + +static void sm90_fp8_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config_sm90( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, + hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // Tensormap construction + // Acts/weights: standard 2D TMA descriptors (FP8 K-major). + // Activation SF: per-128 channel float for L1, per-64 for L2 (MN-major, no swizzle). + // Weight SF: block (128, 128) raw float pointer (no TMA descriptor). + constexpr int kGranK = 128; + constexpr int kL2ActsSFGranK = 64; + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + // L1 output (post-SwiGLU FP8): N is halved. The SM90 epilogue writes this + // staging tile to SMEM as plain row-major bytes, so the TMA store descriptor + // must use no shared-memory swizzle. Later L2 TMA loads may still swizzle + // from this row-major global buffer into their own SMEM tile. + // The TMA store is issued *per warpgroup*, each writing a `WG_BLOCK_M` + // (= block_m / num_epilogue_warpgroups) row tile from its own SMEM offset. + // The descriptor outer-box dim therefore must be `WG_BLOCK_M`, not block_m. + const int num_epilogue_warpgroups_h = config.num_epilogue_threads / 128; + const int wg_block_m = config.block_m / num_epilogue_warpgroups_h; + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, wg_block_m, + static_cast(l2_acts.stride(-2)), + 0); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.block_m, kL2ActsSFGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM90FP8MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .l1_weights_sf = l1_weights_sf.data_ptr(), + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .l2_weights_sf = l2_weights_sf.data_ptr(), + .launch_args = LaunchArgs(num_sms, config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, config.cluster_size) + }; + const auto code = SM90FP8MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_mega_moe", code); + SM90FP8MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index a9542e2f44..66bc81a2c1 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -85,7 +85,9 @@ SymmBuffer, get_symm_buffer_for_mega_moe, transform_weights_for_mega_moe, + transform_weights_for_mega_moe_sm90, fp8_fp4_mega_moe, + fp8_mega_moe, ) # Some utils diff --git a/deep_gemm/include/deep_gemm/common/math.cuh b/deep_gemm/include/deep_gemm/common/math.cuh index 0f0d250481..a93ef04e01 100644 --- a/deep_gemm/include/deep_gemm/common/math.cuh +++ b/deep_gemm/include/deep_gemm/common/math.cuh @@ -62,6 +62,14 @@ CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { #endif } +CUTLASS_DEVICE float2 mul2(const float2& a, const float2& b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __fmul2_rn(a, b); +#else + return make_float2(__fmul_rn(a.x, b.x), __fmul_rn(a.y, b.y)); +#endif +} + CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { float ret; asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); @@ -91,7 +99,7 @@ template CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; - const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto scaled = mul2(amax, finfo_factor); const auto exp_x = fast_log2_ceil(scaled.x); const auto exp_y = fast_log2_ceil(scaled.y); sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh new file mode 100644 index 0000000000..d11af9a356 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh @@ -0,0 +1,1285 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define __CLION_IDE__ + +namespace deep_gemm { + +// ============================================================================ +// SM90 (Hopper) FP8 MegaMoE — full implementation +// ---------------------------------------------------------------------------- +// Pipeline (cluster=1, no TMA multicast): +// * Dispatch warps: pull tokens (FP8) and SF (per-128 channel float) from +// remote ranks via NVLink into the local L1 pool. +// * GEMM TMA-load warps (1 for A+SFA, 1 for B+SFB) feed the pipeline stages. +// * Math warpgroups (1 or 2, totalling kNumEpilogueThreads) consume each +// stage with WGMMA, accumulate into registers, then run the epilogue: +// - L1 (Linear1): SwiGLU with gate/up granularity-8 interleaved layout, +// per-row amax over the 64 post-SwiGLU columns of this block, FP8 e4m3 +// quantize, STSM into SMEM, TMA store to local L1 output buffer. +// The per-row SF is written as a *float* into the L2-acts SF buffer at +// per-64 K granularity (one SF per L1 N block), so each block is fully +// self-contained and no cross-CTA amax synchronisation is needed. +// - L2 (Linear2): BF16 cast of the GEMM output, STSM into SMEM, then +// NVLink scatter to remote combine buffers. +// * After all GEMM blocks, the math warps run the COMBINE step (top-k +// reduction in BF16) — ported verbatim from the SM100 kernel. +// ============================================================================ + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm90_fp8_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const float* __restrict__ l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const float* __restrict__ l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900) and (__CUDA_ARCH__ < 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // ===================================================================== + // Template checks + // ===================================================================== + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of GEMM TMA warps (4 warps expected)"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of math/epilogue threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + DG_STATIC_ASSERT(BLOCK_M % 64 == 0, "BLOCK_M must be a multiple of WGMMA::M (64)"); + DG_STATIC_ASSERT(BLOCK_N == 128, "BLOCK_N is fixed to 128 for this initial SM90 path"); + DG_STATIC_ASSERT(BLOCK_K == 128, "BLOCK_K is fixed to 128 (per-128 SF)"); + + // ===================================================================== + // Thread / warp identification + // ===================================================================== + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch all TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + } + + // ===================================================================== + // Workspaces and symmetric buffer slicing (mirror SM100 layout, except SF + // for L2 activations uses per-64 K granularity) + // ===================================================================== + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + // Per-128 K float SF: 4 bytes per per-128 group => `kHidden / 32` bytes/token (same as SM100 packing) + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + // Per-64 K float SF (SM90 only): 4 bytes per per-64 group => `kIntermediateHidden / 16` bytes/token + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 16); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered input area + const auto input_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxTokensPerRank, workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumMaxTokensPerRank, input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer(input_topk_idx_layout, 1, kNumMaxTokensPerRank, input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer(input_topk_weights_layout, 1, kNumMaxTokensPerRank, input_topk_idx_buffer.get_end_ptr()); + + // L1 input area + const auto l1_token_buffer = layout::Buffer(fp8_token_layout, 1, kNumMaxPoolTokens, input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer(fp8_sf_layout, 1, kNumPaddedSFPoolTokens, l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer(l1_topk_weights_layout, 1, kNumMaxPoolTokens, l1_sf_buffer.get_end_ptr()); + + // L2 input area + const auto l2_token_buffer = layout::Buffer(fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer(fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, l2_token_buffer.get_end_ptr()); + + // Combine input area + const auto combine_token_buffer = layout::Buffer(bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, l2_sf_buffer.get_end_ptr()); + + // ===================================================================== + // GEMM data types and shape constants + // ===================================================================== + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::float_e4m3_t; + using L1WGMMA = typename mma::sm90::FP8MMASelector::type; // M=64, N=128, K=32 + using L2WGMMA = typename mma::sm90::FP8MMASelector::type; + static_assert(L1WGMMA::M == 64 and L1WGMMA::N == BLOCK_N and L1WGMMA::K == 32, + "Unexpected WGMMA shape"); + + // Cluster=1 -> no multicast, A/B are loaded full-sized + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M; + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; // post-SwiGLU + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); // 128 + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); // 128 + constexpr uint32_t kSwizzleCDMode = 128; + constexpr uint32_t kGranK = 128; // L1 acts SF, weights SF + constexpr uint32_t kL2ActsSFGranK = 64; // L2 acts SF (per-64 K, SM90 only) + + // ===================================================================== + // Shared memory layout + // ===================================================================== + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + // SFA per-stage must be sized for the larger of L1 (BLOCK_M floats) and L2 (2*BLOCK_M floats per-64). + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = + math::constexpr_align(2 * BLOCK_M * sizeof(float), 128u); + // Block (128, 128) weight SF: 1 float per (BLOCK_N, BLOCK_K) tile for L2, + // 2 floats (gate/up) for L1. Loaded by math warpgroup directly from global, + // so no SMEM is needed. + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = 0; + + // CD output: max of L1 FP8 (BLOCK_M * (BLOCK_N/2) * 1 byte * num_wg) and + // L2 BF16 (BLOCK_M * BLOCK_N * 2 bytes * num_wg). + constexpr uint32_t SMEM_CD_L1_SIZE = kNumEpilogueWarpgroups * BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t); + constexpr uint32_t SMEM_CD_L2_SIZE = kNumEpilogueWarpgroups * BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = math::constexpr_align( + SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE, kSharedMemoryAlignment); + + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + + // SMEM pointers + auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE); + + // CD output is shared by L1 (FP8) and L2 (BF16); reinterpret-cast as needed. + auto smem_cd_l1 = reinterpret_cast(smem_gemm_base); + auto smem_cd_l2 = reinterpret_cast(smem_gemm_base); + + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + + // Barriers live after SF (SFB is loaded directly from global, no SMEM) + auto barrier_start_ptr = reinterpret_cast( + sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + i; }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + i; }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages + i; }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + i; }); + + // ===================================================================== + // Initialization + // ===================================================================== + if (warp_idx == 0) { + // Clean expert-count shared memory + #pragma unroll + for (uint32_t i = lane_idx; i < kNumExperts; i += 32) + ptx::st_shared(smem_expert_count + i, 0u); + } else if (warp_idx == 1) { + // Init dispatch m-barriers + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM full/empty barriers and combine barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Two producer warps (A+SFA loader, B+SFB loader) each call + // `arrive_and_expect_tx` per stage, so init count must be 2. + full_barriers[i]->init(2); + // Each math warp arrives once per stage release. + empty_barriers[i]->init(kNumEpilogueWarps); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // ===================================================================== + // Scheduler (cluster=1) + // ===================================================================== + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, kNumExpertsPerWave, + kNumSMs, kNumRanks, /*kClusterSize=*/1u>(workspace); + + // Pipeline state shared by TMA loaders and math warpgroups + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM barrier indices (mirroring SM100) + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // Cross-rank NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Register reconfiguration counts (chosen to fit in 64512 reg budget). + // For the 256-epilogue-thread case (block_m=128, 2 math WGs): + // 128*48 + 128*40 + 256*208 = 64512 exactly. + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // ===================================================================== + // ROLE 1: DISPATCH WARPS + // Mirrors SM100 dispatch with two changes: + // * SF is per-128 channel float (no UTCCP transpose). We store the + // remote per-token SF directly into the local L1 SF buffer in + // MN-major layout: `local_sf[k_chunk * num_padded_sf_pool_tokens + token_idx]`. + // * The "token_idx_in_expert" → SF token index is now the simple + // per-block linear mapping (no 4×32 transpose). + // ===================================================================== + if (warp_idx < kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count tokens per expert + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Stake out per-expert SM offsets via global atomic + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source token-topk indices to remote ranks + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + false, true); + + // Sync with epilogue warps before pulling tokens + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Token / SF pull loop + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + scheduler.fetch_expert_recv_count(); + + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + if (current_expert_idx >= kNumExpertsPerRank) + break; + + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection (identical to SM100) + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA pull token data into SMEM + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Copy SF: per-128 K floats, written linearly (no UTCCP transpose). + constexpr uint32_t kNumSFFloats = kHidden / 128; + DG_STATIC_ASSERT(kNumSFFloats > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const uint32_t sf_pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFFloats, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFFloats) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Cleanup workspace, overlapping with combine + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + true, false); + + // ===================================================================== + // ROLE 2: GEMM TMA LOAD warps (load A+SFA, B+SFB) + // Warps inside `kNumNonEpilogueThreads` (= 4 warps): warp 0 loads + // A + SFA, warp 1 loads B + SFB, warps 2..3 idle. + // ===================================================================== + } else if (warp_idx == kNumDispatchWarps) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait for the pool to be ready + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // Each L1 N block sets one bit; total bits = L1_SHAPE_N / BLOCK_N. + constexpr uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N; + const uint64_t expected = (kNumL1BlockNs >= 64) + ? ~0ull : ((1ull << kNumL1BlockNs) - 1ull); + while (ptx::ld_acq_gpu(ptr) != expected); + } + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load A + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], + k_idx, m_idx, 1); + + // TMA load SFA + if (block_phase == sched::BlockPhase::Linear1) { + // L1 SFA per-128: load (BLOCK_M, 1) at K=k_block_idx + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + BLOCK_M * sizeof(float)); + } else { + // L2 SFA per-64: descriptor box is (block_mn, 1) (see make_tma_sf_desc), + // so we must issue two single-group TMAs and place them at smem offsets + // 0 and BLOCK_M to match math's load offsets (`+ 0 * BLOCK_M` / `+ 1 * BLOCK_M`). + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], + m_idx, k_block_idx * 2, 1); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], + smem_sfa[stage_idx] + BLOCK_M, + m_idx, k_block_idx * 2 + 1, 1); + full_barriers[stage_idx]->arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + 2 * BLOCK_M * sizeof(float)); + } + } + __syncwarp(); + } + }); + + } else if (warp_idx == kNumDispatchWarps + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + + const uint32_t shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + empty_barriers[stage_idx]->wait(phase ^ 1); + + if (cute::elect_one_sync()) { + const uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + const uint32_t k_idx = k_block_idx * BLOCK_K; + + // TMA load B (weight SF is now loaded directly by math warps from global) + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], + k_idx, n_idx, 1); + + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE); + } + __syncwarp(); + } + }); + + } else if (warp_idx < kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Idle non-epilogue warps (kNumDispatchWarps+2, +3). They must still + // participate in the warpgroup-collective `setmaxnreg.dec.sync.aligned` + // so that the math warpgroup's `warpgroup_reg_alloc` can succeed. + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // ===================================================================== + // ROLE 3: MATH WARPGROUPS (WGMMA + epilogue + combine) + // ===================================================================== + cutlass::arch::warpgroup_reg_alloc(); + + const uint32_t epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const uint32_t epilogue_wg_idx = epilogue_warp_idx / 4; + const uint32_t epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const uint32_t warp_idx_in_wg = epilogue_warp_idx % 4; + + // WGMMA-output register layout helpers + const uint32_t row_idx = lane_idx / 4; + const uint32_t col_idx = lane_idx % 4; + const uint32_t r_0 = warp_idx_in_wg * 16 + row_idx; + const uint32_t r_1 = r_0 + 8; + + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + DG_STATIC_ASSERT(WG_BLOCK_M == L1WGMMA::M, "Each warpgroup must run exactly one WGMMA per K-block"); + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + + // Sync with dispatch + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const uint32_t valid_m = scheduler.template get_valid_m(); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + const uint32_t m_idx = pool_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + const uint32_t row_offset_r0 = epilogue_wg_idx * WG_BLOCK_M + r_0; + const uint32_t row_offset_r1 = epilogue_wg_idx * WG_BLOCK_M + r_1; + const bool valid_r0 = row_offset_r0 < valid_m; + const bool valid_r1 = row_offset_r1 < valid_m; + + // ---------------- GEMM ---------------- + using WGMMA = L1WGMMA; + constexpr uint32_t kAccumPerThread = WGMMA::kNumAccum; // 64 for M=64,N=128 + float final_accum[kAccumPerThread] = {}; + float accum[kAccumPerThread]; + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + + // Read SF (must precede warpgroup_arrive) + float scale_a_0_lo, scale_a_1_lo; + float scale_a_0_hi, scale_a_1_hi; // Only used in L2 (per-64 K) + if (block_phase == sched::BlockPhase::Linear1) { + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + epilogue_wg_idx * WGMMA::M + r_0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + epilogue_wg_idx * WGMMA::M + r_1); + } else { + // L2: SFA layout is (K=2, M=BLOCK_M) MN-major; first half SF at offset 0, second at BLOCK_M + scale_a_0_lo = ptx::ld_shared(smem_sfa[stage_idx] + 0 * BLOCK_M + epilogue_wg_idx * WGMMA::M + r_0); + scale_a_1_lo = ptx::ld_shared(smem_sfa[stage_idx] + 0 * BLOCK_M + epilogue_wg_idx * WGMMA::M + r_1); + scale_a_0_hi = ptx::ld_shared(smem_sfa[stage_idx] + 1 * BLOCK_M + epilogue_wg_idx * WGMMA::M + r_0); + scale_a_1_hi = ptx::ld_shared(smem_sfa[stage_idx] + 1 * BLOCK_M + epilogue_wg_idx * WGMMA::M + r_1); + } + + // ----- Block (128, 128) weight SF (loaded directly from global) ----- + // L1 weight SF shape: (E, 2*IH/128, H/128) MN-major. The N axis is + // [gate(IH/128), up(IH/128)]; with the gate/up gran-8 interleave on + // the FP8 weight, each BLOCK_N=128 tile covers 64 rows of gate plus + // 64 rows of up taken from the same original 128-row block, so: + // gate_sf_n = n_block_idx / 2 + // up_sf_n = (IH/128) + n_block_idx / 2 + // + // L2 weight SF shape: (E, H/128, IH/128) MN-major. One scalar per + // (BLOCK_N, BLOCK_K) tile, broadcast across all WGMMA accumulators. + // + // NOTE: we tried hoisting these LDGs above the barrier wait and/or + // having only lane 0 load + shfl-broadcast. Both regressed on H20 + // by 7-11% across all batch sizes, presumably because (a) Hopper's + // L1 read-only cache already coalesces same-address LDGs from all + // 128 WG threads and (b) hoisting contended with the dispatch + // warps' NVLink LDGs on the MIO unit. Keep the simple parallel + // post-wait load. + constexpr uint32_t kL1SFKBlocks = kHidden / 128; + constexpr uint32_t kL2SFKBlocks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFGateBlks = kIntermediateHidden / 128; + constexpr uint32_t kL1SFPerExpert = (kIntermediateHidden * 2 / 128) * kL1SFKBlocks; + constexpr uint32_t kL2SFPerExpert = (kHidden / 128) * kL2SFKBlocks; + float gate_sf = 0.0f, up_sf = 0.0f, l2_sf = 0.0f; + if (block_phase == sched::BlockPhase::Linear1) { + const uint32_t gate_n = n_block_idx / 2u; + const uint32_t up_n = kL1SFGateBlks + gate_n; + const float* base = l1_weights_sf + local_expert_idx * kL1SFPerExpert + k_block_idx; + gate_sf = __ldg(base + gate_n * kL1SFKBlocks); + up_sf = __ldg(base + up_n * kL1SFKBlocks); + } else { + l2_sf = __ldg(l2_weights_sf + local_expert_idx * kL2SFPerExpert + + n_block_idx * kL2SFKBlocks + k_block_idx); + } + + if (block_phase == sched::BlockPhase::Linear1) { + // Single per-128 K-block WGMMA group + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + epilogue_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + if (lane_idx == 0) + empty_barriers[stage_idx]->arrive(); + + // L1: gate/up alternate at gran=8 along N; each `i` block of 8 + // cols belongs entirely to one of {gate, up}, so .x and .y + // share the same scalar. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + const float sb = (i & 1u) ? up_sf : gate_sf; + final_accum[i*4+0] += scale_a_0_lo * sb * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * sb * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * sb * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * sb * accum[i*4+3]; + } + } else { + // L2: split BLOCK_K=128 into two halves (per-64 SFA), each 2 WGMMAs. + // First half: K=0..63, SFA = scale_a_*_lo + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + epilogue_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // L2 first half: single scalar `l2_sf` broadcast across N. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[i*4+0] += scale_a_0_lo * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_lo * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_lo * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_lo * l2_sf * accum[i*4+3]; + } + + // Second half: K=64..127, SFA = scale_a_*_hi + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < (BLOCK_K / 2) / WGMMA::K; ++ k) { + const uint32_t k_off = (BLOCK_K / 2) + k * WGMMA::K; + auto desc_a = mma::sm90::make_smem_desc( + smem_a[stage_idx] + epilogue_wg_idx * WGMMA::M * BLOCK_K + k_off, 1); + auto desc_b = mma::sm90::make_smem_desc( + smem_b[stage_idx] + k_off, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread; ++ i) ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + if (lane_idx == 0) + empty_barriers[stage_idx]->arrive(); + + // L2 second half: same broadcast scalar `l2_sf`. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 4; ++ i) { + final_accum[i*4+0] += scale_a_0_hi * l2_sf * accum[i*4+0]; + final_accum[i*4+1] += scale_a_0_hi * l2_sf * accum[i*4+1]; + final_accum[i*4+2] += scale_a_1_hi * l2_sf * accum[i*4+2]; + final_accum[i*4+3] += scale_a_1_hi * l2_sf * accum[i*4+3]; + } + } + } + + // Skip epilogue when block is past valid M (still must release via empty) + if (epilogue_wg_idx * WG_BLOCK_M >= valid_m) { + // Trigger any combine/sync logic minimally + if (block_phase == sched::BlockPhase::Linear1) + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + else + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + return; + } + + if (block_phase == sched::BlockPhase::Linear1) { + // ---------------- L1 EPILOGUE: SwiGLU + FP8 quantize + TMA store ---------------- + // Layout in `final_accum`: + // 16 chunks of 8 N-cols, each chunk = 4 floats per thread = (r0c0, r0c1, r1c0, r1c1). + // Gate chunks: even (0, 2, ..., 14). Up chunks: odd (1, 3, ..., 15). + // Pair `p` ∈ [0, 8): gate chunk = 2p, up chunk = 2p+1. + // + // For each pair we produce 4 post-SwiGLU floats per thread, mapped to + // output cols (p*8 + col_idx*2 + {0,1}) for both r0 and r1. + + constexpr uint32_t kNumPairs = kAccumPerThread / 8; // 8 for BLOCK_N=128 + float swiglu_r0[kNumPairs][2]; + float swiglu_r1[kNumPairs][2]; + + // Per-row amax across all 8 pairs + float amax_r0 = 0.0f, amax_r1 = 0.0f; + + // Compute SwiGLU + per-pair amax + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const uint32_t gate = 2 * p, up = 2 * p + 1; + + // Apply optional clamp on gate / up before SwiGLU + // Match SM100 reference: gate is clamped only on the upper + // side (very-negative gate is fine because SiLU(-inf) -> 0), + // while up is clamped both sides. + auto clamp_gate = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(x, kActivationClamp); + }; + auto clamp_up = [](float& x) { + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) + x = cute::min(cute::max(x, -kActivationClamp), kActivationClamp); + }; + float g_r0_c0 = final_accum[gate*4 + 0]; clamp_gate(g_r0_c0); + float g_r0_c1 = final_accum[gate*4 + 1]; clamp_gate(g_r0_c1); + float g_r1_c0 = final_accum[gate*4 + 2]; clamp_gate(g_r1_c0); + float g_r1_c1 = final_accum[gate*4 + 3]; clamp_gate(g_r1_c1); + float u_r0_c0 = final_accum[up*4 + 0]; clamp_up(u_r0_c0); + float u_r0_c1 = final_accum[up*4 + 1]; clamp_up(u_r0_c1); + float u_r1_c0 = final_accum[up*4 + 2]; clamp_up(u_r1_c0); + float u_r1_c1 = final_accum[up*4 + 3]; clamp_up(u_r1_c1); + + // SiLU: x * sigmoid(x) = x / (1 + exp(-x)) + auto silu = [](float x) -> float { + const float e = kFastMath ? __expf(-x) : expf(-x); + const float sig = kFastMath ? math::fast_rcp(1.0f + e) : 1.0f / (1.0f + e); + return x * sig; + }; + + if (valid_r0) { + swiglu_r0[p][0] = silu(g_r0_c0) * u_r0_c0; + swiglu_r0[p][1] = silu(g_r0_c1) * u_r0_c1; + amax_r0 = cute::max(amax_r0, cute::max(cute::abs(swiglu_r0[p][0]), cute::abs(swiglu_r0[p][1]))); + } else { + swiglu_r0[p][0] = 0.0f; + swiglu_r0[p][1] = 0.0f; + } + if (valid_r1) { + swiglu_r1[p][0] = silu(g_r1_c0) * u_r1_c0; + swiglu_r1[p][1] = silu(g_r1_c1) * u_r1_c1; + amax_r1 = cute::max(amax_r1, cute::max(cute::abs(swiglu_r1[p][0]), cute::abs(swiglu_r1[p][1]))); + } else { + swiglu_r1[p][0] = 0.0f; + swiglu_r1[p][1] = 0.0f; + } + } + + // Apply token weight: SwiGLU * topk_weight (single load per row) + float weight_r0 = valid_r0 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r0) + .get_base_ptr() : 0.0f; + float weight_r1 = valid_r1 ? *l1_topk_weights_buffer + .get_data_buffer(m_idx + row_offset_r1) + .get_base_ptr() : 0.0f; + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + swiglu_r0[p][0] *= weight_r0; + swiglu_r0[p][1] *= weight_r0; + swiglu_r1[p][0] *= weight_r1; + swiglu_r1[p][1] *= weight_r1; + } + amax_r0 *= cute::abs(weight_r0); + amax_r1 *= cute::abs(weight_r1); + + // Reduce amax across the 4 col-lanes that share the same row. + // In WGMMA m64n128k32 output, the 4 lanes (`lane_idx & 3` differs, + // `lane_idx >> 2` same) hold all N positions for the same r_0/r_1, + // so we need an INTRA-group reduction (`xor 1, xor 2`), which is + // `warp_reduce<4, false>`. Using `<4, true>` would instead merge + // amax across 8 different rows -- giving wrong per-row SF. + amax_r0 = math::warp_reduce<4, false>(amax_r0, math::ReduceMax()); + amax_r1 = math::warp_reduce<4, false>(amax_r1, math::ReduceMax()); + + // Compute SF and inverse SF for each row + float sf_r0, sf_inv_r0; + float sf_r1, sf_inv_r1; + { + float2 amax_pair = {amax_r0, amax_r1}; + float2 sf_pair, sf_inv_pair; + math::get_e4m3_sf_and_sf_inv(amax_pair, sf_pair, sf_inv_pair); + sf_r0 = sf_pair.x; sf_inv_r0 = sf_inv_pair.x; + sf_r1 = sf_pair.y; sf_inv_r1 = sf_inv_pair.y; + } + + // Quantize and write to smem_cd_l1 (row-major, no swizzle). + // The L1-output TMA store descriptor is built with swizzle_mode = 0 + // to match this plain row-major SMEM staging tile. + // + // Per pair `p`, each thread holds 4 FP8 values to write at: + // (row r_0, cols p*8 + col_idx*2 + {0,1}) -> packed as fp8x2 (2 bytes) + // (row r_1, cols p*8 + col_idx*2 + {0,1}) -> packed as fp8x2 (2 bytes) + auto* smem_cd_l1_wg = smem_cd_l1 + epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N; + #pragma unroll + for (uint32_t p = 0; p < kNumPairs; ++ p) { + const float v00 = swiglu_r0[p][0] * sf_inv_r0; + const float v01 = swiglu_r0[p][1] * sf_inv_r0; + const float v10 = swiglu_r1[p][0] * sf_inv_r1; + const float v11 = swiglu_r1[p][1] * sf_inv_r1; + + const __nv_fp8x2_e4m3 r0_pair(make_float2(v00, v01)); + const __nv_fp8x2_e4m3 r1_pair(make_float2(v10, v11)); + + const uint32_t col = p * 8 + col_idx * 2; + auto* p0 = reinterpret_cast( + smem_cd_l1_wg + r_0 * L1_OUT_BLOCK_N + col); + auto* p1 = reinterpret_cast( + smem_cd_l1_wg + r_1 * L1_OUT_BLOCK_N + col); + if (valid_r0) + *p0 = r0_pair.__x; + if (valid_r1) + *p1 = r1_pair.__x; + } + + // Write SF as float at `[token, n_block_idx]` in L2 acts SF buffer (per-64 layout). + // Each row is contributed by lanes col_idx ∈ {0..3}; only col_idx == 0 writes. + if (col_idx == 0) { + auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // SF buffer is (kNumPaddedSFPoolTokens × kIntermediateHidden/64), MN-major: + // addr[k_idx * num_padded_sf_pool_tokens + token_idx] + const uint32_t token_r0 = pool_block_idx * BLOCK_M + row_offset_r0; + const uint32_t token_r1 = pool_block_idx * BLOCK_M + row_offset_r1; + const uint32_t k_sf_idx = n_block_idx; // one per-64 SF per L1 block + if (valid_r0) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r0] = sf_r0; + if (valid_r1) + sf_base_ptr[k_sf_idx * kNumPaddedSFPoolTokens + token_r1] = sf_r1; + } + + // Sync the warpgroup before TMA store + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store of the entire tile. Padding rows beyond + // `valid_m` are written with stale/garbage FP8 to the L1-output + // pool buffer, but they are never consumed downstream: the L2 + // GEMM tile loads them, but its NVLink-scatter epilogue is + // gated by `m_idx_in_block >= valid_m`, and stale SF in the + // padding rows can produce NaN accumulators that simply stay + // in registers (only valid rows are converted to BF16 and + // STSM'd into smem). Using TMA for partial tiles is a large + // win for low-batch / decode where every tile is partial. + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + const uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd_l1 + epilogue_wg_idx * WG_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + ptx::tma_store_wait<0>(); + + // Notify L2 that this N block's L1 output (and SF) is ready + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx); + } + __syncwarp(); + } else { + // ---------------- L2 EPILOGUE: BF16 cast + NVLink scatter ---------------- + constexpr uint32_t kNumRowsPerWarp = WG_BLOCK_M / 8; + + // STSM into smem_cd_l2 (BF16). Reuse SM100 column-swizzle layout. + #pragma unroll + for (uint32_t i = 0; i < kAccumPerThread / 8; ++ i) { + // Each i consumes 8 floats (one 16x256b chunk in SM100 terms). + // For SM90 WGMMA layout, 8 floats per i correspond to 2 chunks of 4 floats: + // final_accum[i*8 + (0..3)] = chunk 2i: (r0c0, r0c1, r1c0, r1c1) + // final_accum[i*8 + (4..7)] = chunk 2i+1: same shape + const uint32_t chunk_lo = 2 * i, chunk_hi = 2 * i + 1; + + // Write to SMEM at appropriate position + // Row r_0 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r0_lo + // Row r_0 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r0_hi + // Row r_1 cols [chunk_lo*8 + col_idx*2, chunk_lo*8 + col_idx*2 + 1] = r1_lo + // Row r_1 cols [chunk_hi*8 + col_idx*2, chunk_hi*8 + col_idx*2 + 1] = r1_hi + auto write_pair = [&](uint32_t row, uint32_t col, uint32_t packed) { + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * BLOCK_N + + row * BLOCK_N + + col; + // BF16 STS: 2 bf16 elements + *reinterpret_cast(smem_ptr) = packed; + }; + if (valid_r0) { + const uint32_t r0_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 0], final_accum[chunk_lo*4 + 1]); + const uint32_t r0_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 0], final_accum[chunk_hi*4 + 1]); + write_pair(r_0, chunk_lo * 8 + col_idx * 2, r0_lo); + write_pair(r_0, chunk_hi * 8 + col_idx * 2, r0_hi); + } + if (valid_r1) { + const uint32_t r1_lo = math::cast_into_bf16_and_pack( + final_accum[chunk_lo*4 + 2], final_accum[chunk_lo*4 + 3]); + const uint32_t r1_hi = math::cast_into_bf16_and_pack( + final_accum[chunk_hi*4 + 2], final_accum[chunk_hi*4 + 3]); + write_pair(r_1, chunk_lo * 8 + col_idx * 2, r1_lo); + write_pair(r_1, chunk_hi * 8 + col_idx * 2, r1_hi); + } + } + + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Scatter to remote ranks via NVLink (one row per warp-pair) + // Each warpgroup-warp covers 8 unique rows × 2 (r_0 + r_1 doubled by warps) + // Lane group of 16 within a warp → 1 row. + const uint32_t row_in_warp_block = lane_idx / 16; // 0 or 1 + const uint32_t lane_in_row = lane_idx % 16; + const uint32_t cols_per_lane = BLOCK_N / 16; // 8 cols per lane + static_assert(BLOCK_N == 128, "Layout assumes BLOCK_N=128"); + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_wg = warp_idx_in_wg * 16 + j * 2 + row_in_warp_block; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + row_in_wg; + if (m_idx_in_block >= valid_m) break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read 8 BF16s (= 16 bytes = 1 uint4) from smem + auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * WG_BLOCK_M * BLOCK_N + + row_in_wg * BLOCK_N + + lane_in_row * cols_per_lane; + const auto packed = *reinterpret_cast(smem_ptr); + + // Write to remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * sizeof(nv_bfloat16) + lane_in_row * sizeof(uint4)); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // ---------------- COMBINE ---------------- + // NVLink barrier first: signals remote ranks that this rank's GEMM + // outputs (NVLink scatter targets) are fully written. + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Sync with dispatch (paired with dispatch's pre-cleanup sync) so that + // dispatch may now safely clean workspace state. + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + constexpr uint32_t kNumChunks = + (kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE + and kHidden <= 32 * kNumMaxRegistersForBuffer) ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr( + smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + bool do_reduce = move_mask_and_load(load_stage_idx); + + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only supports sm_90"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh index cdbecccd56..a8bffbb19a 100644 --- a/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -22,6 +22,7 @@ template Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """SM90 (Hopper) variant of `transform_weights_for_mega_moe`. + + SM90 has no TMEM / UTCCP path, so the SF tensors are consumed directly by + WGMMA promote and don't need the 4x32 transpose. With block (128, 128) + weight quantization, weight SFs are read by the math warpgroup directly + from global memory in their natural ``(E, N/128, K/128)`` MN-major layout + and require no transformation. Only L1's gate/up FP8 weight interleave is + preserved. + """ + l1_fp8, l1_sf = l1_weights + # Reuse the gran-8 N interleave on the FP8 weight only; the block SF stays + # in its natural ``(E, 2*IH/128, H/128)`` layout (gate then up along N). + def _interleave_one(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return (_interleave_one(l1_fp8), l1_sf), l2_weights + + def fp8_fp4_mega_moe(y: torch.Tensor, l1_weights: Tuple[torch.Tensor, torch.Tensor], l2_weights: Tuple[torch.Tensor, torch.Tensor], @@ -126,3 +152,33 @@ def fp8_fp4_mega_moe(y: torch.Tensor, activation, activation_clamp, fast_math ) + + +def fp8_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (128, 128, 128), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + """SM90 (Hopper) MegaMoE entry point. + + Expects FP8 e4m3 weights and block-(128, 128) float scale factors. The + weight SF layout matches the convention used by ``DeepSeekV4FlashFp8`` / + DeepEP, so the same SF tensors can be physically shared between the + DeepEP path and this kernel. + """ + _C.fp8_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/scripts/run_ncu_mega_moe_sm90.sh b/scripts/run_ncu_mega_moe_sm90.sh new file mode 100755 index 0000000000..e8c2b0ef61 --- /dev/null +++ b/scripts/run_ncu_mega_moe_sm90.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +# SM90 (Hopper) variant of run_ncu_mega_moe.sh +# Drives `tests/bench_mega_moe_sm90.py` with NCU, profiling the +# `sm90_fp8_mega_moe_impl` kernel for a single batch size. + +set -e + +num_processes=8 +output_dir=work_sm90 +python_args=() +for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do + arg="${!arg_idx}" + case "$arg" in + --num-processes) + python_args+=("$arg") + if ((arg_idx < $#)); then + ((arg_idx++)) + num_processes="${!arg_idx}" + python_args+=("$num_processes") + fi + ;; + -h|--help) + echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" + exit 0 + ;; + --num-processes=*) + num_processes="${arg#*=}" + python_args+=("$arg") + ;; + -o|--output) + if ((arg_idx < $#)); then + ((arg_idx++)) + output_dir="${!arg_idx}" + fi + ;; + --output=*) + output_dir="${arg#*=}" + ;; + *) + python_args+=("$arg") + ;; + esac +done + +echo "Python Args: ${python_args[*]}" +echo "Num Processes: $num_processes" +echo "Output Dir: $output_dir" +mkdir -p "$output_dir" + +export DG_JIT_WITH_LINEINFO=1 + +echo "Warm up JIT cache" +python tests/bench_mega_moe_sm90.py --ncu-profile-only "${python_args[@]}" + +sleep 2 + +ncu_args=( + --config-file off + --force-overwrite + --kernel-name sm90_fp8_mega_moe_impl + --import-source yes + --replay-mode application + --section SpeedOfLight + --section LaunchStats + --section SchedulerStats + --section WarpStateStats + --section MemoryWorkloadAnalysis + --section InstructionStats + --launch-skip 0 + --launch-count 1 + --clock-control none + --kill yes + --app-replay-buffer memory +) + +echo "Run Job" + +for ((i = 0; i < num_processes; ++i)); do + ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe-sm90.$i" \ + python tests/bench_mega_moe_sm90.py \ + --local-rank-idx=$i \ + --ncu-profile-only \ + "${python_args[@]}" & +done + +echo "Waiting" +wait +echo "Done" diff --git a/tests/bench_mega_moe_sm90.py b/tests/bench_mega_moe_sm90.py new file mode 100644 index 0000000000..9a6b17c0d3 --- /dev/null +++ b/tests/bench_mega_moe_sm90.py @@ -0,0 +1,213 @@ +"""SM90 (Hopper) MegaMoE benchmark / NCU-profile harness. + +Mirrors ``tests/test_mega_moe.py``'s ``--ncu-profile-only`` / +``--local-rank-idx`` interface so the same ``scripts/run_ncu_mega_moe.sh`` +pattern can drive it for SM90. + +In normal (non-NCU) mode it sweeps a list of ``num_tokens`` values (default: +1, 2, 4, 8, 16, 32) and reports per-call kernel time via the same +``bench_kineto`` helper used by the SM100 perf test, plus a rough TFLOPS / +HBM GB/s figure useful for tracking optimisation deltas. +""" + +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, calc_diff, get_arch_major + + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _run_one_config(args, num_tokens, num_max_tokens_per_rank, + hidden, intermediate_hidden, + num_experts, num_topk, num_ranks, rank_idx, group, + activation_clamp, fast_math, + print_perf=True): + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + # Symmetric buffer (one per config: cheaper to recreate than to keep max-size) + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + ) + + # Inputs (bf16, then quantised) + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf), + ) + + cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + + # Stage inputs once; bench-loop re-copies them each call (bench helper expects + # an idempotent ``fn``). + def run_fused(): + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.fp8_mega_moe( + y, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp, + fast_math=fast_math, + ) + return y + + if args.ncu_profile_only: + dist_print(f'[NCU] tokens={num_tokens} hidden={hidden} ih={intermediate_hidden}', + once_in_node=True) + run_fused() + torch.cuda.synchronize() + dist.barrier() + buffer.destroy() + return + + # Warm up + benchmark + run_fused() + dist.barrier() + t_fused = bench_kineto(run_fused, 'sm90_fp8_mega_moe', + barrier=lambda: dist.barrier(), + num_tests=args.num_tests, + suppress_kineto_output=True) + + # Count tokens that landed on this rank for stats + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + num_touched_experts = max(0, torch.unique(gathered_topk_idx.flatten()).numel() - 1) + # FP8 weights = 1 byte, FP8 acts = 1 byte, BF16 output = 2 bytes + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden + # L1 weights + num_touched_experts * hidden * intermediate_hidden + # L2 weights + num_recv_tokens * hidden + # L1 acts read + num_recv_tokens * intermediate_hidden + # L1 out write + num_recv_tokens * intermediate_hidden + # L2 acts read + num_recv_tokens * hidden * 2 # L2 out write + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + if print_perf: + dist_print( + f' tokens={num_tokens:4d} recv={num_recv_tokens:5d} experts={num_touched_experts:4d} ' + f'{t_fused * 1e6:7.1f} us {tflops:6.1f} TFLOPS {hbm_gbs:6.0f} GB/s (rank{rank_idx})', + once_in_node=True, + ) + + dist.barrier() + buffer.destroy() + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print(f'[SKIP] requires SM90, got SM{get_arch_major()}0', once_in_node=True) + dist.destroy_process_group() + return + + if args.batches is None: + batches = [1, 2, 4, 8, 16, 32] + else: + batches = args.batches + + dist_print( + f'SM90 MegaMoE bench: ranks={num_ranks} hidden={args.hidden} ' + f'ih={args.intermediate_hidden} experts={args.num_experts} topk={args.num_topk} ' + f'masked_ratio={args.masked_ratio} fast_math={bool(args.fast_math)}', + once_in_node=True, + ) + + # In NCU mode we run only one batch (the first one in `batches`) so that + # ncu's `--launch-count 1` is unambiguous. + if args.ncu_profile_only: + batches = batches[:1] + + num_max_tokens_per_rank = max(batches) + for num_tokens in batches: + _run_one_config( + args, num_tokens, num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + num_ranks, rank_idx, group, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math), + ) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='SM90 MegaMoE benchmark') + + parser.add_argument('--ncu-profile-only', action='store_true') + parser.add_argument('--num-processes', type=int, default=8) + parser.add_argument('--local-rank-idx', type=int, default=None) + + parser.add_argument('--batches', type=int, nargs='+', default=None, + help='List of num_tokens to sweep (default: 1 2 4 8 16 32)') + parser.add_argument('--hidden', type=int, default=7168) + parser.add_argument('--intermediate-hidden', type=int, default=2048) + parser.add_argument('--num-experts', type=int, default=256) + parser.add_argument('--num-topk', type=int, default=8) + parser.add_argument('--activation-clamp', type=float, default=10.0) + parser.add_argument('--masked-ratio', type=float, default=0.0) + parser.add_argument('--fast-math', type=int, default=1) + parser.add_argument('--num-tests', type=int, default=20) + + args = parser.parse_args() + + if args.local_rank_idx is not None: + test(args.local_rank_idx, args.num_processes, args) + else: + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np) diff --git a/tests/test_mega_moe.py b/tests/test_mega_moe.py index e74b65e5d1..83e8d622f7 100644 --- a/tests/test_mega_moe.py +++ b/tests/test_mega_moe.py @@ -151,7 +151,7 @@ def run_fused(): num_topk=num_topk, use_fp8_dispatch=True, explicitly_destroy=True, allow_multiple_reduction=False, - gpu_timeout_secs=10, cpu_timeout_secs=30 + num_gpu_timeout_secs=10, num_cpu_timeout_secs=30 ) if is_legacy_loaded else None def run_baseline(): diff --git a/tests/test_mega_moe_hopper.py b/tests/test_mega_moe_hopper.py new file mode 100644 index 0000000000..648e04ecf0 --- /dev/null +++ b/tests/test_mega_moe_hopper.py @@ -0,0 +1,895 @@ +""" +H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。 + +结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8: + * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`), + 使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。 + * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine, + 使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation + per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA + 同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照, + 不是 bitwise apples-to-apples correctness oracle。 + * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us / + reduction us / `t_baseline / t_fused` legacy 比。 +""" + +import argparse +import math +import os +import random +import torch +import torch.distributed as dist +import triton +import triton.language as tl +from typing import Tuple + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto, get_arch_major + +try: + import deep_ep as _deep_ep + _deep_ep_import_error = None +except Exception as ex: + _deep_ep = None + _deep_ep_import_error = ex + + +# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名, +# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段 +SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl" + + +# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准 +FP8_E4M3_MAX = 448.0 +# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例, +# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。 +_FP8_E4M3_MAX_TL = tl.constexpr(448.0) +L1_ACT_SF_GRAN = 128 +FUSED_L2_ACT_SF_GRAN = 64 +BASELINE_L2_ACT_SF_GRAN = 128 +WEIGHT_SF_GRAN_MN = 128 +WEIGHT_SF_GRAN_K = 128 + + +# ============================================================================ +# 模块 1:Triton SwiGLU + FP8 量化内核 +# ---------------------------------------------------------------------------- +# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按 +# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则, +# 避免再额外引入 exact-FP32-scale 差异。 +# 输入 x : (M, 2*H) bf16,内层是 [gate_part | up_part] +# 输入 topk_w : (M,) fp32,可选 +# 输出 y : (M, H) fp8_e4m3fn +# 输出 y_sf : (M, H/BLOCK_K) fp32 行主序 +# ============================================================================ + + +@triton.jit +def _swiglu_apply_weight_to_fp8_kernel( + x_ptr, + topk_w_ptr, + y_ptr, + y_sf_ptr, + M, + H, # 运行时形状 + stride_xm, + stride_xn, # x: (M, 2H) 的 stride + stride_ym, + stride_yn, # y: (M, H) 的 stride + stride_sfm, + stride_sfk, # y_sf: (M, H/BLOCK_K) 的 stride + clamp_value, # 当 HAS_CLAMP=False 时这个参数无意义 + HAS_TOPK: tl.constexpr, + HAS_CLAMP: tl.constexpr, + USE_UE8M0_SCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, # = num_per_channels +): + # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列) + pid_m = tl.program_id(0) + pid_k = tl.program_id(1) + + # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + # 当前 K-block 内的列索引(在 H 维度,不是 2H) + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))---- + # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的 + gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn + up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn + gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) + up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32) + + # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)---- + if HAS_CLAMP: + gate = tl.minimum(gate, clamp_value) + up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value) + + # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)---- + y = gate * tl.sigmoid(gate) * up + + # ---- 4) 可选 MoE 权重缩放(per-token 标量)---- + if HAS_TOPK: + w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0) + y = y * w[:, None] + + # ---- 5) 当前 K-block 内每行 absmax → scale ---- + amax = tl.max(tl.abs(y), axis=1) # (BLOCK_M,) + sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30) + if USE_UE8M0_SCALE: + # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv: + # scale = 2 ** ceil(log2(amax / 448)). + sf = tl.exp2(tl.ceil(tl.log2(sf))) + + # ---- 6) 量化为 FP8 e4m3fn ---- + y_fp8 = (y / sf[:, None]).to(tl.float8e4nv) + + # ---- 7) 写回 y 和 sf ---- + y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn + tl.store(y_ptrs, y_fp8, mask=mask_m[:, None]) + + sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk + tl.store(sf_ptrs, sf, mask=mask_m) + + +def swiglu_apply_weight_to_fp8_triton( + x: torch.Tensor, + topk_weights: torch.Tensor | None, + clamp_value: float | None = None, + num_per_channels: int = BASELINE_L2_ACT_SF_GRAN, + use_ue8m0_scale: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """SwiGLU + FP8 量化。语义等价于 PyTorch reference: + gate, up = x[:, :H], x[:, H:] + y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w + y_sf = y.view(M, H/np, np).abs().amax(-1) / 448 + if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf) + y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8) + """ + assert x.is_cuda and x.dtype == torch.bfloat16 + assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位" + M, two_H = x.shape + H = two_H // 2 + assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍" + + y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device) + y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device) + + # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调 + BLOCK_M = 16 + grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels) + + # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位 + topk_ptr = topk_weights if topk_weights is not None else x + + _swiglu_apply_weight_to_fp8_kernel[grid]( + x, + topk_ptr, + y, + y_sf, + M, + H, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + y_sf.stride(0), + y_sf.stride(1), + float(clamp_value) if clamp_value is not None else 0.0, + HAS_TOPK=topk_weights is not None, + HAS_CLAMP=clamp_value is not None, + USE_UE8M0_SCALE=use_ue8m0_scale, + BLOCK_M=BLOCK_M, + BLOCK_K=num_per_channels, + ) + return y, y_sf + + +# ============================================================================ +# 模块 2:grouped weight 的 (128, 128) FP8 块量化 +# ---------------------------------------------------------------------------- +# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定: +# 每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。 +# 与 SM100 FP4 路径的差异: +# * 不需要 deep_gemm.transform_sf_into_required_layout +# * SF 是 FP32,不是 UE8M0 packed +# ============================================================================ + + +def _quantize_grouped_fp8_block_128_128( + w: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。""" + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数" + + # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + + # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块 + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128) + sf = amax / FP8_E4M3_MAX + + # 量化:每个元素除以所属子块的 sf 后转 FP8 + # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +# ============================================================================ +# 模块 3:尝试导入 deep_ep(用于 dispatch / combine) +# ============================================================================ + + +def _import_deep_ep(): + if _deep_ep is None: + dist_print(f"Failed to import deep_ep: {_deep_ep_import_error}", once_in_node=True) + return None + return _deep_ep + + +class _DeepEPHandle: + def __init__(self, raw_handle, psum_num_recv_tokens_per_expert: torch.Tensor): + self.raw_handle = raw_handle + self.psum_num_recv_tokens_per_expert = psum_num_recv_tokens_per_expert + + +class _DeepEPBufferCompat: + """Compatibility shim for newer DeepEP versions that expose Buffer, not ElasticBuffer.""" + + def __init__(self, deep_ep, group, num_nvl_bytes: int): + self.buffer = deep_ep.Buffer( + group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=0, + explicitly_destroy=True, + ) + + def dispatch( + self, + x, + *, + topk_idx, + topk_weights, + num_experts: int, + expert_alignment: int, + **_, + ): + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = ( + self.buffer.get_dispatch_layout(topk_idx, num_experts) + ) + recv_x, _, recv_topk_weights, num_recv_tokens_per_expert, raw_handle, event = self.buffer.dispatch( + x, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + topk_idx=topk_idx, + topk_weights=topk_weights, + expert_alignment=expert_alignment, + ) + psum = torch.tensor( + num_recv_tokens_per_expert, dtype=torch.int, device=topk_idx.device + ).cumsum(dim=0, dtype=torch.int) + return recv_x, None, recv_topk_weights, _DeepEPHandle(raw_handle, psum), event + + def combine(self, x, *, handle): + raw_handle = handle.raw_handle if isinstance(handle, _DeepEPHandle) else handle + return self.buffer.combine(x, handle=raw_handle) + + def barrier(self, use_comm_stream: bool = False): + torch.cuda.synchronize() + dist.barrier() + + def destroy(self): + self.buffer.destroy() + + +def _make_deep_ep_buffer(deep_ep, group, num_max_tokens_per_rank, hidden, num_topk, sym_buffer_bytes): + if hasattr(deep_ep, "ElasticBuffer"): + return deep_ep.ElasticBuffer( + group, + num_max_tokens_per_rank=num_max_tokens_per_rank, + hidden=hidden, + num_topk=num_topk, + use_fp8_dispatch=True, + explicitly_destroy=True, + allow_multiple_reduction=False, + ) + nvl_alignment = 2 * 1024 * 1024 + num_nvl_bytes = ((int(sym_buffer_bytes) + nvl_alignment - 1) // nvl_alignment) * nvl_alignment + return _DeepEPBufferCompat(deep_ep, group, num_nvl_bytes=num_nvl_bytes) + + +# ============================================================================ +# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖) +# ============================================================================ + + +def _bench_cuda_events( + fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0 +) -> float: + """返回 fn 的中位数耗时(秒)。""" + for _ in range(num_warmup): + fn() + torch.cuda.synchronize() + times_ms = [] + for _ in range(num_repeat): + # L2 flush,避免重复访问命中 cache 让测时偏低 + if l2_flush_gb > 0: + free_bytes, _ = torch.cuda.mem_get_info() + flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5)) + if flush_bytes >= 4: + torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + fn() + e.record() + e.synchronize() + times_ms.append(s.elapsed_time(e)) + times_ms.sort() + return times_ms[len(times_ms) // 2] / 1e3 + + +# ============================================================================ +# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline +# ============================================================================ + + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + if get_arch_major() != 9: + dist_print( + f"[SKIP] test_mega_moe_hopper requires SM90; got SM{get_arch_major()}0", + once_in_node=True, + ) + dist.destroy_process_group() + return + + # 形状参数(与 test_mega_moe.py 同名同义) + num_max_tokens_per_rank = args.num_max_tokens_per_rank + num_tokens = ( + max( + 0, + args.num_max_tokens_per_rank + - random.randint(0, args.num_max_removed_tokens), + ) + if args.num_tokens == 0 + else args.num_tokens + ) + hidden, intermediate_hidden = args.hidden, args.intermediate_hidden + num_experts, num_topk = args.num_experts, args.num_topk + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + assert num_experts % num_ranks == 0, ( + f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除" + ) + + # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe): + # * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF) + # * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列) + assert hidden % 128 == 0 + assert intermediate_hidden % 128 == 0 + assert intermediate_hidden // 64 <= 64, ( + f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}" + ) + + # ---- 创建 BF16 输入:token 与两层 weight ---- + # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维 + x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起) + l1_weights_bf16 = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, + device="cuda", + ) + # L2 weight: 每个 expert 把 intermediate_hidden → hidden + l2_weights_bf16 = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, + device="cuda", + ) + + # 路由:scores → topk_idx (M, K) + topk_weights (M, K) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda") + topk_weights, topk_idx = torch.topk( + scores, num_topk, dim=-1, largest=True, sorted=False + ) + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_weights.masked_fill_(topk_idx < 0, 0) + + # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖 + cum_stats_fused = torch.zeros( + (num_experts_per_rank,), dtype=torch.int, device="cuda" + ) + cum_stats_baseline = cum_stats_fused.clone() + + # ---- BF16 → FP8 量化 ---- + # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序) + # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF + x_fp8 = per_token_cast_to_fp8( + x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False + ) + + # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF) + # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组 + l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16) + l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16) + + # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变 + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + l1_weights, l2_weights + ) + + # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致) + clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None + run_baseline_enabled = args.run_baseline or bool(args.check_output_diff) + + # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)---- + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + # ---- 分配 fused 的 SymmBuffer 与输出 buffer ---- + sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + ) + y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + + def run_fused(): + # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时 + # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入 + sym_buffer.x[:num_tokens].copy_(x_fp8[0]) + sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1]) + sym_buffer.topk_idx[:num_tokens].copy_(topk_idx) + sym_buffer.topk_weights[:num_tokens].copy_(topk_weights) + + deep_gemm.fp8_mega_moe( + y_fused, + transformed_l1, + transformed_l2, + sym_buffer, + cumulative_local_expert_recv_stats=cum_stats_fused, + recipe=(128, 128, 128), + activation="swiglu", + activation_clamp=clamp_arg, + fast_math=bool(args.fast_math), + ) + return y_fused + + # ---- 打印 config ---- + dist_print("Config (H200 fused mega-MoE):", once_in_node=True) + dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True) + dist_print( + f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True + ) + dist_print( + f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})", + once_in_node=True, + ) + dist_print(f" > Masked ratio: {args.masked_ratio}", once_in_node=True) + dist_print( + f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, " + f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 " + f"(SM90 grouped GEMM constraint)", + once_in_node=True, + ) + dist_print( + f" > Baseline: {'enabled' if run_baseline_enabled else 'disabled'}", + once_in_node=True, + ) + dist_print( + f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True + ) + dist_print(once_in_node=True) + + # 与社区版 test_mega_moe.py 对齐:NCU 模式只跑 fused kernel,避免 baseline 噪声。 + if args.ncu_profile_only: + dist_print("Run fused SM90 mega-MoE kernel:", once_in_node=True) + y = run_fused() + torch.cuda.synchronize() + assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16 + dist_print(" > Done, exiting", once_in_node=True) + dist.barrier() + sym_buffer.destroy() + dist.destroy_process_group() + return + + # ---- 分配 DeepEP buffer(baseline 用)---- + deep_ep = _import_deep_ep() if run_baseline_enabled else None + ep_buffer = None + if deep_ep is not None: + ep_buffer = _make_deep_ep_buffer( + deep_ep, + group, + num_max_tokens_per_rank, + hidden, + num_topk, + sym_buffer.buffer.nbytes, + ) + + # ---------------------------------------------------------------- + # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine + # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换** + # 的版本(baseline grouped GEMM 不需要 gate/up interleave) + # ---------------------------------------------------------------- + def run_baseline(): + recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( + x_fp8, + topk_idx=topk_idx, + topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cum_stats_baseline, + num_experts=num_experts, + expert_alignment=alignment, + do_cpu_sync=False, + do_handle_copy=False, + do_expand=True, + use_tma_aligned_col_major_sf=False, # SM90: row-major float SF + ) + n = recv_x[0].size(0) + + # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接) + l1_y = torch.empty( + (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda" + ) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + recv_x, + l1_weights, + l1_y, + handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, + disable_ue8m0_cast=True, + ) + + # Triton SwiGLU + FP8 量化(含 topk 权重乘法) + # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K; + # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline + # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。 + l1_y = swiglu_apply_weight_to_fp8_triton( + x=l1_y, + topk_weights=recv_topk_weights, + clamp_value=clamp_arg, + num_per_channels=BASELINE_L2_ACT_SF_GRAN, + use_ue8m0_scale=True, + ) + + # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16 + l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda") + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + l1_y, + l2_weights, + l2_y, + handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, + disable_ue8m0_cast=True, + ) + + # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank + return ep_buffer.combine(l2_y, handle=handle)[0] + + # ---- 跑一次确保不报错(fused + 可选 baseline)---- + y = run_fused() + assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, ( + f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}" + ) + if ep_buffer is not None: + out_b = run_baseline() + assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, ( + f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}" + ) + if args.check_output_diff: + diff = (y.float() - out_b.float()).abs() + denom = out_b.float().abs().mean().clamp_min(1e-12) + dist_print( + "Output diff (fused vs legacy-per128 baseline):", once_in_node=True + ) + dist_print( + f" > max_abs={diff.max().item():.6e}, " + f"mean_abs={diff.mean().item():.6e}, " + f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}", + once_in_node=True, + ) + dist_print(once_in_node=True) + + # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ---- + # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目 + # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。 + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[ + (gathered_topk_idx < rank_idx * num_experts_per_rank) + | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank) + ] = -1 + local_expert_ids = gathered_topk_idx[gathered_topk_idx != -1] + num_recv_tokens = int(local_expert_ids.numel()) + num_touched_experts = int(torch.unique(local_expert_ids).numel()) + + # ---- benchmark ---- + # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead) + t_fused = bench_kineto( + run_fused, + SM90_KERNEL_NAME, + num_tests=args.num_bench_tests, + barrier=lambda: ep_buffer.barrier(use_comm_stream=False) + if ep_buffer is not None + else dist.barrier(), + trace_path=( + f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json" + if args.dump_profile_traces + else None + ), + ) + # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events) + t_baseline = ( + _bench_cuda_events( + run_baseline, + num_warmup=args.num_warmup, + num_repeat=args.num_repeat, + l2_flush_gb=args.l2_flush_gb, + ) + if ep_buffer is not None + else 0.0 + ) + + def safe_div(a, b): + return float("nan") if b == 0 else a / b + + # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens + tflops = safe_div( + 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused + ) + + # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同) + l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden + l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden + l1_weight_sf_bytes = ( + num_touched_experts + * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN) + * (hidden // WEIGHT_SF_GRAN_K) + * 4 + ) + l2_weight_sf_bytes = ( + num_touched_experts + * (hidden // WEIGHT_SF_GRAN_MN) + * (intermediate_hidden // WEIGHT_SF_GRAN_K) + * 4 + ) + l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4 + l2_act_sf_bytes = ( + num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4 + ) + num_hbm_bytes = ( + l1_weight_bytes + + l2_weight_bytes # weights (FP8) + + l1_weight_sf_bytes + + l2_weight_sf_bytes # weight SF (FP32) + + num_recv_tokens * hidden + + l1_input_sf_bytes # L1 输入读 (FP8 + SF) + + num_recv_tokens * intermediate_hidden + + l2_act_sf_bytes # L1 输出写 (FP8 + SF) + + num_recv_tokens * intermediate_hidden + + l2_act_sf_bytes # L2 输入读 (FP8 + SF) + + num_recv_tokens * hidden * 2 # L2 输出写 (BF16) + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16 + num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2) + nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused) + + # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s) + t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 + + # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐 + approx_factor = t_fused / max(t_fused - t_reduction, 1e-12) + + # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline + tflops_baseline = safe_div( + 2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline + ) + hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline) + nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline) + + def fmt_perf_line( + name: str, + t: float, + compute_tflops: float, + hbm_gbs_: float, + nvlink_gbs_: float, + reduction_us: float | None = None, + speedup: float | None = None, + ) -> str: + reduction = f"{reduction_us:13.1f}" if reduction_us is not None else f"{'-':>13}" + speedup_text = ( + f"{speedup:6.2f}x {'fused faster' if speedup > 1 else 'baseline faster'}" + if speedup is not None else + f"{'-':>21}" + ) + return ( + f" > {name:<10} {rank_idx:2d}/{num_ranks:<2d} " + f"{num_recv_tokens:12d} " + f"{num_touched_experts:14d} | " + f"{compute_tflops:15.0f} " + f"{hbm_gbs_:9.0f} " + f"{nvlink_gbs_:9.0f} " + f"{t * 1e6:9.0f} " + f"{reduction} " + f"{speedup_text}" + ) + + dist_print("Performance:", once_in_node=True) + dist_print( + " > kind EP recv_tokens active_experts | " + "compute(TFLOPS) HBM(GB/s) NVL(GB/s) time(us) reduction(us) speedup", + once_in_node=True, + ) + dist_print( + fmt_perf_line( + "[fused]", + t_fused, + tflops * approx_factor, + hbm_gbs * approx_factor, + nvlink_gbs * approx_factor, + reduction_us=t_reduction * 1e6, + ) + ) + if ep_buffer is not None: + speedup = safe_div(t_baseline, t_fused) + dist_print( + fmt_perf_line( + "[baseline]", + t_baseline, + tflops_baseline, + hbm_gbs_baseline, + nvlink_gbs_baseline, + speedup=speedup, + ) + ) + else: + reason = ( + "disabled; pass --run-baseline or --check-output-diff to compare" + if not run_baseline_enabled + else "deep_ep unavailable" + ) + dist_print(f" > [baseline] ({reason})", once_in_node=True) + + # ---- 清理 ---- + dist.barrier() + sym_buffer.destroy() + if ep_buffer is not None: + ep_buffer.destroy() + dist.destroy_process_group() + + +# ============================================================================ +# 模块 6:argparse + spawn +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline" + ) + + # 资源 + parser.add_argument( + "--ncu-profile-only", + action="store_true", + help="只运行一次 fused SM90 kernel,便于 NCU/Nsight 采样", + ) + parser.add_argument( + "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)" + ) + parser.add_argument( + "--local-rank-idx", + type=int, + default=None, + help="单进程模式的 local rank;用于外部 launcher/NCU 分别启动每个 rank", + ) + + # 模型形状 + # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096 + parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192) + parser.add_argument( + "--num-tokens", + type=int, + default=0, + help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank", + ) + parser.add_argument( + "--num-max-removed-tokens", + type=int, + default=0, + help="num-tokens 为 0 时,每个 rank 随机移除的最大 token 数", + ) + parser.add_argument("--hidden", type=int, default=7168) + parser.add_argument( + "--intermediate-hidden", + type=int, + default=3072, + help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)", + ) + parser.add_argument( + "--activation-clamp", + type=float, + default=10.0, + help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭", + ) + parser.add_argument("--num-experts", type=int, default=384) + parser.add_argument("--num-topk", type=int, default=6) + parser.add_argument( + "--masked-ratio", + type=float, + default=0.0, + help="随机 mask 掉部分 topk expert selection,用于验证稀疏路由边界", + ) + parser.add_argument( + "--fast-math", + type=int, + default=1, + help="fused 内 SwiGLU 是否启用 fast-math(0/1)", + ) + + # 测时 + parser.add_argument( + "--num-bench-tests", + type=int, + default=30, + help="bench_kineto 抓 fused 时的迭代数", + ) + parser.add_argument( + "--num-warmup", type=int, default=5, help="baseline cuda events warmup" + ) + parser.add_argument( + "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代" + ) + parser.add_argument( + "--l2-flush-gb", + type=float, + default=8.0, + help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭", + ) + parser.add_argument( + "--run-baseline", + action="store_true", + help="启用 DeepEP+grouped-FP8 legacy baseline;默认关闭以避免 full-size 默认配置触发 baseline kernel 非法访问", + ) + parser.add_argument( + "--check-output-diff", + type=int, + default=0, + help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)", + ) + parser.add_argument( + "--dump-profile-traces", + type=str, + default="", + help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)", + ) + + args = parser.parse_args() + + if args.dump_profile_traces: + os.makedirs(args.dump_profile_traces, exist_ok=True) + + if args.local_rank_idx is not None: + # 单进程模式:由外部 launcher 分别设置 MASTER_ADDR/PORT/WORLD_SIZE/RANK。 + test(args.local_rank_idx, args.num_processes, args) + else: + # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group。 + torch.multiprocessing.spawn( + test, args=(args.num_processes, args), nprocs=args.num_processes + ) diff --git a/tests/test_mega_moe_sm90.py b/tests/test_mega_moe_sm90.py new file mode 100644 index 0000000000..f38a5be5c8 --- /dev/null +++ b/tests/test_mega_moe_sm90.py @@ -0,0 +1,528 @@ +"""Layered tests for the SM90 (Hopper) MegaMoE kernel. + +The fused FP8 SM90 MegaMoE kernel is exercised across a hierarchy of +scenarios so that each kernel path / heuristic branch / edge case is +covered with at least one configuration. + +Layers +------ + L1 Smoke : single tiny config; only verifies the kernel runs + and produces an output close to a PyTorch reference. + L2 Heuristic : sweeps tokens-per-expert across the bands of + ``get_block_config_for_mega_moe_sm90`` so each + ``{block_m, num_epilogue_warpgroups}`` case is hit. + L3 Shape sweep : sweeps ``hidden``, ``intermediate_hidden`` and + ``num_topk`` over divisible-by-128 values. + L4 Edge cases : masking ratio, activation clamp (finite vs inf), + ``fast_math`` 0/1, ``num_tokens`` boundaries. + L5 Stress : ``--num-correctness-tests`` repeated random configs. + +Notes +----- +* The reference is a pure PyTorch BF16/FP32 simulation of the fused path + (dequantize -> matmul -> SwiGLU + clamp + per-row quantize -> matmul -> + cross-rank scatter -> BF16 reduce). It is *not* bitwise-identical to + the kernel; correctness is checked with ``calc_diff < 0.07``. +* Because every scenario allocates its own symmetric memory buffer we + re-`init_dist`/`destroy` once per process at the outer level only, + and re-create ``SymmBuffer`` per scenario. +* Skips itself when the device is not SM90. +""" + +import argparse +import math +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple, List, Dict, Any + +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import calc_diff, get_arch_major + + +# ---------------------------------------------------------------------------- +# Quantization helpers +# ---------------------------------------------------------------------------- + +def _quantize_grouped_fp8_block_128_128(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Block (128, 128) FP8 quantization along (N, K). + + Args + ---- + w : (G, N, K) bf16, with N % 128 == 0 and K % 128 == 0 + + Returns + ------- + fp8 : (G, N, K) torch.float8_e4m3fn + sf : (G, N // 128, K // 128) torch.float32, MN-major in the (N, K) + plane (i.e. K is the inner contiguous dim, matching the kernel's + ``stride_k = 1`` expectation and the DeepEP convention). + """ + g, n, k = w.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w.view(g, n // 128, 128, k // 128, 128).float() + amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4) # (G, N/128, K/128) + sf = amax / 448.0 + w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn) + return w_fp8.view(g, n, k).contiguous(), sf.contiguous() + + +def _dequant_block_128_128(w_fp8: torch.Tensor, sf: torch.Tensor) -> torch.Tensor: + """Inverse of `_quantize_grouped_fp8_block_128_128`. Returns fp32.""" + *prefix, n, k = w_fp8.shape + assert n % 128 == 0 and k % 128 == 0 + w_view = w_fp8.float().view(*prefix, n // 128, 128, k // 128, 128) + return (w_view * sf.unsqueeze(-1).unsqueeze(-3)).view(*prefix, n, k) + + +def _dequant_per_token_per_128_k(x_fp8: torch.Tensor, sf: torch.Tensor) -> torch.Tensor: + """For (M, K) fp8 with (M, K // 128) float SF (per-token, K-major).""" + m, k = x_fp8.shape + assert k % 128 == 0 + w_view = x_fp8.float().view(m, k // 128, 128) + return (w_view * sf.unsqueeze(-1)).view(m, k) + + +# ---------------------------------------------------------------------------- +# PyTorch reference +# ---------------------------------------------------------------------------- + +def _swiglu_fp32(gate_up: torch.Tensor, clamp: float) -> torch.Tensor: + """SwiGLU with one-sided gate clamp and two-sided up clamp. + + Matches the fused kernel: ``silu(min(gate, c)) * clamp(up, -c, c)``. + """ + n2 = gate_up.size(-1) + half = n2 // 2 + gate, up = gate_up[..., :half], gate_up[..., half:] + if math.isfinite(clamp): + gate = gate.clamp(max=clamp) + up = up.clamp(min=-clamp, max=clamp) + return torch.nn.functional.silu(gate) * up + + +def _reference_fused( + x_fp8_local: torch.Tensor, x_sf_local: torch.Tensor, + topk_idx_local: torch.Tensor, topk_weights_local: torch.Tensor, + l1_w_fp8: torch.Tensor, l1_w_sf: torch.Tensor, + l2_w_fp8: torch.Tensor, l2_w_sf: torch.Tensor, + rank_idx: int, num_ranks: int, group: dist.ProcessGroup, + num_experts: int, num_topk: int, + hidden: int, intermediate_hidden: int, + activation_clamp: float, +) -> torch.Tensor: + """Reference: returns (num_tokens, hidden) bf16 result for *this* rank. + + All-gathers the global tokens / topk decisions / per-rank weights, then + for each global token routes through its topk experts, applies the + L1+SwiGLU+L2 path, and reduces over topk on the source rank. + """ + num_experts_per_rank = num_experts // num_ranks + + # --- gather global token data -------------------------------------------------- + x_fp8_g = uneven_all_gather(x_fp8_local, group=group) # (Mg, H) + x_sf_g = uneven_all_gather(x_sf_local, group=group) # (Mg, H/128) + topk_idx_g = uneven_all_gather(topk_idx_local, group=group) # (Mg, K) + topk_w_g = uneven_all_gather(topk_weights_local, group=group) # (Mg, K) + mg = x_fp8_g.size(0) + + # rank-id lookup for each gathered token (for combine routing) + rank_offsets = [0] + sizes = [torch.tensor([0], device='cuda')] # placeholder + # mimic uneven_all_gather to compute per-rank token counts + local_size = torch.tensor([x_fp8_local.size(0)], device='cuda', dtype=torch.long) + sizes_t = torch.empty(num_ranks, dtype=torch.long, device='cuda') + dist.all_gather_into_tensor(sizes_t, local_size, group=group) + sizes_list = sizes_t.tolist() + src_rank_of = torch.empty(mg, dtype=torch.long, device='cuda') + cur = 0 + for r, s in enumerate(sizes_list): + src_rank_of[cur:cur + s] = r + cur += s + assert cur == mg + + # --- gather all-rank weights -------------------------------------------------- + # l1_w_fp8: (E_pr, 2*IH, H), l1_w_sf: (E_pr, 2*IH, H/128) + l1_w_g = [torch.empty_like(l1_w_fp8) for _ in range(num_ranks)] + l1_sf_g = [torch.empty_like(l1_w_sf) for _ in range(num_ranks)] + l2_w_g = [torch.empty_like(l2_w_fp8) for _ in range(num_ranks)] + l2_sf_g = [torch.empty_like(l2_w_sf) for _ in range(num_ranks)] + dist.all_gather(l1_w_g, l1_w_fp8, group=group) + dist.all_gather(l1_sf_g, l1_w_sf, group=group) + dist.all_gather(l2_w_g, l2_w_fp8, group=group) + dist.all_gather(l2_sf_g, l2_w_sf, group=group) + l1_w_all = torch.stack(l1_w_g, dim=0) # (R, E_pr, 2*IH, H) + l1_sf_all = torch.stack(l1_sf_g, dim=0) + l2_w_all = torch.stack(l2_w_g, dim=0) + l2_sf_all = torch.stack(l2_sf_g, dim=0) + + # --- per-token / per-topk compute -------------------------------------------------- + # The combine slot tensor: (Mg, K, H) bf16 — each src rank will reduce over K. + combine_buf = torch.zeros(mg, num_topk, hidden, dtype=torch.float32, device='cuda') + + # Precompute dequantized x in fp32 + x_fp32 = _dequant_per_token_per_128_k(x_fp8_g, x_sf_g) # (Mg, H) + + # Iterate (cheap; reference is for small test configs only) + # Token-chunked to keep gathered (S, 2*IH, H) dequant tensors below GPU memory. + _CHUNK = 256 + for k in range(num_topk): + # Skip masked + mask = topk_idx_g[:, k] >= 0 + if not mask.any(): + continue + sel_idx_full = mask.nonzero(as_tuple=False).squeeze(-1) # (S,) + for c0 in range(0, sel_idx_full.numel(), _CHUNK): + sel_idx = sel_idx_full[c0:c0 + _CHUNK] + eids = topk_idx_g[sel_idx, k] # (S,) + weights = topk_w_g[sel_idx, k] # (S,) + x_sel = x_fp32[sel_idx] # (S, H) + + dst_rank = (eids // num_experts_per_rank).long() + dst_local = (eids % num_experts_per_rank).long() + + # L1 GEMM (per-token): y = x @ W^T shape (S, 2*IH) + l1_w_sel = _dequant_block_128_128( + l1_w_all[dst_rank, dst_local], # (S, 2*IH, H) + l1_sf_all[dst_rank, dst_local], + ) + l1_y = torch.einsum('sk,snk->sn', x_sel, l1_w_sel) # (S, 2*IH) + del l1_w_sel + + # SwiGLU + clamp + multiply by topk weight + l1_y = _swiglu_fp32(l1_y, activation_clamp) * weights.unsqueeze(-1) # (S, IH) + + # Per-row, per-64-col FP8 quantize -> dequantize + s_, ih = l1_y.shape + assert ih == intermediate_hidden and ih % 64 == 0 + l1_view = l1_y.view(s_, ih // 64, 64) + amax = l1_view.abs().amax(dim=-1).clamp(1e-4) # (S, IH/64) + sf2 = amax / 448.0 + l1_q = (l1_view / sf2.unsqueeze(-1)).to(torch.float8_e4m3fn).float() + l2_in = (l1_q * sf2.unsqueeze(-1)).view(s_, ih) # (S, IH) fp32 + + # L2 GEMM + l2_w_sel = _dequant_block_128_128( + l2_w_all[dst_rank, dst_local], # (S, H, IH) + l2_sf_all[dst_rank, dst_local], + ) + l2_y = torch.einsum('sn,smn->sm', l2_in, l2_w_sel) # (S, H) + del l2_w_sel + + # Scatter to combine buffer (cast to bf16 then back to mimic kernel storage) + combine_buf[sel_idx, k] = l2_y.to(torch.bfloat16).float() + + # Sum over K -> (Mg, H), keep only this rank's slice + y_full_bf16 = combine_buf.to(torch.bfloat16).sum(dim=1).to(torch.bfloat16) # (Mg, H) + start = sum(sizes_list[:rank_idx]) + end = start + sizes_list[rank_idx] + return y_full_bf16[start:end].contiguous() + + +# ---------------------------------------------------------------------------- +# Single-scenario runner +# ---------------------------------------------------------------------------- + +def _run_scenario( + name: str, + cfg: Dict[str, Any], + rank_idx: int, num_ranks: int, group: dist.ProcessGroup, + diff_tol: float, +): + num_max = cfg['num_max_tokens_per_rank'] + num_tokens = cfg.get('num_tokens', num_max) + hidden = cfg['hidden'] + intermediate_hidden = cfg['intermediate_hidden'] + num_experts = cfg['num_experts'] + num_topk = cfg['num_topk'] + masked_ratio = cfg.get('masked_ratio', 0.0) + activation_clamp = cfg.get('activation_clamp', 10.0) + fast_math = cfg.get('fast_math', True) + + assert num_experts % num_ranks == 0, f'{name}: experts {num_experts} not divisible by ranks {num_ranks}' + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max + assert hidden % 128 == 0 and intermediate_hidden % 128 == 0 + + verbose = bool(int(os.environ.get('DG_TEST_VERBOSE', '0'))) + def _trace(stage: str): + if verbose: + print(f'[rank{rank_idx}] {name} :: {stage}', flush=True) + + _trace('begin') + torch.manual_seed(rank_idx * 1000 + abs(hash(name)) % 1000) + random.seed(rank_idx * 1000 + abs(hash(name)) % 1000) + + # ---- Inputs (bf16) ------------------------------------------------------- + x_bf = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_bf = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + l2_bf = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), + dtype=torch.bfloat16, device='cuda') * 0.05 + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_w, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + if masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < masked_ratio, -1) + topk_w.masked_fill_(topk_idx < 0, 0) + + # Quantize x to FP8 with per-128 K float SF (SM90 format) + # Quantize x to FP8 with per-128 K float SF (SM90 format) + x_fp8, x_sf = per_token_cast_to_fp8(x_bf, use_ue8m0=False, gran_k=128, + use_packed_ue8m0=False) + # Quantize weights with block (128, 128) — matches DeepSeekV4FlashFp8 / DeepEP. + l1_w_fp8, l1_w_sf = _quantize_grouped_fp8_block_128_128(l1_bf) + l2_w_fp8, l2_w_sf = _quantize_grouped_fp8_block_128_128(l2_bf) + + # SM90 weight transform (gate/up interleave only). With block (128, 128) + # SF, the SF tensor is consumed by the kernel as-is — no MN-major TMA + # transform and no SF-side gate/up interleave is needed. + _trace('weight_transform') + transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90( + (l1_w_fp8, l1_w_sf), (l2_w_fp8, l2_w_sf) + ) + + # ---- Allocate symm buffer ----------------------------------------------- + _trace('alloc_symm_buffer') + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max, num_topk, + hidden, intermediate_hidden, + ) + cum_stats = torch.zeros(num_experts_per_rank, dtype=torch.int, device='cuda') + + # ---- Run fused ----------------------------------------------------------- + _trace('copy_inputs') + buffer.x[:num_tokens].copy_(x_fp8) + buffer.x_sf[:num_tokens].copy_(x_sf) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_w) + + y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + _trace('launch_fused (may JIT-compile, can take minutes)') + deep_gemm.fp8_mega_moe( + y_fused, transformed_l1, transformed_l2, buffer, + cumulative_local_expert_recv_stats=cum_stats, + recipe=(128, 128, 128), + activation='swiglu', + activation_clamp=activation_clamp if math.isfinite(activation_clamp) else None, + fast_math=fast_math, + ) + _trace('sync_fused') + torch.cuda.synchronize() + _trace('fused_done') + + # ---- Reference & check --------------------------------------------------- + # Use the FP8 weights and their block-(128, 128) SF directly — the dequant + # helper expects this MN/K-block SF layout, and the original (gate||up) row + # ordering is what `_swiglu_fp32` splits with ``[..., :IH], [..., IH:]``. + _trace('reference') + y_ref = _reference_fused( + x_fp8, x_sf, topk_idx, topk_w, + l1_w_fp8, l1_w_sf, l2_w_fp8, l2_w_sf, + rank_idx, num_ranks, group, + num_experts, num_topk, + hidden, intermediate_hidden, + activation_clamp, + ) + + diff = calc_diff(y_fused, y_ref) + ok = diff < diff_tol + dist_print(f' [{name:<32}] diff={diff:.4f} ' + f'(tol={diff_tol:.2f}) {"OK" if ok else "FAIL"}', + once_in_node=True) + assert ok, f'{name}: diff={diff} >= tol={diff_tol}' + + # Verify cum_stats has been incremented (i.e. dispatch ran) + if num_tokens > 0 and masked_ratio < 1.0: + assert cum_stats.sum().item() >= 0 # non-negative; can be 0 if nothing routed here + + buffer.destroy() + dist.barrier() + + +# ---------------------------------------------------------------------------- +# Scenario tables +# ---------------------------------------------------------------------------- + +# A single tiny config used as a smoke test. +_SMOKE = dict( + num_max_tokens_per_rank=64, num_tokens=64, + hidden=512, intermediate_hidden=512, + num_experts=8, num_topk=2, +) + + +def _layer1_smoke() -> List[Tuple[str, Dict[str, Any]]]: + return [('L1.smoke', dict(_SMOKE))] + + +def _layer2_heuristic_branches(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + """Vary tokens / (num_experts * num_topk / num_ranks) so each + ``get_block_config_for_mega_moe_sm90`` band fires at least once. + + The heuristic decides on ``avg_tokens_per_expert``; we approximate by + setting ``num_max_tokens_per_rank`` and ``num_topk`` while keeping + ``num_experts`` fixed. The bands are at 64.5 / 96.5 / 192.5. + """ + base = dict(hidden=1024, intermediate_hidden=1024, + num_experts=8 * num_ranks, num_topk=2) + out: List[Tuple[str, Dict[str, Any]]] = [] + # tokens-per-rank settings chosen to hit (small / mid / large) bands + for tokens, label in [(64, 'small'), (256, 'midA'), (512, 'midB'), (2048, 'large')]: + cfg = dict(base) + cfg.update(num_max_tokens_per_rank=tokens, num_tokens=tokens) + out.append((f'L2.heur.{label}.t{tokens}', cfg)) + return out + + +def _layer3_shape_sweep(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + out: List[Tuple[str, Dict[str, Any]]] = [] + base_experts = 8 * num_ranks + for hidden in (512, 2048): + for ih in (512, 2048): + for topk in (1, 2, 4): + if topk > base_experts: + continue + cfg = dict(num_max_tokens_per_rank=128, num_tokens=128, + hidden=hidden, intermediate_hidden=ih, + num_experts=base_experts, num_topk=topk) + out.append((f'L3.h{hidden}_ih{ih}_k{topk}', cfg)) + return out + + +def _layer4_edges(num_ranks: int) -> List[Tuple[str, Dict[str, Any]]]: + base = dict(num_max_tokens_per_rank=128, + hidden=512, intermediate_hidden=512, + num_experts=8 * num_ranks, num_topk=2) + out = [] + # Masked ratios + for mr in (0.0, 0.3, 0.7): + cfg = dict(base); cfg.update(num_tokens=128, masked_ratio=mr) + out.append((f'L4.mask{mr:.1f}', cfg)) + # All masked + cfg = dict(base); cfg.update(num_tokens=128, masked_ratio=1.0) + out.append(('L4.mask_all', cfg)) + # Activation clamp variations (finite vs inf) + for c in (1.0, 10.0, math.inf): + cfg = dict(base); cfg.update(num_tokens=128, activation_clamp=c) + out.append((f'L4.clamp{c}', cfg)) + # fast_math toggle + for fm in (True, False): + cfg = dict(base); cfg.update(num_tokens=128, fast_math=fm) + out.append((f'L4.fm{int(fm)}', cfg)) + # num_tokens boundaries + cfg = dict(base); cfg.update(num_tokens=0) + out.append(('L4.tokens0', cfg)) + cfg = dict(base); cfg.update(num_tokens=base['num_max_tokens_per_rank']) + out.append(('L4.tokens_max', cfg)) + return out + + +def _layer5_stress(num_ranks: int, num_tests: int) -> List[Tuple[str, Dict[str, Any]]]: + """Random configs under simple constraints.""" + rng = random.Random(0xC0FFEE) + out = [] + for i in range(num_tests): + hidden = rng.choice([512, 1024, 2048]) + ih = rng.choice([512, 1024, 2048]) + topk = rng.choice([1, 2, 4]) + tokens = rng.choice([32, 64, 128, 256, 512]) + masked = rng.choice([0.0, 0.0, 0.3, 0.5]) + clamp = rng.choice([1.0, 10.0, math.inf]) + fm = rng.choice([True, False]) + cfg = dict(num_max_tokens_per_rank=tokens, num_tokens=tokens, + hidden=hidden, intermediate_hidden=ih, + num_experts=8 * num_ranks, num_topk=topk, + masked_ratio=masked, activation_clamp=clamp, fast_math=fm) + out.append((f'L5.rand{i:03d}', cfg)) + return out + + +# ---------------------------------------------------------------------------- +# Entry point +# ---------------------------------------------------------------------------- + +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + + # Skip on non-SM90 + if get_arch_major() != 9: + dist_print(f'[SKIP] test_mega_moe_sm90 requires SM90; got SM{get_arch_major()}0', + once_in_node=True) + dist.destroy_process_group() + return + + diff_tol = args.diff_tol + layers: List[Tuple[str, Dict[str, Any]]] = [] + + if 1 in args.layers: + layers += _layer1_smoke() + if 2 in args.layers: + layers += _layer2_heuristic_branches(num_ranks) + if 3 in args.layers: + layers += _layer3_shape_sweep(num_ranks) + if 4 in args.layers: + layers += _layer4_edges(num_ranks) + if 5 in args.layers: + layers += _layer5_stress(num_ranks, args.num_correctness_tests or 8) + + if args.filter: + layers = [(n, c) for n, c in layers if args.filter in n] + + dist_print(f'SM90 MegaMoE test plan: {len(layers)} scenarios across ' + f'layers {sorted(args.layers)} on {num_ranks} ranks', + once_in_node=True) + + failures: List[str] = [] + for name, cfg in layers: + try: + _run_scenario(name, cfg, rank_idx, num_ranks, group, diff_tol) + except AssertionError as ex: + dist_print(f' [{name}] FAIL: {ex}', once_in_node=True) + failures.append(name) + if args.fail_fast: + break + + dist_print('', once_in_node=True) + if failures: + dist_print(f'FAILED {len(failures)}/{len(layers)} scenarios: {failures}', + once_in_node=True) + else: + dist_print(f'PASSED all {len(layers)} scenarios', once_in_node=True) + + dist.barrier() + dist.destroy_process_group() + if failures: + sys.exit(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Layered SM90 MegaMoE tests') + parser.add_argument('--num-processes', type=int, default=2, + help='Number of ranks to spawn (default: 2)') + parser.add_argument('--layers', type=int, nargs='+', default=[1, 2, 3, 4], + help='Which layers to run (1..5). Default: 1 2 3 4. ' + 'Layer 5 requires --num-correctness-tests.') + parser.add_argument('--num-correctness-tests', type=int, default=None, + help='Layer 5 stress test count') + parser.add_argument('--filter', type=str, default='', + help='Substring filter on scenario names') + parser.add_argument('--diff-tol', type=float, default=0.07, + help='calc_diff tolerance (default: 0.07)') + parser.add_argument('--fail-fast', action='store_true', + help='Stop on first failing scenario') + args = parser.parse_args() + + np = args.num_processes + torch.multiprocessing.spawn(test, args=(np, args), nprocs=np)