diff --git a/csrc/tvm_ffi_api.cpp b/csrc/tvm_ffi_api.cpp index 78649118fb..abd46179cd 100644 --- a/csrc/tvm_ffi_api.cpp +++ b/csrc/tvm_ffi_api.cpp @@ -399,6 +399,17 @@ void dg_m_grouped_bf16_gemm_nt_contiguous(TensorView a, TensorView b, TensorView ); } +void dg_m_grouped_bf16_gemm_nn_contiguous(TensorView a, TensorView b, TensorView d, + TensorView grouped_layout, + std::string compiled_dims, + bool use_psum_layout) { + gemm::m_grouped_bf16_gemm_nn_contiguous( + convert_to_torch_tensor(a), convert_to_torch_tensor(b), + convert_to_torch_tensor(d), convert_to_torch_tensor(grouped_layout), + compiled_dims, use_psum_layout + ); +} + void dg_m_grouped_bf16_gemm_nt_masked(TensorView a, TensorView b, TensorView d, TensorView masked_m, int64_t expected_m, @@ -410,6 +421,22 @@ void dg_m_grouped_bf16_gemm_nt_masked(TensorView a, TensorView b, TensorView d, ); } +void dg_k_grouped_bf16_gemm_tn_contiguous(TensorView a, TensorView b, TensorView d, + Array ks, TensorView ks_tensor, + Optional c, + std::string compiled_dims) { + std::vector ks_val; + ks_val.reserve(ks.size()); + for (Array::iterator it = ks.begin(); it != ks.end(); ++it) + ks_val.push_back(static_cast(*it)); + auto c_opt = c.has_value()? std::optional(convert_to_torch_tensor(c.value())) : std::nullopt; + gemm::k_grouped_bf16_gemm_tn_contiguous( + convert_to_torch_tensor(a), convert_to_torch_tensor(b), + convert_to_torch_tensor(d), ks_val, convert_to_torch_tensor(ks_tensor), + c_opt, compiled_dims + ); +} + TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_nt, dg_fp8_fp4_gemm_nt); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_nn, dg_fp8_fp4_gemm_nn); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp8_fp4_gemm_tn, dg_fp8_fp4_gemm_tn); @@ -422,7 +449,9 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_nn, dg_bf16_gemm_nn); TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tn, dg_bf16_gemm_tn); TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tt, dg_bf16_gemm_tt); TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nt_contiguous, dg_m_grouped_bf16_gemm_nt_contiguous); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nn_contiguous, dg_m_grouped_bf16_gemm_nn_contiguous); TVM_FFI_DLL_EXPORT_TYPED_FUNC(m_grouped_bf16_gemm_nt_masked, dg_m_grouped_bf16_gemm_nt_masked); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(k_grouped_bf16_gemm_tn_contiguous, dg_k_grouped_bf16_gemm_tn_contiguous); // Einsum void dg_einsum(std::string expr, TensorView a, TensorView b, TensorView d, diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 77db53cd92..2b58232d7e 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -41,11 +41,12 @@ def __init__(self, group: dist.ProcessGroup, self.group.barrier() torch.cuda.synchronize() - # Create input buffer views + # Create input buffer views (as torch tensors, not tvm-ffi tensors). (self.x, self.x_sf, self.topk_idx, self.topk_weights, self.l1_acts, self.l1_acts_sf, - self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + self.l2_acts, self.l2_acts_sf) = map( + torch.from_dlpack, slice_input_buffers(self.buffer)) def destroy(self): self.handle = None diff --git a/sgl_deep_gemm/__init__.py b/sgl_deep_gemm/__init__.py index 50750112f7..12cbf9ffea 100644 --- a/sgl_deep_gemm/__init__.py +++ b/sgl_deep_gemm/__init__.py @@ -255,9 +255,15 @@ def m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m, recipe=None, def m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, compiled_dims='nk', use_psum_layout=False, expected_m_for_psum_layout=None): _C.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, compiled_dims, use_psum_layout, expected_m_for_psum_layout) + def m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout, compiled_dims='nk', use_psum_layout=False): + _C.m_grouped_bf16_gemm_nn_contiguous(a, b, d, grouped_layout, compiled_dims, use_psum_layout) + def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims='nk'): _C.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims) + def k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=None, compiled_dims='mn'): + _C.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c, compiled_dims) + bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked except AttributeError: diff --git a/tests/test_mega_moe_l1_fp4_accuracy.py b/tests/test_mega_moe_l1_fp4_accuracy.py index 1d013ff240..1cd77fba6f 100644 --- a/tests/test_mega_moe_l1_fp4_accuracy.py +++ b/tests/test_mega_moe_l1_fp4_accuracy.py @@ -241,12 +241,6 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): activation_clamp = args.activation_clamp assert num_tokens <= num_max_tokens_per_rank - buffer = deep_gemm.get_symm_buffer_for_mega_moe( - group, num_experts, - num_max_tokens_per_rank, num_topk, - hidden, intermediate_hidden - ) - # Inputs (BF16) + topk routing x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') l1_weights_bf16 = torch.randn( @@ -262,6 +256,7 @@ def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): # FP8 / FP4 quantizations needed by the kernel x_fp8 = per_token_cast_to_fp8(x_bf16, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + x_fp4 = per_token_cast_to_fp4(x_bf16, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) def cast_grouped_weights_to_fp4(bf16_weights): num_groups, n, k = bf16_weights.shape @@ -277,9 +272,9 @@ def cast_grouped_weights_to_fp4(bf16_weights): transformed_l1_weights, transformed_l2_weights = \ deep_gemm.transform_weights_for_mega_moe(l1_weights_fp4, l2_weights_fp4) - def run_once(): - buffer.x[:num_tokens].copy_(x_fp8[0]) - buffer.x_sf[:num_tokens].copy_(x_fp8[1]) + def run_once(buffer, x_src): + buffer.x[:num_tokens].copy_(x_src[0]) + buffer.x_sf[:num_tokens].copy_(x_src[1]) buffer.topk_idx[:num_tokens].copy_(topk_idx) buffer.topk_weights[:num_tokens].copy_(topk_weights) cumulative_local_expert_recv_stats.zero_() @@ -294,27 +289,31 @@ def run_once(): ) return y, cumulative_local_expert_recv_stats.clone() + # Buffer layout depends on DG_USE_FP4_ACTS; set the env before allocating. + def make_buffer(use_fp4_acts): + os.environ['DG_USE_FP4_ACTS'] = '1' if use_fp4_acts else '0' + os.environ['DG_COMM_KERNEL_DEBUG'] = '0' # don't zero buffer between calls + return deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden + ) + # ---- BF16 reference for L1 SwiGLU output (per token×topk) ---- bf16_ref = _bf16_reference_l1( x_bf16, l1_weights_bf16, topk_idx, topk_weights, activation_clamp) # bf16_ref: (num_tokens, num_topk, intermediate_hidden) — only nonzero # where topk_idx[t, k] is in this rank's expert range. - # ---- Run FP8 path ---- - os.environ['DG_USE_FP4_ACTS'] = '0' - os.environ['DG_COMM_KERNEL_DEBUG'] = '0' # don't zero buffer between calls - # First run is a warmup. Stream A0.2 verified FP8-vs-FP8 across two - # consecutive runs gives a perfect (rel-MAE = 0) `y` match — the kernel - # IS deterministic at the `y` level, so any nonzero FP4-vs-FP8 `y` - # delta is a real numerical disagreement, not slot-permutation noise. - _ = run_once() + # ---- Run FP8 path (first call warms up) ---- + buffer_fp8 = make_buffer(use_fp4_acts=False) + _ = run_once(buffer_fp8, x_fp8) torch.cuda.synchronize() - y_fp8, recv_stats_fp8 = run_once() + y_fp8, recv_stats_fp8 = run_once(buffer_fp8, x_fp8) torch.cuda.synchronize() - y_fp8_a = y_fp8 # keep as alias so the FP8-vs-FP8 baseline below works - # Snapshot l2_acts and l2_acts_sf before they get overwritten by next call. - l2_acts_fp8 = buffer.l2_acts.clone() - l2_acts_sf_fp8 = buffer.l2_acts_sf.clone() + y_fp8_a = y_fp8 + l2_acts_fp8 = buffer_fp8.l2_acts.clone() + l2_acts_sf_fp8 = buffer_fp8.l2_acts_sf.clone() recv_fp8_list = recv_stats_fp8.cpu().tolist() # `recv_stats` is per-expert cumulative — the last element is the running # total of tokens routed to this rank's experts (since dispatcher @@ -322,11 +321,11 @@ def run_once(): # take the last value as the slot count. total_local_fp8 = int(recv_fp8_list[-1]) if recv_fp8_list else 0 - # ---- Run FP4 path ---- - os.environ['DG_USE_FP4_ACTS'] = '1' - _ = run_once() + # ---- Run FP4 path (separate buffer, laid out for packed E2M1) ---- + buffer = make_buffer(use_fp4_acts=True) + _ = run_once(buffer, x_fp4) torch.cuda.synchronize() - y_fp4, recv_stats_fp4 = run_once() + y_fp4, recv_stats_fp4 = run_once(buffer, x_fp4) torch.cuda.synchronize() l2_acts_fp4 = buffer.l2_acts.clone() l2_acts_sf_fp4 = buffer.l2_acts_sf.clone()