Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/backend/cpu_generic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.CanonicalizeShapeExpr(),
relax.transform.ComputePrimValue(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/backend/cuda/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.CanonicalizeShapeExpr(),
relax.transform.ComputePrimValue(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/backend/gpu_generic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.CanonicalizeShapeExpr(),
relax.transform.ComputePrimValue(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/backend/rocm/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume
relax.transform.LowerAllocTensor(),
relax.transform.KillAfterLastUse(),
relax.transform.LowerRuntimeBuiltin(),
relax.transform.CanonicalizeShapeExpr(),
relax.transform.ComputePrimValue(),
relax.transform.VMShapeLower(),
relax.transform.AttachGlobalSymbol(),
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
BundleModelParams,
CallTIRRewrite,
CanonicalizeBindings,
CanonicalizeShapeExpr,
CombineParallelMatmul,
ComputePrimValue,
ConvertLayout,
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,32 @@ def FoldConstant() -> tvm.ir.transform.Pass:
return _ffi_api.FoldConstant() # type: ignore


def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass:
"""Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables.

VMShapeLower can only handle ShapeExpr where each dimension is either:
- IntImm (concrete integer constant)
- tir::Var (symbolic variable from function parameters or match_cast)

This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by:
1. Creating a fresh tir::Var for each compound expression
2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression
3. Replacing the compound expression in ShapeExpr with the fresh var

Example transformation:
Before: y = R.zeros(R.shape([n + 1]), dtype="float32")
After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0))
y = R.zeros(R.shape([_s0]), dtype="float32")

This pass should be applied before ComputePrimValue and before VMShapeLower.

Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CanonicalizeShapeExpr() # type: ignore


def ExpandTupleArguments() -> tvm.ir.transform.Pass:
"""Expand tuple arguments to internal functions

Expand Down
129 changes: 129 additions & 0 deletions src/relax/transform/canonicalize_shape_expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/transform/canonicalize_shape_expr.cc
* \brief Canonicalize ShapeExpr by replacing composite PrimExpr dimensions with symbolic vars.
*/

#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

#include <string>
#include <unordered_map>

namespace tvm {
namespace relax {

namespace {

bool IsSimpleShapeDim(const PrimExpr& expr) {
return expr->IsInstance<IntImmNode>() || expr->IsInstance<tirx::VarNode>();
}

class ShapeExprCanonicalizer : public ExprMutator {
public:
using ExprMutator::VisitExpr_;

BindingBlock VisitBindingBlock(const BindingBlock& block) final {
bool prev = inside_binding_block_;
inside_binding_block_ = true;
BindingBlock ret = ExprMutator::VisitBindingBlock(block);
inside_binding_block_ = prev;
return ret;
}
Comment on lines +47 to +53

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If a local function (lambda) is defined inside a binding block, inside_binding_block_ will remain true when visiting the local function's parameters and return type. This can lead to incorrect canonicalization of the nested function's parameters in the outer function's scope, emitting MatchCast bindings into the wrong block. We should override VisitExpr_(const FunctionNode* op) to temporarily reset inside_binding_block_ to false when entering a local function.

  BindingBlock VisitBindingBlock(const BindingBlock& block) final {
    bool prev = inside_binding_block_;
    inside_binding_block_ = true;
    BindingBlock ret = ExprMutator::VisitBindingBlock(block);
    inside_binding_block_ = prev;
    return ret;
  }

  Expr VisitExpr_(const FunctionNode* op) final {
    bool prev = inside_binding_block_;
    inside_binding_block_ = false;
    Expr ret = ExprMutator::VisitExpr_(op);
    inside_binding_block_ = prev;
    return ret;
  }


Expr VisitExpr_(const FunctionNode* op) final {
bool prev = inside_binding_block_;
inside_binding_block_ = false;
Expr ret = ExprMutator::VisitExpr_(op);
inside_binding_block_ = prev;
return ret;
}

Expr VisitExpr_(const ShapeExprNode* op) final {
if (!inside_binding_block_) {
return ffi::GetRef<ShapeExpr>(op);
}

ffi::Array<PrimExpr> new_values;
bool changed = false;
for (const PrimExpr& dim : op->values) {
if (IsSimpleShapeDim(dim)) {
new_values.push_back(dim);
continue;
}

changed = true;
new_values.push_back(GetOrCreateSymbol(dim));
}

if (!changed) {
return ffi::GetRef<ShapeExpr>(op);
}
return ShapeExpr(new_values, op->span);
}

private:
tirx::Var GetOrCreateSymbol(const PrimExpr& expr) {
auto it = expr_to_var_.find(expr);
if (it != expr_to_var_.end()) {
return it->second;
}

std::string base_name = "shape_expr_symbol_" + std::to_string(symbol_counter_++);
tirx::Var sym_var(base_name, expr->dtype);
expr_to_var_.emplace(expr, sym_var);

PrimStructInfo target_sinfo(sym_var);
Var match_var(base_name + "_pv", target_sinfo);
builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo));
return sym_var;
}

int symbol_counter_ = 0;
bool inside_binding_block_ = false;
std::unordered_map<PrimExpr, tirx::Var, ffi::StructuralHash, ffi::StructuralEqual> expr_to_var_;
Comment thread
cchung100m marked this conversation as resolved.
};
Comment thread
cchung100m marked this conversation as resolved.

} // namespace

namespace transform {

Pass CanonicalizeShapeExpr() {
auto pass_func = [](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ShapeExprCanonicalizer()(std::move(f)));
};
return CreateFunctionPass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"CanonicalizeShapeExpr",
/*required=*/{});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr);
}

} // namespace transform
} // namespace relax
} // namespace tvm
101 changes: 101 additions & 0 deletions tests/python/relax/test_transform_canonicalize_shape_expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm.testing
from tvm import relax, tirx
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T


