Skip to content

fix the issue of incorrect results in test_swiglu_backward_and_per_token_cast_kernel.py#5

Open
ginkito wants to merge 1 commit into
MetaX-MACA:devfrom
ginkito:swiglu_backward_and_per_token_cast_kernel
Open

fix the issue of incorrect results in test_swiglu_backward_and_per_token_cast_kernel.py#5
ginkito wants to merge 1 commit into
MetaX-MACA:devfrom
ginkito:swiglu_backward_and_per_token_cast_kernel

Conversation

@ginkito
Copy link
Copy Markdown

@ginkito ginkito commented May 11, 2026

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the get_fused_mapping_kernel to transition from a 64-lane to a 32-lane warp configuration, modifying alignment, lane masks, and intrinsic calls. Additionally, the thread count in the swiglu_backward_and_per_token_cast_kernel is increased from 64 to 256. Review feedback highlights that the hardcoded alignment and warp-level logic should be made more dynamic or synchronized via constants to prevent potential mismatches and improve portability across different GPU architectures.

@T.macro
def divide_task(length: int, num_tasks: int, task_id: int, start: T.Ref, end: T.Ref):
length_per_task = align(T.ceildiv(length, num_tasks), 64)
length_per_task = align(T.ceildiv(length, num_tasks), 32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The alignment value in the divide_task macro is hardcoded to 32. This must be kept in sync with the warp_size used in the kernel (line 36) to ensure that each task partitioned by the macro starts on a warp boundary. This alignment is critical for the correctness of warp-level primitives like __match_any_sync used later in the kernel. Consider using a shared constant or passing the warp size as an argument to the macro to avoid duplication and potential mismatch.

Comment on lines +144 to +149
lane_mask = T.uint32(1 << lane_idx) + T.uint32(1 << lane_idx) - 1
lane_mask_rev = ~lane_mask
for i in T.serial(start + lane_idx, aligned_end, warp_size):
T.assume(0 <= i)
expert_idx = T.Select(i < numel, T.int32(topk_idx_1d[i]), -1)
mask = T.call_extern(T.uint64, '__match_any_sync', tilelang.tvm.tir.const(0xFFFFFFFFFFFFFFFF, T.uint64), expert_idx)
mask = T.call_extern(T.uint32, '__match_any_sync', tilelang.tvm.tir.const(0xFFFFFFFF, T.uint32), expert_idx)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warp-level logic (lane masks and __match_any_sync) is now hardcoded for 32-lane warps. While this correctly fixes the issue on NVIDIA hardware, it reduces the portability of the kernel to architectures with different warp sizes (e.g., 64-lane), which the repository seems to support in other modules (like common.py). It is recommended to derive the warp size and associated masks dynamically based on the target device configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant