Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions csrc/tvm_ffi_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,17 @@ void dg_m_grouped_bf16_gemm_nt_contiguous(TensorView a, TensorView b, TensorView
);
}

void dg_m_grouped_bf16_gemm_nn_contiguous(TensorView a, TensorView b, TensorView d,
TensorView grouped_layout,
std::string compiled_dims,
bool use_psum_layout) {
gemm::m_grouped_bf16_gemm_nn_contiguous(
convert_to_torch_tensor(a), convert_to_torch_tensor(b),
convert_to_torch_tensor(d), convert_to_torch_tensor(grouped_layout),
compiled_dims, use_psum_layout
);
}

void dg_m_grouped_bf16_gemm_nt_masked(TensorView a, TensorView b, TensorView d,
TensorView masked_m,
int64_t expected_m,
Expand All @@ -410,6 +421,22 @@ void dg_m_grouped_bf16_gemm_nt_masked(TensorView a, TensorView b, TensorView d,
);
}

void dg_k_grouped_bf16_gemm_tn_contiguous(TensorView a, TensorView b, TensorView d,
Array<int64_t> ks, TensorView ks_tensor,
Optional<TensorView> c,
std::string compiled_dims) {
std::vector<int> ks_val;
ks_val.reserve(ks.size());
for (Array<int64_t>::iterator it = ks.begin(); it != ks.end(); ++it)
ks_val.push_back(static_cast<int>(*it));
auto c_opt = c.has_value()? std::optional<torch::Tensor>(convert_to_torch_tensor(c.value())) : std::nullopt;
gemm::k_grouped_bf16_gemm_tn_contiguous(
convert_to_torch_tensor(a), convert_to_torch_tensor(b),
convert_to_torch_tensor(d), ks_val, convert_to_torch_tensor(ks_tensor),
c_opt, compiled_dims
);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_nt, dg_fp8_fp4_gemm_nt);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_nn, dg_fp8_fp4_gemm_nn);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_tn, dg_fp8_fp4_gemm_tn);
Expand All @@ -422,7 +449,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_nn, dg_bf16_gemm_nn);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tn, dg_bf16_gemm_tn);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tt, dg_bf16_gemm_tt);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nt_contiguous, dg_m_grouped_bf16_gemm_nt_contiguous);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nn_contiguous, dg_m_grouped_bf16_gemm_nn_contiguous);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nt_masked, dg_m_grouped_bf16_gemm_nt_masked);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(k_grouped_bf16_gemm_tn_contiguous, dg_k_grouped_bf16_gemm_tn_contiguous);

// Einsum
void dg_einsum(std::string expr, TensorView a, TensorView b, TensorView d,
Expand Down
5 changes: 3 additions & 2 deletions deep_gemm/mega/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def __init__(self, group: dist.ProcessGroup,
self.group.barrier()
torch.cuda.synchronize()

# Create input buffer views
# Create input buffer views (as torch tensors, not tvm-ffi tensors).
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
self.l2_acts, self.l2_acts_sf) = map(
Comment thread
Fridge003 marked this conversation as resolved.
torch.from_dlpack, slice_input_buffers(self.buffer))

def destroy(self):
self.handle = None
Expand Down
6 changes: 6 additions & 0 deletions sgl_deep_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,15 @@ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m, recipe=None,
def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, compiled_dims='nk', use_psum_layout=False, expected_m_for_psum_layout=None):
_C.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m_for_psum_layout)

def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout, compiled_dims='nk', use_psum_layout=False):
_C.m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout, compiled_dims, use_psum_layout)

def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims='nk'):
_C.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims)

def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None, compiled_dims='mn'):
_C.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c, compiled_dims)

bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked

except AttributeError:
Expand Down
51 changes: 25 additions & 26 deletions tests/test_mega_moe_l1_fp4_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,6 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
activation_clamp = args.activation_clamp
assert num_tokens <= num_max_tokens_per_rank

buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden
)

# Inputs (BF16) + topk routing
x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
l1_weights_bf16 = torch.randn(
Expand All @@ -262,6 +256,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):

