Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,22 @@ class DictAttrs : public Attrs {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<int64_t>(attr_key, 0).value_or(0) != 0;
const DictAttrsNode* node = get();
auto it = node->dict.find(attr_key);
if (it == node->dict.end()) {
return false;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are moving away from IntImm attr

}
const ffi::Any& value = (*it).second;
if (auto opt_int = value.try_cast<int64_t>()) {
return opt_int.value() != 0;
}
if (auto opt_imm = value.try_cast<IntImm>()) {
return opt_imm.value()->value != 0;
}
if (auto opt_bool = value.try_cast<bool>()) {
return opt_bool.value();
}
return false;
}

// Inline-expand TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE here, minus
Expand Down
27 changes: 22 additions & 5 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,23 @@ def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtyp
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_gptj(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str):
"""Compute the inverse frequency of RoPE for gptj RoPE scaling."""
freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(d_range, "float32"))
def rope_freq_gptj(
s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str,
freq_dim_base: int = 0,
):
"""Compute the inverse frequency of RoPE for gptj RoPE scaling.

Parameters
----------
freq_dim_base : int
If > 0, use this as the denominator in the frequency exponent instead
of d_range. This supports partial rotary embeddings where the frequency
base dimension (head_dim) differs from the number of rotated dimensions
(rotary_dim). E.g., Gemma 4 full-attention layers have head_dim=512
but only rotate 128 dims (partial_rotary_factor=0.25).
"""
denom = freq_dim_base if freq_dim_base > 0 else d_range
freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(denom, "float32"))
freq_var = tirx.Var("freq", "float32")
cos_freq = tirx.cos(freq_var).astype(dtype)
sin_freq = tirx.sin(freq_var).astype(dtype)
Expand Down Expand Up @@ -262,6 +276,9 @@ def switch_rope_freq_func(rope_scaling: dict[str, Any]) -> Callable:
if "rope_type" not in rope_scaling:
return rope_freq_default
if rope_scaling["rope_type"] == "gptj":
freq_dim_base = rope_scaling.get("freq_dim_base", 0)
if freq_dim_base > 0:
return partial(rope_freq_gptj, freq_dim_base=freq_dim_base)
return rope_freq_gptj
if rope_scaling["rope_type"] == "llama3":
return partial(
Expand Down Expand Up @@ -522,7 +539,7 @@ def _rope( # pylint: disable=too-many-arguments
expr = tirx.Let(var, value, expr)
return expr

@T.prim_func(s_tir=True)
@T.prim_func(private=True, s_tir=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
Expand Down Expand Up @@ -564,7 +581,7 @@ def fused_rope( # pylint: disable=too-many-locals
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

@T.prim_func(s_tir=True)
@T.prim_func(private=True, s_tir=True)
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
Expand Down
20 changes: 16 additions & 4 deletions python/tvm/s_tir/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target):
return None
sch = s_tir.Schedule(func)
root_block = get_root_block(sch)
try:
root_block = get_root_block(sch)
except ValueError:
return None
blocks = sch.get_child_blocks(root_block)

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down Expand Up @@ -500,7 +503,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target):
return None
sch = s_tir.Schedule(func)
root_block = get_root_block(sch)
try:
root_block = get_root_block(sch)
except ValueError:
return None
blocks = sch.get_child_blocks(root_block)

if "dlight.do_not_tensorize" in func.attrs.keys():
Expand Down Expand Up @@ -721,7 +727,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target):
return None
sch = s_tir.Schedule(func)
root_block = get_root_block(sch)
try:
root_block = get_root_block(sch)
except ValueError:
return None
blocks = sch.get_child_blocks(root_block)

if "dlight.do_not_tensorize" in func.attrs.keys():
Expand Down Expand Up @@ -972,7 +981,10 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
return None
sch = s_tir.Schedule(func)
config = self.get_configs(target)
root_block = get_root_block(sch)
try:
root_block = get_root_block(sch)
except ValueError:
return None
blocks = sch.get_child_blocks(root_block)

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down
25 changes: 16 additions & 9 deletions python/tvm/tirx/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,25 @@ def is_host_func(f):

host_mod = tvm.tirx.transform.Filter(is_host_func)(mod)
device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod)
# TODO(syfeng): Here we use str as key since target hash is not correct
target_str2target = {}
device_func_dict = {}
# Group device functions by target kind name (e.g. "webgpu", "cuda") rather
# than the full target string. Different TIR passes may attach slightly
# different target objects (e.g. with or without max_num_threads) to
# functions that should all end up in the same device module. Using the
# full str(target) as key splits them into separate modules, causing the
# later module to shadow the earlier one at runtime.
kind2target: dict[str, "Target"] = {}
kind2funcs: dict[str, dict] = {}
device_mod_dict: dict[Target, IRModule] = {}
for gv, func in device_mod.functions.items():
target = func.attrs.get("target", None)
target_str = str(target) if target is not None else ""
target_str2target[target_str] = target # This might be overridden by the last one
device_func_dict.setdefault(target_str, dict()).update({gv: func})
for target_str in target_str2target.keys():
target = target_str2target[target_str]
device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
kind = target.kind.name if target is not None else ""
# Keep the first target encountered for each kind as the canonical one
if kind not in kind2target:
kind2target[kind] = target
kind2funcs.setdefault(kind, dict()).update({gv: func})
for kind in kind2target:
target = kind2target[kind]
device_mod_dict[target] = tvm.IRModule(kind2funcs[kind], attrs=device_mod.attrs)
return host_mod, device_mod_dict


