fix the issue of incorrect results in test_swiglu_backward_and_per_token_cast_kernel.py#5
Conversation
…ken_cast_kernel.py
There was a problem hiding this comment.
Code Review
This pull request updates the get_fused_mapping_kernel to transition from a 64-lane to a 32-lane warp configuration, modifying alignment, lane masks, and intrinsic calls. Additionally, the thread count in the swiglu_backward_and_per_token_cast_kernel is increased from 64 to 256. Review feedback highlights that the hardcoded alignment and warp-level logic should be made more dynamic or synchronized via constants to prevent potential mismatches and improve portability across different GPU architectures.
| @T.macro | ||
| def divide_task(length: int, num_tasks: int, task_id: int, start: T.Ref, end: T.Ref): | ||
| length_per_task = align(T.ceildiv(length, num_tasks), 64) | ||
| length_per_task = align(T.ceildiv(length, num_tasks), 32) |
There was a problem hiding this comment.
The alignment value in the divide_task macro is hardcoded to 32. This must be kept in sync with the warp_size used in the kernel (line 36) to ensure that each task partitioned by the macro starts on a warp boundary. This alignment is critical for the correctness of warp-level primitives like __match_any_sync used later in the kernel. Consider using a shared constant or passing the warp size as an argument to the macro to avoid duplication and potential mismatch.
| lane_mask = T.uint32(1 << lane_idx) + T.uint32(1 << lane_idx) - 1 | ||
| lane_mask_rev = ~lane_mask | ||
| for i in T.serial(start + lane_idx, aligned_end, warp_size): | ||
| T.assume(0 <= i) | ||
| expert_idx = T.Select(i < numel, T.int32(topk_idx_1d[i]), -1) | ||
| mask = T.call_extern(T.uint64, '__match_any_sync', tilelang.tvm.tir.const(0xFFFFFFFFFFFFFFFF, T.uint64), expert_idx) | ||
| mask = T.call_extern(T.uint32, '__match_any_sync', tilelang.tvm.tir.const(0xFFFFFFFF, T.uint32), expert_idx) |
There was a problem hiding this comment.
The warp-level logic (lane masks and __match_any_sync) is now hardcoded for 32-lane warps. While this correctly fixes the issue on NVIDIA hardware, it reduces the portability of the kernel to architectures with different warp sizes (e.g., 64-lane), which the repository seems to support in other modules (like common.py). It is recommended to derive the warp size and associated masks dynamically based on the target device configuration.
No description provided.