[Runtime][WebGPU] Add prerequisites for Gemma 4 E2B WebGPU execution#19766
[Runtime][WebGPU] Add prerequisites for Gemma 4 E2B WebGPU execution#19766MakotoUwu wants to merge 1 commit into
Conversation
Unblocks the Gemma 4 E2B text-only path in mlc-llm by adding the TVM runtime, Relax frontend, WebGPU target, and WebAssembly runtime support that the model exercises during WebGPU prefill/decode. PagedKVCache hybrid dispatch (src/runtime/vm/paged_kv_cache.cc) --------------------------------------------------------------- * Hoist `ReserveAppendLengthInSeq` above the aux-data loop so page metadata reflects the current prefill when a request spans multiple blocks in a single call. Previously the aux-data loop read block page counts before the blocks were reserved, producing empty `page_indptr` entries for the first call of a newly-created sequence whose length exceeded `page_size`. * Route `AttnKind::kMHASliding` through the MHA dispatch arm in `SelfAttention()` and `CrossAttention()`. Without this, sliding layers in Gemma 4 fall through to the MLA path and return zero-initialised output for their attention sub-graph. WebGPU target kind (src/target/target_kind.cc) ---------------------------------------------- * Register `max_shared_memory_per_block = 32768` on the `webgpu` target kind. Without this attribute, Dlight's shared-memory analysis falls back to the generic 48 KB default and can generate decode kernels that exceed Chrome/Dawn's 32 KB workgroup-storage budget. Chrome currently exposes 32768 and the WebGPU spec mandates at least 16384, so 32768 is a conservative default. Relax nn.llm RoPE support (python/tvm/relax/frontend/nn/llm/) ------------------------------------------------------------- * Add a `freq_dim_base` parameter to `rope_freq_gptj` so callers can decouple the frequency-base dimension from the rotated range. Gemma 4 full-attention layers use partial rotary embeddings with `head_dim` as the frequency base and a smaller rotated dimension. * Mark the generated `fused_rope` and `fused_rope_longrope_scaling` prim_funcs as private. Gemma 4 builds separate RoPE factories for sliding and partial-rotary full-attention layers, and private nested prim_funcs avoid duplicate module-scope global symbols. * Promote the `apply_rope` prim_func parameter from `T.int32` to `T.int64` to match the existing caller convention that passes an int64 immediate. TIRX device-module grouping (python/tvm/tirx/build.py) ------------------------------------------------------ * Group device functions by `target.kind.name` instead of the stringified target object in `split_host_device_mods`. This avoids treating target attributes as separate backend kinds while still preserving a canonical target per backend kind. WebAssembly runtime (web/) -------------------------- * Reorder FFI includes in `web/emcc/wasm_runtime.cc` so `tvm_ffi::*` static initialisers run before `runtime::*` initialisers. * Extend `ArrayDecodeStorage` with a fall-through for payloads tagged `f32-to-bf16` whose byte length matches native float32, allowing native-f32 shards with that tag to decode correctly. * Add chunked tensor loading in `web/src/runtime.ts` for records whose `nbytes` exceed the per-call transfer budget. Also unpack `kTVMFFIShape` results so chunked loading can call `tensorCreateView` with explicit shape tuples.
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements and bug fixes across TVM, including: updating "HasNonzeroAttr" to support "IntImm" and "bool" types; extending GPT-J RoPE frequency calculations to support partial rotary embeddings via a new "freq_dim_base" parameter; adding error handling in GPU matmul scheduling to skip functions without a root block; grouping device functions by target kind name during build; adding "max_shared_memory_per_block" to WebGPU target attributes; unconditionally reserving pages in the paged KV cache to fix intra-prefill shared-KV cross-attention; and implementing chunked copying for large tensors in the Web runtime to prevent memory issues. There are no review comments, so I have no feedback to provide.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
tqchen
left a comment
There was a problem hiding this comment.
Thanks for the PR, would be good to split the PR into webgpu runtime only part which would be easier to review, the compiler side have some unintended issues
| const DictAttrsNode* node = get(); | ||
| auto it = node->dict.find(attr_key); | ||
| if (it == node->dict.end()) { | ||
| return false; |
There was a problem hiding this comment.
we are moving away from IntImm attr
|
Thanks, this makes sense. I split the Web/WebGPU runtime-only portion into a smaller draft PR here: #19771. I will keep the compiler-side pieces from this draft separate and revisit them independently, especially the attr/RoPE/TIRx changes that caused the unintended CI issues. |
Summary
This PR adds a small set of TVM runtime, Relax frontend, TIRx, DLight, WebGPU target, and WebAssembly runtime prerequisites that are exercised by Gemma 4 E2B text-only WebGPU execution through downstream MLC/WebLLM.
The intent is not to add Gemma 4 model support to Apache TVM directly. The downstream model implementation remains in MLC-LLM. This PR only contains lower-layer behavior needed by that downstream path.
Changes
DictAttrs::HasNonzeroAttracceptIntImmand bool attributes, in addition to native integer attributes. This fixes TIRx attrs such astirx.noaliaswhen represented asIntImm.max_shared_memory_per_block = 32768as the default WebGPU target option insrc/backend/webgpu/codegen/target_kind.cc.freq_dim_basesupport for GPT-J-style RoPE frequency generation, so callers can decouple rotated dimensions from the frequency-base dimension.f32-to-bf16, and add chunked tensor loading support for large records.IntImmattr case and the non-root helper PrimFunc DLight matmul case.Validation
Local validation on macOS after rebasing on current
apache/main:Result: passed. The build produced nonfatal warnings, mostly
-Woverloaded-virtualwarnings around TIRx visitor overloads, plus LLVM deprecation warnings.Focused tests:
Result:
Branch shape and hygiene:
Downstream WebLLM smoke evidence from the Apache-prep wasm:
70d7295dc91b622b79ceeada2c64b4c20787832631c04e3714de95db04515dfcapache-pr-validate:okHi: load11.0s, generation0.7s4.3s, generation0.2s, outputParis4.0s, generation2.4sNon-goals