Skip to content

[Runtime][WebGPU] Add prerequisites for Gemma 4 E2B WebGPU execution#19766

Draft
MakotoUwu wants to merge 1 commit into
apache:mainfrom
MakotoUwu:ole-34/apache-tvm-gemma4-e2b-prereqs
Draft

[Runtime][WebGPU] Add prerequisites for Gemma 4 E2B WebGPU execution#19766
MakotoUwu wants to merge 1 commit into
apache:mainfrom
MakotoUwu:ole-34/apache-tvm-gemma4-e2b-prereqs

Conversation

@MakotoUwu

Copy link
Copy Markdown

Summary

This PR adds a small set of TVM runtime, Relax frontend, TIRx, DLight, WebGPU target, and WebAssembly runtime prerequisites that are exercised by Gemma 4 E2B text-only WebGPU execution through downstream MLC/WebLLM.

The intent is not to add Gemma 4 model support to Apache TVM directly. The downstream model implementation remains in MLC-LLM. This PR only contains lower-layer behavior needed by that downstream path.

Changes

  • Make DictAttrs::HasNonzeroAttr accept IntImm and bool attributes, in addition to native integer attributes. This fixes TIRx attrs such as tirx.noalias when represented as IntImm.
  • Let DLight GPU matmul schedule rules return non-applicable for helper PrimFuncs without a root block, instead of raising during default schedule application.
  • Add max_shared_memory_per_block = 32768 as the default WebGPU target option in src/backend/webgpu/codegen/target_kind.cc.
  • Add freq_dim_base support for GPT-J-style RoPE frequency generation, so callers can decouple rotated dimensions from the frequency-base dimension.
  • Mark generated RoPE PrimFuncs as private to avoid duplicate module-scope symbols when downstream code creates multiple RoPE factories.
  • Group TIRx device functions by target kind name instead of full target string to avoid splitting a single backend into multiple device modules when attributes differ.
  • Fix PagedKVCache metadata reservation ordering and route sliding MHA attention through the MHA path.
  • Make the wasm runtime tolerate native-f32 payloads tagged as f32-to-bf16, and add chunked tensor loading support for large records.
  • Add focused tests for the IntImm attr case and the non-root helper PrimFunc DLight matmul case.

Validation

Local validation on macOS after rebasing on current apache/main:

cmake --build build --parallel 8

Result: passed. The build produced nonfatal warnings, mostly -Woverloaded-virtual warnings around TIRx visitor overloads, plus LLVM deprecation warnings.

Focused tests:

TVM_LIBRARY_PATH=build/lib python -m pytest \
  tests/python/ir/test_ir_attrs.py::test_dict_attrs_has_nonzero_attr_accepts_int_imm -q

TVM_LIBRARY_PATH=build/lib python -m pytest \
  tests/python/s_tir/dlight/test_gpu_matmul.py::test_matmul_rule_skips_non_root_block_helper_func -q

Result:

1 passed in 0.02s
1 passed in 0.03s

Branch shape and hygiene:

git rev-list --left-right --count apache/main...HEAD
0 1

git diff --check apache/main...HEAD
# no output

Downstream WebLLM smoke evidence from the Apache-prep wasm:

  • Wasm SHA256: 70d7295dc91b622b79ceeada2c64b4c20787832631c04e3714de95db04515dfc
  • Browser result title: apache-pr-validate:ok
  • Prompt checks:
    • Hi: load 11.0s, generation 0.7s
    • France capital: load 4.3s, generation 0.2s, output Paris
    • Haiku: load 4.0s, generation 2.4s

Non-goals

  • This PR does not add Gemma 4 model registration to TVM.
  • This PR does not add MLC-LLM model code.
  • This PR does not add WebLLM model-list entries.
  • This PR does not claim multimodal Gemma 4 support.
  • Full MLC vs Transformers.js benchmarking is intentionally deferred to downstream artifact/model-card work.

Unblocks the Gemma 4 E2B text-only path in mlc-llm by adding the
TVM runtime, Relax frontend, WebGPU target, and WebAssembly runtime
support that the model exercises during WebGPU prefill/decode.