Expand Down
1 change: 1 addition & 0 deletions src/backend/webgpu/codegen/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void RegisterTargetKind() {
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.add_attr_option<int64_t>("max_shared_memory_per_block", refl::DefaultValue(32768))
.set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});
}
Expand Down
40 changes: 22 additions & 18 deletions src/runtime/vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -966,13 +966,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
}

if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
// data structure with regard to the page table after appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
// Reserve pages BEFORE the aux-data loop unconditionally. The aux-data
// loop below reads `block.page_ids.size()` to populate page_indptr /
// page_indices / length_info, so pages must already be reserved in this
// call's blocks for the metadata to reflect the current prefill state.
//
// Previously the reserve was conditional on `append_before_attn_`: the
// `=true` branch reserved before the loop (correct), but the `=false`
// branch reserved after the loop, producing zero-page metadata for the
// first prefill into an empty cache. That broke models that perform
// intra-prefill shared-KV cross-attention (e.g. Gemma 4 layers 15-34
// reading the K/V written by layers 13/14 inside the same prefill call):
// MHACrossAttnInternal saw `page_indices->shape[0] == 0` and skipped
// the entire computation, leaving the model-supplied `o_data` as
// uninitialised memory.
//
// The K/V-append timing is unchanged: the actual append (via
// `f_transpose_append_mha`) is still controlled by `append_before_attn_`
// at attention time, not by when page slots are reserved here.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}

for (int d = 0; d < num_depths_; ++d) {
Expand Down Expand Up @@ -1114,15 +1127,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

if (!append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is not 1, we create the auxiliary
// data structure with regard to the page table before appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}

// Map each the token position in the input batch to the position
// in the global KV cache. The mapping is used in when appending k/v values.
kv_len_arr_host_.clear();
Expand Down Expand Up @@ -1425,7 +1429,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);

if (attn_kind == AttnKind::kMHA) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
} else {
MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
Expand Down Expand Up @@ -1464,7 +1468,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);

if (attn_kind == AttnKind::kMHA) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale,
/*is_first_kernel=*/true);
} else {
Expand Down
15 changes: 15 additions & 0 deletions tests/python/ir/test_ir_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm_ffi

import tvm
from tvm import tirx


def test_dict_attrs():
Expand Down Expand Up @@ -57,6 +58,20 @@ def test_assert_structural_equal_reports_mismatch():
assert "and rhs at" in message


def test_dict_attrs_has_nonzero_attr_accepts_int_imm():
arg = tirx.Var("arg", "handle")
func = tirx.PrimFunc([arg], tirx.Evaluate(0)).with_attr(
{
"global_symbol": "int_imm_noalias",
"tirx.noalias": tirx.IntImm("int32", 1),
}
)

tvm.compile(tvm.IRModule({"main": func}), target="c")


if __name__ == "__main__":
test_dict_attrs()
test_attrs_equal()
test_assert_structural_equal_reports_mismatch()
test_dict_attrs_has_nonzero_attr_accepts_int_imm()
11 changes: 11 additions & 0 deletions tests/python/s_tir/dlight/test_gpu_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# ruff: noqa: E501
import tvm
import tvm.testing
from tvm import tirx
from tvm.s_tir import dlight as dl
from tvm.script import tirx as T
from tvm.target import Target
Expand Down Expand Up @@ -309,6 +310,16 @@ def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((
tvm.ir.assert_structural_equal(mod["main"], expected)


def test_matmul_rule_skips_non_root_block_helper_func():
func = tirx.PrimFunc([], tirx.Evaluate(0)).with_attr("target", Target("webgpu"))
mod = tvm.IRModule({"main": func})

with Target("webgpu"):
scheduled = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)

tvm.ir.assert_structural_equal(scheduled, mod)


def test_skip_gemv():
# fmt: off
@T.prim_func(private=True, s_tir=True)
Expand Down
57 changes: 34 additions & 23 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>

// FFI core must come before runtime .cc includes in this single translation
// unit. Otherwise, static initialisation can resolve ffi globals before
// object.cc registers them, leading to crashes at module init time.
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"

#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/device_api.cc"
#include "src/runtime/extra/contrib/sort/sort.cc"
Expand All @@ -46,22 +65,6 @@
#include "src/runtime/tensor.cc"
#include "src/runtime/timer.cc"
#include "src/runtime/workspace_pool.cc"
// relax setup
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"
#include "src/runtime/memory/memory_manager.cc"
#include "src/runtime/vm/attn_backend.cc"
#include "src/runtime/vm/builtin.cc"
Expand Down Expand Up @@ -130,20 +133,28 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin
const char* byte_data = bytes->data;
const size_t byte_size = bytes->size;
if (format == "f32-to-bf16" && dtype == "float32") {
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
TVM_FFI_ICHECK(cpu_arr.IsContiguous());
size_t size = 1;
for (int i = 0; i < cpu_arr->ndim; ++i) {
size *= cpu_arr->shape[i];
}
TVM_FFI_ICHECK_EQ(size, byte_size / 2);
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
// The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2
// bytes per element). When the byte_size matches that expectation, expand
// back to f32. If the byte_size matches the native float32 width
// (4 bytes per element), the payload is already raw float32 — fall through
// to the generic byte copy. This makes the loader tolerant of weight
// shards produced by older / alternate quantisation pipelines that retain
// the "f32-to-bf16" tag without performing the bf16 truncation.
if (size == byte_size / 2) {
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
}
return;
}
} else {
cpu_arr.CopyFromBytes(byte_data, byte_size);
}
cpu_arr.CopyFromBytes(byte_data, byte_size);
}

TVM_FFI_STATIC_INIT_BLOCK() {
Expand Down
Loading