# FP8 / FP4 quantizations needed by the kernel
x_fp8 = per_token_cast_to_fp8(x_bf16, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
x_fp4 = per_token_cast_to_fp4(x_bf16, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)

def cast_grouped_weights_to_fp4(bf16_weights):
num_groups, n, k = bf16_weights.shape
Expand All @@ -277,9 +272,9 @@ def cast_grouped_weights_to_fp4(bf16_weights):
transformed_l1_weights, transformed_l2_weights = \
deep_gemm.transform_weights_for_mega_moe(l1_weights_fp4, l2_weights_fp4)

def run_once():
buffer.x[:num_tokens].copy_(x_fp8[0])
buffer.x_sf[:num_tokens].copy_(x_fp8[1])
def run_once(buffer, x_src):
buffer.x[:num_tokens].copy_(x_src[0])
buffer.x_sf[:num_tokens].copy_(x_src[1])
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
cumulative_local_expert_recv_stats.zero_()
Expand All @@ -294,39 +289,43 @@ def run_once():
)
return y, cumulative_local_expert_recv_stats.clone()

# Buffer layout depends on DG_USE_FP4_ACTS; set the env before allocating.
def make_buffer(use_fp4_acts):
os.environ['DG_USE_FP4_ACTS'] = '1' if use_fp4_acts else '0'
os.environ['DG_COMM_KERNEL_DEBUG'] = '0' # don't zero buffer between calls
return deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden
)

# ---- BF16 reference for L1 SwiGLU output (per token×topk) ----
bf16_ref = _bf16_reference_l1(
x_bf16, l1_weights_bf16, topk_idx, topk_weights, activation_clamp)
# bf16_ref: (num_tokens, num_topk, intermediate_hidden) — only nonzero
# where topk_idx[t, k] is in this rank's expert range.

# ---- Run FP8 path ----
os.environ['DG_USE_FP4_ACTS'] = '0'
os.environ['DG_COMM_KERNEL_DEBUG'] = '0' # don't zero buffer between calls
# First run is a warmup. Stream A0.2 verified FP8-vs-FP8 across two
# consecutive runs gives a perfect (rel-MAE = 0) `y` match — the kernel
# IS deterministic at the `y` level, so any nonzero FP4-vs-FP8 `y`
# delta is a real numerical disagreement, not slot-permutation noise.
_ = run_once()
# ---- Run FP8 path (first call warms up) ----
buffer_fp8 = make_buffer(use_fp4_acts=False)
_ = run_once(buffer_fp8, x_fp8)
torch.cuda.synchronize()
y_fp8, recv_stats_fp8 = run_once()
y_fp8, recv_stats_fp8 = run_once(buffer_fp8, x_fp8)
torch.cuda.synchronize()
y_fp8_a = y_fp8 # keep as alias so the FP8-vs-FP8 baseline below works
# Snapshot l2_acts and l2_acts_sf before they get overwritten by next call.
l2_acts_fp8 = buffer.l2_acts.clone()
l2_acts_sf_fp8 = buffer.l2_acts_sf.clone()
y_fp8_a = y_fp8
l2_acts_fp8 = buffer_fp8.l2_acts.clone()
l2_acts_sf_fp8 = buffer_fp8.l2_acts_sf.clone()
recv_fp8_list = recv_stats_fp8.cpu().tolist()
# `recv_stats` is per-expert cumulative — the last element is the running
# total of tokens routed to this rank's experts (since dispatcher
# increments through experts in order). For our single-rank harness we
# take the last value as the slot count.
total_local_fp8 = int(recv_fp8_list[-1]) if recv_fp8_list else 0

# ---- Run FP4 path ----
os.environ['DG_USE_FP4_ACTS'] = '1'
_ = run_once()
# ---- Run FP4 path (separate buffer, laid out for packed E2M1) ----
buffer = make_buffer(use_fp4_acts=True)
_ = run_once(buffer, x_fp4)
torch.cuda.synchronize()
y_fp4, recv_stats_fp4 = run_once()
y_fp4, recv_stats_fp4 = run_once(buffer, x_fp4)
torch.cuda.synchronize()
l2_acts_fp4 = buffer.l2_acts.clone()
l2_acts_sf_fp4 = buffer.l2_acts_sf.clone()
Expand Down