PagedKVCache hybrid dispatch (src/runtime/vm/paged_kv_cache.cc)
---------------------------------------------------------------
* Hoist `ReserveAppendLengthInSeq` above the aux-data loop so page
  metadata reflects the current prefill when a request spans multiple
  blocks in a single call. Previously the aux-data loop read block
  page counts before the blocks were reserved, producing empty
  `page_indptr` entries for the first call of a newly-created
  sequence whose length exceeded `page_size`.
* Route `AttnKind::kMHASliding` through the MHA dispatch arm in
  `SelfAttention()` and `CrossAttention()`. Without this, sliding
  layers in Gemma 4 fall through to the MLA path and return
  zero-initialised output for their attention sub-graph.

WebGPU target kind (src/target/target_kind.cc)
----------------------------------------------
* Register `max_shared_memory_per_block = 32768` on the `webgpu`
  target kind. Without this attribute, Dlight's shared-memory
  analysis falls back to the generic 48 KB default and can generate
  decode kernels that exceed Chrome/Dawn's 32 KB workgroup-storage
  budget. Chrome currently exposes 32768 and the WebGPU spec mandates
  at least 16384, so 32768 is a conservative default.

Relax nn.llm RoPE support (python/tvm/relax/frontend/nn/llm/)
-------------------------------------------------------------
* Add a `freq_dim_base` parameter to `rope_freq_gptj` so callers can
  decouple the frequency-base dimension from the rotated range.
  Gemma 4 full-attention layers use partial rotary embeddings with
  `head_dim` as the frequency base and a smaller rotated dimension.
* Mark the generated `fused_rope` and `fused_rope_longrope_scaling`
  prim_funcs as private. Gemma 4 builds separate RoPE factories for
  sliding and partial-rotary full-attention layers, and private nested
  prim_funcs avoid duplicate module-scope global symbols.
* Promote the `apply_rope` prim_func parameter from `T.int32` to
  `T.int64` to match the existing caller convention that passes an
  int64 immediate.

TIRX device-module grouping (python/tvm/tirx/build.py)
------------------------------------------------------
* Group device functions by `target.kind.name` instead of the
  stringified target object in `split_host_device_mods`. This avoids
  treating target attributes as separate backend kinds while still
  preserving a canonical target per backend kind.

WebAssembly runtime (web/)
--------------------------
* Reorder FFI includes in `web/emcc/wasm_runtime.cc` so `tvm_ffi::*`
  static initialisers run before `runtime::*` initialisers.
* Extend `ArrayDecodeStorage` with a fall-through for payloads tagged
  `f32-to-bf16` whose byte length matches native float32, allowing
  native-f32 shards with that tag to decode correctly.
* Add chunked tensor loading in `web/src/runtime.ts` for records whose
  `nbytes` exceed the per-call transfer budget. Also unpack
  `kTVMFFIShape` results so chunked loading can call `tensorCreateView`
  with explicit shape tuples.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several enhancements and bug fixes across TVM, including: updating "HasNonzeroAttr" to support "IntImm" and "bool" types; extending GPT-J RoPE frequency calculations to support partial rotary embeddings via a new "freq_dim_base" parameter; adding error handling in GPU matmul scheduling to skip functions without a root block; grouping device functions by target kind name during build; adding "max_shared_memory_per_block" to WebGPU target attributes; unconditionally reserving pages in the paged KV cache to fix intra-prefill shared-KV cross-attention; and implementing chunked copying for large tensors in the Web runtime to prevent memory issues. There are no review comments, so I have no feedback to provide.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

@tqchen tqchen left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, would be good to split the PR into webgpu runtime only part which would be easier to review, the compiler side have some unintended issues

Comment thread include/tvm/ir/attrs.h
const DictAttrsNode* node = get();
auto it = node->dict.find(attr_key);
if (it == node->dict.end()) {
return false;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are moving away from IntImm attr

@MakotoUwu

Copy link
Copy Markdown
Author

Thanks, this makes sense. I split the Web/WebGPU runtime-only portion into a smaller draft PR here: #19771.

I will keep the compiler-side pieces from this draft separate and revisit them independently, especially the attr/RoPE/TIRx changes that caused the unintended CI issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants