From 02b8b9fa32cd91666768b2758a91cceb500fc9b3 Mon Sep 17 00:00:00 2001 From: MakotoUwu Date: Thu, 16 Apr 2026 00:22:39 +0200 Subject: [PATCH] [Runtime] Add Gemma 4 E2B prerequisites Unblocks the Gemma 4 E2B text-only path in mlc-llm by adding the TVM runtime, Relax frontend, WebGPU target, and WebAssembly runtime support that the model exercises during WebGPU prefill/decode. PagedKVCache hybrid dispatch (src/runtime/vm/paged_kv_cache.cc) --------------------------------------------------------------- * Hoist `ReserveAppendLengthInSeq` above the aux-data loop so page metadata reflects the current prefill when a request spans multiple blocks in a single call. Previously the aux-data loop read block page counts before the blocks were reserved, producing empty `page_indptr` entries for the first call of a newly-created sequence whose length exceeded `page_size`. * Route `AttnKind::kMHASliding` through the MHA dispatch arm in `SelfAttention()` and `CrossAttention()`. Without this, sliding layers in Gemma 4 fall through to the MLA path and return zero-initialised output for their attention sub-graph. WebGPU target kind (src/target/target_kind.cc) ---------------------------------------------- * Register `max_shared_memory_per_block = 32768` on the `webgpu` target kind. Without this attribute, Dlight's shared-memory analysis falls back to the generic 48 KB default and can generate decode kernels that exceed Chrome/Dawn's 32 KB workgroup-storage budget. Chrome currently exposes 32768 and the WebGPU spec mandates at least 16384, so 32768 is a conservative default. Relax nn.llm RoPE support (python/tvm/relax/frontend/nn/llm/) ------------------------------------------------------------- * Add a `freq_dim_base` parameter to `rope_freq_gptj` so callers can decouple the frequency-base dimension from the rotated range. Gemma 4 full-attention layers use partial rotary embeddings with `head_dim` as the frequency base and a smaller rotated dimension. * Mark the generated `fused_rope` and `fused_rope_longrope_scaling` prim_funcs as private. Gemma 4 builds separate RoPE factories for sliding and partial-rotary full-attention layers, and private nested prim_funcs avoid duplicate module-scope global symbols. * Promote the `apply_rope` prim_func parameter from `T.int32` to `T.int64` to match the existing caller convention that passes an int64 immediate. TIRX device-module grouping (python/tvm/tirx/build.py) ------------------------------------------------------ * Group device functions by `target.kind.name` instead of the stringified target object in `split_host_device_mods`. This avoids treating target attributes as separate backend kinds while still preserving a canonical target per backend kind. WebAssembly runtime (web/) -------------------------- * Reorder FFI includes in `web/emcc/wasm_runtime.cc` so `tvm_ffi::*` static initialisers run before `runtime::*` initialisers. * Extend `ArrayDecodeStorage` with a fall-through for payloads tagged `f32-to-bf16` whose byte length matches native float32, allowing native-f32 shards with that tag to decode correctly. * Add chunked tensor loading in `web/src/runtime.ts` for records whose `nbytes` exceed the per-call transfer budget. Also unpack `kTVMFFIShape` results so chunked loading can call `tensorCreateView` with explicit shape tuples. --- include/tvm/ir/attrs.h | 17 ++- .../frontend/nn/llm/position_embedding.py | 27 ++++- python/tvm/s_tir/dlight/gpu/matmul.py | 20 +++- python/tvm/tirx/build.py | 25 +++-- src/backend/webgpu/codegen/target_kind.cc | 1 + src/runtime/vm/paged_kv_cache.cc | 40 ++++--- tests/python/ir/test_ir_attrs.py | 15 +++ tests/python/s_tir/dlight/test_gpu_matmul.py | 11 ++ web/emcc/wasm_runtime.cc | 57 ++++++---- web/src/runtime.ts | 103 +++++++++++++++++- 10 files changed, 253 insertions(+), 63 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 96eec4616b4d..45912ca3103f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -199,7 +199,22 @@ class DictAttrs : public Attrs { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0).value_or(0) != 0; + const DictAttrsNode* node = get(); + auto it = node->dict.find(attr_key); + if (it == node->dict.end()) { + return false; + } + const ffi::Any& value = (*it).second; + if (auto opt_int = value.try_cast()) { + return opt_int.value() != 0; + } + if (auto opt_imm = value.try_cast()) { + return opt_imm.value()->value != 0; + } + if (auto opt_bool = value.try_cast()) { + return opt_bool.value(); + } + return false; } // Inline-expand TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE here, minus diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index e42cb55f4821..8e14a3d8ad18 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -67,9 +67,23 @@ def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtyp return cos_freq, sin_freq, {freq_var: freq} -def rope_freq_gptj(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str): - """Compute the inverse frequency of RoPE for gptj RoPE scaling.""" - freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(d_range, "float32")) +def rope_freq_gptj( + s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str, + freq_dim_base: int = 0, +): + """Compute the inverse frequency of RoPE for gptj RoPE scaling. + + Parameters + ---------- + freq_dim_base : int + If > 0, use this as the denominator in the frequency exponent instead + of d_range. This supports partial rotary embeddings where the frequency + base dimension (head_dim) differs from the number of rotated dimensions + (rotary_dim). E.g., Gemma 4 full-attention layers have head_dim=512 + but only rotate 128 dims (partial_rotary_factor=0.25). + """ + denom = freq_dim_base if freq_dim_base > 0 else d_range + freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(denom, "float32")) freq_var = tirx.Var("freq", "float32") cos_freq = tirx.cos(freq_var).astype(dtype) sin_freq = tirx.sin(freq_var).astype(dtype) @@ -262,6 +276,9 @@ def switch_rope_freq_func(rope_scaling: dict[str, Any]) -> Callable: if "rope_type" not in rope_scaling: return rope_freq_default if rope_scaling["rope_type"] == "gptj": + freq_dim_base = rope_scaling.get("freq_dim_base", 0) + if freq_dim_base > 0: + return partial(rope_freq_gptj, freq_dim_base=freq_dim_base) return rope_freq_gptj if rope_scaling["rope_type"] == "llama3": return partial( @@ -522,7 +539,7 @@ def _rope( # pylint: disable=too-many-arguments expr = tirx.Let(var, value, expr) return expr - @T.prim_func(s_tir=True) + @T.prim_func(private=True, s_tir=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, @@ -564,7 +581,7 @@ def fused_rope( # pylint: disable=too-many-locals else: v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] - @T.prim_func(s_tir=True) + @T.prim_func(private=True, s_tir=True) def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, diff --git a/python/tvm/s_tir/dlight/gpu/matmul.py b/python/tvm/s_tir/dlight/gpu/matmul.py index ef9392cc2f3c..acb37a43392e 100644 --- a/python/tvm/s_tir/dlight/gpu/matmul.py +++ b/python/tvm/s_tir/dlight/gpu/matmul.py @@ -359,7 +359,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) - root_block = get_root_block(sch) + try: + root_block = get_root_block(sch) + except ValueError: + return None blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) @@ -500,7 +503,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) - root_block = get_root_block(sch) + try: + root_block = get_root_block(sch) + except ValueError: + return None blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -721,7 +727,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) - root_block = get_root_block(sch) + try: + root_block = get_root_block(sch) + except ValueError: + return None blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -972,7 +981,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring return None sch = s_tir.Schedule(func) config = self.get_configs(target) - root_block = get_root_block(sch) + try: + root_block = get_root_block(sch) + except ValueError: + return None blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) diff --git a/python/tvm/tirx/build.py b/python/tvm/tirx/build.py index 10ec096bca79..46a1bfb7c9b3 100644 --- a/python/tvm/tirx/build.py +++ b/python/tvm/tirx/build.py @@ -104,18 +104,25 @@ def is_host_func(f): host_mod = tvm.tirx.transform.Filter(is_host_func)(mod) device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod) - # TODO(syfeng): Here we use str as key since target hash is not correct - target_str2target = {} - device_func_dict = {} + # Group device functions by target kind name (e.g. "webgpu", "cuda") rather + # than the full target string. Different TIR passes may attach slightly + # different target objects (e.g. with or without max_num_threads) to + # functions that should all end up in the same device module. Using the + # full str(target) as key splits them into separate modules, causing the + # later module to shadow the earlier one at runtime. + kind2target: dict[str, "Target"] = {} + kind2funcs: dict[str, dict] = {} device_mod_dict: dict[Target, IRModule] = {} for gv, func in device_mod.functions.items(): target = func.attrs.get("target", None) - target_str = str(target) if target is not None else "" - target_str2target[target_str] = target # This might be overridden by the last one - device_func_dict.setdefault(target_str, dict()).update({gv: func}) - for target_str in target_str2target.keys(): - target = target_str2target[target_str] - device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs) + kind = target.kind.name if target is not None else "" + # Keep the first target encountered for each kind as the canonical one + if kind not in kind2target: + kind2target[kind] = target + kind2funcs.setdefault(kind, dict()).update({gv: func}) + for kind in kind2target: + target = kind2target[kind] + device_mod_dict[target] = tvm.IRModule(kind2funcs[kind], attrs=device_mod.attrs) return host_mod, device_mod_dict diff --git a/src/backend/webgpu/codegen/target_kind.cc b/src/backend/webgpu/codegen/target_kind.cc index 0447bce49927..f95ae4b21d4d 100644 --- a/src/backend/webgpu/codegen/target_kind.cc +++ b/src/backend/webgpu/codegen/target_kind.cc @@ -58,6 +58,7 @@ void RegisterTargetKind() { .add_attr_option("max_num_threads", refl::DefaultValue(256)) .add_attr_option("supports_subgroups", refl::DefaultValue(false)) .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(32768)) .set_target_canonicalizer(UpdateWebGPUAttrs) .set_default_keys({"webgpu", "gpu"}); } diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index e5c4576e01c1..7978619f92a9 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -966,13 +966,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); } - if (append_before_attn_) { - // Right now we use different kernels when depth is 1 or not 1. - // For the case where maximum depth is 1, we create the auxiliary - // data structure with regard to the page table after appending. - for (int i = 0; i < cur_batch_size_; ++i) { - ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); - } + // Reserve pages BEFORE the aux-data loop unconditionally. The aux-data + // loop below reads `block.page_ids.size()` to populate page_indptr / + // page_indices / length_info, so pages must already be reserved in this + // call's blocks for the metadata to reflect the current prefill state. + // + // Previously the reserve was conditional on `append_before_attn_`: the + // `=true` branch reserved before the loop (correct), but the `=false` + // branch reserved after the loop, producing zero-page metadata for the + // first prefill into an empty cache. That broke models that perform + // intra-prefill shared-KV cross-attention (e.g. Gemma 4 layers 15-34 + // reading the K/V written by layers 13/14 inside the same prefill call): + // MHACrossAttnInternal saw `page_indices->shape[0] == 0` and skipped + // the entire computation, leaving the model-supplied `o_data` as + // uninitialised memory. + // + // The K/V-append timing is unchanged: the actual append (via + // `f_transpose_append_mha`) is still controlled by `append_before_attn_` + // at attention time, not by when page slots are reserved here. + for (int i = 0; i < cur_batch_size_; ++i) { + ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); } for (int d = 0; d < num_depths_; ++d) { @@ -1114,15 +1127,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - if (!append_before_attn_) { - // Right now we use different kernels when depth is 1 or not 1. - // For the case where maximum depth is not 1, we create the auxiliary - // data structure with regard to the page table before appending. - for (int i = 0; i < cur_batch_size_; ++i) { - ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); - } - } - // Map each the token position in the input batch to the position // in the global KV cache. The mapping is used in when appending k/v values. kv_len_arr_host_.clear(); @@ -1425,7 +1429,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. TVM_FFI_ICHECK(!dirty_aux_data_device_); - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); } else { MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); @@ -1464,7 +1468,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. TVM_FFI_ICHECK(!dirty_aux_data_device_); - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale, /*is_first_kernel=*/true); } else { diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py index 25480f726577..13fa30bf9332 100644 --- a/tests/python/ir/test_ir_attrs.py +++ b/tests/python/ir/test_ir_attrs.py @@ -19,6 +19,7 @@ import tvm_ffi import tvm +from tvm import tirx def test_dict_attrs(): @@ -57,6 +58,20 @@ def test_assert_structural_equal_reports_mismatch(): assert "and rhs at" in message +def test_dict_attrs_has_nonzero_attr_accepts_int_imm(): + arg = tirx.Var("arg", "handle") + func = tirx.PrimFunc([arg], tirx.Evaluate(0)).with_attr( + { + "global_symbol": "int_imm_noalias", + "tirx.noalias": tirx.IntImm("int32", 1), + } + ) + + tvm.compile(tvm.IRModule({"main": func}), target="c") + + if __name__ == "__main__": test_dict_attrs() test_attrs_equal() + test_assert_structural_equal_reports_mismatch() + test_dict_attrs_has_nonzero_attr_accepts_int_imm() diff --git a/tests/python/s_tir/dlight/test_gpu_matmul.py b/tests/python/s_tir/dlight/test_gpu_matmul.py index af23258e0191..18327f4d6e91 100644 --- a/tests/python/s_tir/dlight/test_gpu_matmul.py +++ b/tests/python/s_tir/dlight/test_gpu_matmul.py @@ -18,6 +18,7 @@ # ruff: noqa: E501 import tvm import tvm.testing +from tvm import tirx from tvm.s_tir import dlight as dl from tvm.script import tirx as T from tvm.target import Target @@ -309,6 +310,16 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer(( tvm.ir.assert_structural_equal(mod["main"], expected) +def test_matmul_rule_skips_non_root_block_helper_func(): + func = tirx.PrimFunc([], tirx.Evaluate(0)).with_attr("target", Target("webgpu")) + mod = tvm.IRModule({"main": func}) + + with Target("webgpu"): + scheduled = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + tvm.ir.assert_structural_equal(scheduled, mod) + + def test_skip_gemv(): # fmt: off @T.prim_func(private=True, s_tir=True) diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 9d3d46f18cb4..8574a86fb46d 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -32,6 +32,25 @@ #include #include +// FFI core must come before runtime .cc includes in this single translation +// unit. Otherwise, static initialisation can resolve ffi globals before +// object.cc registers them, leading to crashes at module init time. +#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc" +#include "3rdparty/tvm-ffi/src/ffi/container.cc" +#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "3rdparty/tvm-ffi/src/ffi/error.cc" +#include "3rdparty/tvm-ffi/src/ffi/function.cc" +#include "3rdparty/tvm-ffi/src/ffi/object.cc" +#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc" + #include "src/runtime/cpu_device_api.cc" #include "src/runtime/device_api.cc" #include "src/runtime/extra/contrib/sort/sort.cc" @@ -46,22 +65,6 @@ #include "src/runtime/tensor.cc" #include "src/runtime/timer.cc" #include "src/runtime/workspace_pool.cc" -// relax setup -#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc" -#include "3rdparty/tvm-ffi/src/ffi/container.cc" -#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" -#include "3rdparty/tvm-ffi/src/ffi/error.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" -#include "3rdparty/tvm-ffi/src/ffi/function.cc" -#include "3rdparty/tvm-ffi/src/ffi/object.cc" -#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" -#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc" #include "src/runtime/memory/memory_manager.cc" #include "src/runtime/vm/attn_backend.cc" #include "src/runtime/vm/builtin.cc" @@ -130,20 +133,28 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin const char* byte_data = bytes->data; const size_t byte_size = bytes->size; if (format == "f32-to-bf16" && dtype == "float32") { - const uint16_t* bf16 = reinterpret_cast(byte_data); - uint32_t* data = static_cast(cpu_arr->data); TVM_FFI_ICHECK(cpu_arr.IsContiguous()); size_t size = 1; for (int i = 0; i < cpu_arr->ndim; ++i) { size *= cpu_arr->shape[i]; } - TVM_FFI_ICHECK_EQ(size, byte_size / 2); - for (size_t i = 0; i < size; ++i) { - data[i] = static_cast(bf16[i]) << 16; + // The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2 + // bytes per element). When the byte_size matches that expectation, expand + // back to f32. If the byte_size matches the native float32 width + // (4 bytes per element), the payload is already raw float32 — fall through + // to the generic byte copy. This makes the loader tolerant of weight + // shards produced by older / alternate quantisation pipelines that retain + // the "f32-to-bf16" tag without performing the bf16 truncation. + if (size == byte_size / 2) { + const uint16_t* bf16 = reinterpret_cast(byte_data); + uint32_t* data = static_cast(cpu_arr->data); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(bf16[i]) << 16; + } + return; } - } else { - cpu_arr.CopyFromBytes(byte_data, byte_size); } + cpu_arr.CopyFromBytes(byte_data, byte_size); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 078a0c7df21f..eb236eb0be04 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1323,6 +1323,7 @@ export class Instance implements Disposable { artifactCache: ArtifactCacheTemplate, signal?: AbortSignal, ) { + const maxChunkBytes = 128 * 1024 * 1024; const perf = compact.getPerformance(); const tstart = perf.now(); let totalBytes = 0; @@ -1421,9 +1422,53 @@ export class Instance implements Disposable { this.empty(rec.shape, rec.dtype, this.cpu()) ) }); - const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + const recSource = + rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength + ? shardBytes + : shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes); + const canChunkRecord = + rec.nbytes > maxChunkBytes && + rec.shape.length >= 1 && + Number.isInteger(rec.shape[0]) && + rec.shape[0] > 0 && + rec.nbytes % rec.shape[0] === 0; + const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => { + if (!canChunkRecord) { + this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype); + return; + } + const outerDim = rec.shape[0]; + const chunkStrideBytes = rec.nbytes / outerDim; + const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes)); + for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { + const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); + const chunkByteOffset = outerOffset * chunkStrideBytes; + const chunkBytes = outerCount * chunkStrideBytes; + const chunkShape = rec.shape.slice(); + chunkShape[0] = outerCount; + // Wrap in withNewScope so TVM intermediate objects (shape tuple) + // are disposed after each chunk, but detach the view we need. + const chunkView = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.ctx.tensorCreateView( + targetTensor, + this.ctx.makeShapeTuple(...chunkShape.map((value) => new Scalar(value, "int"))), + rec.dtype, + new Scalar(chunkByteOffset, "int"), + ) + ); + }); + const chunkSource = sourceBytes.subarray(chunkByteOffset, chunkByteOffset + chunkBytes); + try { + this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype); + } finally { + chunkView.dispose(); + } + } + }; // first sync copy to cpu. - this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); + copyRecordToTensor(cpu_arr, recSource); // then async stream into GPU if needed if (device.deviceType === DeviceStrToEnum.cpu) { this.tensorCacheUpdate(rec.name, cpu_arr, false); @@ -1435,7 +1480,40 @@ export class Instance implements Disposable { this.empty(rec.shape, rec.dtype, device) ) }); - gpu_arr.copyFrom(cpu_arr); + if (!canChunkRecord) { + gpu_arr.copyFrom(cpu_arr); + } else { + const outerDim = rec.shape[0]; + const chunkStrideBytes = rec.nbytes / outerDim; + const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes)); + for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { + const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); + const chunkByteOffset = outerOffset * chunkStrideBytes; + const chunkShape = rec.shape.slice(); + chunkShape[0] = outerCount; + // Use withNewScope so the shape tuple is auto-disposed, + // and detach the views we need for manual lifetime control. + const [cpuView, gpuView] = this.withNewScope(() => { + const chunkShapeTuple = this.ctx.makeShapeTuple( + ...chunkShape.map((value) => new Scalar(value, "int")), + ); + return [ + this.detachFromCurrentScope( + this.ctx.tensorCreateView(cpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int")) + ), + this.detachFromCurrentScope( + this.ctx.tensorCreateView(gpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int")) + ), + ]; + }); + try { + gpuView.copyFrom(cpuView); + } finally { + cpuView.dispose(); + gpuView.dispose(); + } + } + } await device.sync(); this.tensorCacheUpdate(rec.name, gpu_arr, false); cpu_arr.dispose(); @@ -2258,6 +2336,25 @@ export class Instance implements Disposable { case TypeIndex.kTVMFFIOpaquePtr: { return this.memory.loadPointer(valuePtr); } + case TypeIndex.kTVMFFIShape: { + const shapeObjPtr = this.memory.loadPointer(valuePtr); + if (callbackArg) { + const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader; + const shapeDataPtr = this.memory.loadPointer(shapeCellPtr); + const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr()); + const result = new Array(shapeLen); + for (let i = 0; i < shapeLen; ++i) { + result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64); + } + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr) + ); + return result; + } + return this.ctx.attachToCurrentScope( + new TVMObject(shapeObjPtr, this.lib, this.ctx) + ); + } case TypeIndex.kTVMFFITensor: { return this.ctx.attachToCurrentScope( new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)