Skip to content

cuda: generalize router-select for arbitrary expert count (fixes Pro on CUDA, #427)#435

Draft
newjordan wants to merge 1 commit into
antirez:mainfrom
newjordan:cuda-router-select-any-n-expert
Draft

cuda: generalize router-select for arbitrary expert count (fixes Pro on CUDA, #427)#435
newjordan wants to merge 1 commit into
antirez:mainfrom
newjordan:cuda-router-select-any-n-expert

Conversation

@newjordan

Copy link
Copy Markdown

Problem

On the CUDA backend, DeepSeek-V4 PRO (DeepSeek-V4-Pro-IQ2XXS-...-Instruct-imatrix.gguf, 384 routed experts) crashes on the very first prompt:

ds4: CUDA loading model tensors into device cache: 2.04 GiB
ds4: gpu layer 0 ffn batch encode failed
ds4: prompt processing failed: cuda prefill failed

This is #427 (DGX Spark GB10, every --ssd-streaming / ctx / prefill-chunk combination). Flash runs fine; PRO never starts.

Root cause

The CUDA routed-expert top-k select is hardcoded to the Flash configuration. Both dispatchers — ds4_gpu_router_select_batch_tensor (prefill) and ds4_gpu_router_select_tensor (decode) — open with:

if (n_expert != 256u || n_expert_used != 6u || fabsf(expert_weight_scale - 1.5f) > 1.0e-6f) return 0;

PRO is n_expert = 384 and expert_weight_scale = 2.5, so the guard returns 0 with no CUDA error; metal_graph_encode_layer_ffn_batch then only reports the generic ffn batch encode failed. The three select kernels also bake in 256 (logits + t*256, __shared__ float sprob[256], the warp-topk's 32 lanes × 8 experts, <<<n_tokens, 256>>>) and the 1.5 scale.

Fix

Generalize the router-select to any n_expert that is a non-zero multiple of 32 up to 512 (still requires n_expert_used == 6), using the model's routed weight scale:

  • thread n_expert and expert_weight_scale into router_select_kernel, router_select_parallel_kernel, and router_select_warp_topk_kernel
  • replace the literal 256 (strides, loops, bounds, shared scratch, parallel block width) with n_expert, and the literal 1.5 with the scale
  • warp-topk: each of the 32 lanes now owns n_expert/32 experts (8 for 256, 12 for 384); per-lane register arrays and shared scratch are capped at 512 experts
  • the buffer-size checks and the router-bias range use n_expert

Flash (256/6/1.5) is unchanged by construction: per_lane stays 8 and the arithmetic is identical, so output is bit-for-bit the same.

One file, ds4_cuda.cu, +64/−48.

Validation

Machine: DGX Spark GB10 (sm_121), CUDA 13.0 toolkit + driver, --ssd-streaming.

AI usage disclosure

YES — AI assisted with the diagnosis, the diff, and PR preparation; the contributor reviewed the code, the validation, and the submitted content.

Fixes #427.

The CUDA routed-expert top-k select was hardcoded to the Flash config:
n_expert == 256, n_expert_used == 6, routed weight scale == 1.5. Both the
batch (prefill) and single-token (decode) dispatchers rejected anything
else with a silent `return 0`. DeepSeek-V4 PRO routes 384 experts with a
weight scale of 2.5, so prefill failed immediately at the guard, surfacing
only as `cuda prefill failed`.

- Thread n_expert and the model's expert_weight_scale into router_select_kernel,
  router_select_parallel_kernel and router_select_warp_topk_kernel; replace the
  literal 256 (strides, loops, bounds, shared arrays, parallel block width) with
  n_expert and the literal 1.5 with the scale.
- warp-topk: each of the 32 lanes now owns n_expert/32 experts (8 for 256, 12
  for 384); per-lane register arrays and shared scratch capped at 512 experts.
- Both dispatch guards accept any n_expert that is a non-zero multiple of 32 up
  to 512 (still require n_expert_used == 6); the buffer-size checks and the
  router-bias range now use n_expert.

Flash (256/6/1.5) is unchanged: per_lane stays 8 and the arithmetic is identical,
so output is bit-for-bit the same.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@newjordan

Copy link
Copy Markdown
Author

Hey antirez, if you have any issues, or want a different format just let me know. Cheers and TY so much for DS4 work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Immediate prefill crash on layer 0 every time. dgx spark on fully up to date DGX OS and firmwares.

1 participant