@I.ir_module
class Before:
@R.function
def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")):
R.func_attr({"relax.force_pure": True})
x_0, x_1, x_2, x_3 = T.int64(), T.int64(), T.int64(), T.int64()
out: R.Tensor((T.int64(4) * (x_0 * x_1 * x_2 * x_3),), "float32") = R.zeros(
R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]), dtype="float32"
)
return out


def test_canonicalize_shape_expr_removes_composite_dims():
mod = relax.transform.CanonicalizeShapeExpr()(Before)
composite_dims = []

def _visit(expr):
if isinstance(expr, relax.ShapeExpr):
for dim in expr.values:
if not isinstance(dim, (tirx.IntImm, tirx.Var)):
composite_dims.append(dim)

relax.analysis.post_order_visit(mod["main"], _visit)
assert not composite_dims


def test_canonicalize_shape_expr_unblocks_vm_shape_lower():
mod = Before
mod = relax.transform.CanonicalizeShapeExpr()(mod)
mod = relax.transform.ComputePrimValue()(mod)
mod = relax.transform.VMShapeLower()(mod)

assert any("compute_symbolic_expr" in gv.name_hint for gv in mod.get_global_vars())


@I.ir_module
class ParamCompoundShape:
@R.function
def main(x: R.Tensor(("A", "B", "A + B"), "float32")) -> R.Tensor((1,), "float32"):
out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32")
return out


def test_canonicalize_shape_expr_skips_parameter_struct_info():
mod = relax.transform.CanonicalizeShapeExpr()(ParamCompoundShape)
param_shape = mod["main"].params[0].struct_info.shape

assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values)


@I.ir_module
class NestedParamCompound:
@R.function
def main(x: R.Tensor(("n",), "float32")):
@R.function
def inner(y: R.Tensor(("a", "b", "a + b"), "float32")) -> R.Tensor((1,), "float32"):
out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32")
return out

return inner


def test_canonicalize_shape_expr_nested_function_no_block_pollution():
mod = relax.transform.CanonicalizeShapeExpr()(NestedParamCompound)

leaked = [
binding
for block in mod["main"].body.blocks
for binding in block.bindings
if isinstance(binding, relax.MatchCast)
]

assert not leaked


if __name__ == "__main__":
tvm.testing.main()
Comment on lines +74 to +101

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a test case to verify that nested functions do not have their parameters incorrectly canonicalized in the outer function's scope.

@I.ir_module
class NestedFunc:
    @R.function
    def main(x: R.Tensor(("n",), "float32")):
        with R.dataflow():
            @R.function
            def local_func(y: R.Tensor(("n + 1",), "float32")):
                out = R.zeros(R.shape([n + 1]), dtype="float32")
                return out
            res = local_func(x)
            R.output(res)
        return res


def test_canonicalize_shape_expr_nested_function():
    mod = relax.transform.CanonicalizeShapeExpr()(NestedFunc)
    local_func = mod["main"].body.blocks[0].bindings[0].value
    param_shape = local_func.params[0].struct_info.shape
    assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values)


if __name__ == "__main__":
    tvm.testing.main()

Loading