Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 135 additions & 6 deletions csrc/apis/mega.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
#include "../jit_kernels/impls/sm90_fp8_mega_moe.hpp"

namespace deep_gemm::mega {

Expand All @@ -23,6 +24,15 @@ get_symm_buffer_size_for_mega_moe(
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);

// Architecture-dependent SF dtype for the user-facing tensor view:
// * SM100: per-32 UE8M0 packed 4-into-int (`torch::kInt`).
// * SM90 : per-128 channel float (`torch::kFloat32`).
// Both use the same number of bytes per token (hidden / 32), so the symmetric
// buffer layout is shared; only the slice view dtype changes.
const auto arch_major = device_runtime->get_arch_major();
const bool is_sm90 = arch_major == 9;
const auto sf_dtype = is_sm90 ? torch::kFloat32 : torch::kInt;

// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);

Expand All @@ -31,7 +41,16 @@ get_symm_buffer_size_for_mega_moe(
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
// L2 acts SF granularity differs by arch:
// * SM100 packs 4 UE8M0 bytes per int along K, so each token uses
// `intermediate_hidden / 32` bytes (per-32 K).
// * SM90 stores per-64 K floats so that each L1 epilogue block (which
// produces 64 post-SwiGLU columns) can write its own SF independently
// without cross-CTA amax synchronisation; bytes per token become
// `intermediate_hidden / 64 * sizeof(float) = intermediate_hidden / 16`.
const int fp8_intermediate_sf_bytes_per_token =
is_sm90 ? (intermediate_hidden / 16) : (intermediate_hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(fp8_intermediate_sf_bytes_per_token);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
Expand Down Expand Up @@ -86,10 +105,14 @@ get_symm_buffer_size_for_mega_moe(

// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
// SM100 packs 4 UE8M0 bytes per int along K, so the padded SF token count
// must be divisible by 4. SM90 stores per-128 floats and has no such constraint.
if (not is_sm90)
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);

// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
// Dtype is per-arch (see `sf_dtype` above): float on SM90, int (packed UE8M0) on SM100.
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
Expand All @@ -98,7 +121,7 @@ get_symm_buffer_size_for_mega_moe(
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
torch::TensorOptions().dtype(sf_dtype).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
Expand All @@ -115,16 +138,16 @@ get_symm_buffer_size_for_mega_moe(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
torch::TensorOptions().dtype(sf_dtype).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
{num_max_padded_sf_pool_tokens, is_sm90 ? intermediate_hidden / 64 : intermediate_hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
torch::TensorOptions().dtype(sf_dtype).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
Expand Down Expand Up @@ -224,11 +247,117 @@ static void fp8_fp4_mega_moe(
sym_buffer.zero_();
}

// SM90 (Hopper) FP8 MegaMoE entry point.
//
// Mirrors `fp8_fp4_mega_moe` but expects FP8 (e4m3) weights with per-128 channel
// float scale factors. Top-level routing (which entry to call) is the caller's
// responsibility (see `deep_gemm/mega/__init__.py`).
static void fp8_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
const int& num_experts, const int& num_topk,
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math
) {
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;

// Architecture check
const auto arch_major = device_runtime->get_arch_major();
DG_HOST_ASSERT(arch_major == 9);

// Config checks: SM90 uses block (128, 128) float SF for weights,
// per-token per-128-K float SF for activations.
const auto num_tokens = static_cast<int>(y.size(0));
const auto [rm, rn, rk] = recipe;
DG_HOST_ASSERT(rm == 128 and rn == 128 and rk == 128);
DG_HOST_ASSERT(activation == "swiglu");

// Activation checks
const auto activation_clamp =
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
DG_HOST_ASSERT(activation_clamp >= 0);

// Tensor checks: SM90 weights must be FP8 e4m3, K-major
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(l1_weights.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(l2_weights.scalar_type() == torch::kFloat8_e4m3fn);
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = get_shape<3>(l1_weights);
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = get_shape<3>(l2_weights);
DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank);
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
DG_HOST_ASSERT(hidden == hidden_);
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());

// Shape constraints required by the SM90 kernel:
// * Hidden dims must be multiples of 128 (per-128 SF + scheduler integer-tiling).
// * `l2_arrival_mask` is uint64, with one bit per L1-output N-block of size 64 in the
// intermediate dim, so `kNumL1BlockNs = intermediate_hidden / 64` must be ≤ 64.
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(intermediate_hidden / 64 <= 64);

// Check weight SF layout (block (128, 128) float, MN-major; not TMA-loaded
// so no TMA-stride alignment is required, but we do require contiguity in
// the K-direction within each expert).
constexpr int kGranMN = 128, kGranK = 128;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, false, true, torch::kFloat);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, false, true, torch::kFloat);

// Check stats counter
if (cumulative_local_expert_recv_stats.has_value()) {
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
}

// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, activation);
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);

// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);

sm90_fp8_mega_moe(y,
l1_acts, l1_acts_sf,
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
cumulative_local_expert_recv_stats,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);

if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
sym_buffer.zero_();
}

static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
m.def("fp8_mega_moe", &fp8_mega_moe);
#endif
}

Expand Down
6 changes: 6 additions & 0 deletions csrc/jit/compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class Compiler {
flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage";
if (get_env("DG_JIT_WITH_LINEINFO", 0))
flags += " -Xcompiler -rdynamic -lineinfo";
// NOTES: `--device-debug` (-G) emits full device DWARF so that cuda-gdb
// can resolve `__device__` global variables / line numbers in JIT
// kernels. It DISABLES device-side optimization and will tank perf, so
// it is gated behind an explicit env var.
if (get_env("DG_JIT_WITH_DEVICE_DEBUG", 0))
flags += " --device-debug";
}

virtual ~Compiler() = default;
Expand Down
7 changes: 5 additions & 2 deletions csrc/jit/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ using KernelHandle = CUfunction;
using LaunchConfigHandle = CUlaunchConfig;
using LaunchAttrHandle = CUlaunchAttribute;

// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4
#if CUDA_VERSION >= 12040
// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4.
// Define `DG_JIT_FORCE_LEGACY_LOAD` to force the older `cuModuleLoad` path
// (useful when building against a newer CUDA SDK but running with an older
// driver that lacks the `cuLibrary*` symbols).
#if CUDA_VERSION >= 12040 && !defined(DG_JIT_FORCE_LEGACY_LOAD)
#define DG_JIT_USE_LIBRARY_ENUM_KERNELS
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels);
Expand Down
Loading