diff --git a/python/tvm/relax/backend/cpu_generic/pipeline.py b/python/tvm/relax/backend/cpu_generic/pipeline.py index d0b819cea7f8..1a2bd5446f26 100644 --- a/python/tvm/relax/backend/cpu_generic/pipeline.py +++ b/python/tvm/relax/backend/cpu_generic/pipeline.py @@ -56,6 +56,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/cuda/pipeline.py b/python/tvm/relax/backend/cuda/pipeline.py index e3d8b66c1271..f82ac7173a24 100644 --- a/python/tvm/relax/backend/cuda/pipeline.py +++ b/python/tvm/relax/backend/cuda/pipeline.py @@ -66,6 +66,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/gpu_generic/pipeline.py b/python/tvm/relax/backend/gpu_generic/pipeline.py index a86ec6f31109..32ebd28400e4 100644 --- a/python/tvm/relax/backend/gpu_generic/pipeline.py +++ b/python/tvm/relax/backend/gpu_generic/pipeline.py @@ -65,6 +65,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/backend/rocm/pipeline.py b/python/tvm/relax/backend/rocm/pipeline.py index a57172829a91..64dae9f78f57 100644 --- a/python/tvm/relax/backend/rocm/pipeline.py +++ b/python/tvm/relax/backend/rocm/pipeline.py @@ -65,6 +65,7 @@ def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume relax.transform.LowerAllocTensor(), relax.transform.KillAfterLastUse(), relax.transform.LowerRuntimeBuiltin(), + relax.transform.CanonicalizeShapeExpr(), relax.transform.ComputePrimValue(), relax.transform.VMShapeLower(), relax.transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index c3188adf5027..bf42d371b66c 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -29,6 +29,7 @@ BundleModelParams, CallTIRRewrite, CanonicalizeBindings, + CanonicalizeShapeExpr, CombineParallelMatmul, ComputePrimValue, ConvertLayout, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index a291fb973730..49bad34a10e6 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -739,6 +739,32 @@ def FoldConstant() -> tvm.ir.transform.Pass: return _ffi_api.FoldConstant() # type: ignore +def CanonicalizeShapeExpr() -> tvm.ir.transform.Pass: + """Canonicalize ShapeExpr by replacing compound PrimExpr with fresh symbolic variables. + + VMShapeLower can only handle ShapeExpr where each dimension is either: + - IntImm (concrete integer constant) + - tir::Var (symbolic variable from function parameters or match_cast) + + This pass transforms compound PrimExpr (e.g., n+1, 4*n*m) by: + 1. Creating a fresh tir::Var for each compound expression + 2. Emitting a MatchCast that binds the fresh var to a PrimValue computing the expression + 3. Replacing the compound expression in ShapeExpr with the fresh var + + Example transformation: + Before: y = R.zeros(R.shape([n + 1]), dtype="float32") + After: _s0_pv: R.Prim(value=_s0) = R.match_cast(R.prim_value(n+1), R.Prim(value=_s0)) + y = R.zeros(R.shape([_s0]), dtype="float32") + + This pass should be applied before ComputePrimValue and before VMShapeLower. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CanonicalizeShapeExpr() # type: ignore + + def ExpandTupleArguments() -> tvm.ir.transform.Pass: """Expand tuple arguments to internal functions diff --git a/src/relax/transform/canonicalize_shape_expr.cc b/src/relax/transform/canonicalize_shape_expr.cc new file mode 100644 index 000000000000..10186e9c78af --- /dev/null +++ b/src/relax/transform/canonicalize_shape_expr.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/transform/canonicalize_shape_expr.cc + * \brief Canonicalize ShapeExpr by replacing composite PrimExpr dimensions with symbolic vars. + */ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +bool IsSimpleShapeDim(const PrimExpr& expr) { + return expr->IsInstance() || expr->IsInstance(); +} + +class ShapeExprCanonicalizer : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + bool prev = inside_binding_block_; + inside_binding_block_ = true; + BindingBlock ret = ExprMutator::VisitBindingBlock(block); + inside_binding_block_ = prev; + return ret; + } + + Expr VisitExpr_(const FunctionNode* op) final { + bool prev = inside_binding_block_; + inside_binding_block_ = false; + Expr ret = ExprMutator::VisitExpr_(op); + inside_binding_block_ = prev; + return ret; + } + + Expr VisitExpr_(const ShapeExprNode* op) final { + if (!inside_binding_block_) { + return ffi::GetRef(op); + } + + ffi::Array new_values; + bool changed = false; + for (const PrimExpr& dim : op->values) { + if (IsSimpleShapeDim(dim)) { + new_values.push_back(dim); + continue; + } + + changed = true; + new_values.push_back(GetOrCreateSymbol(dim)); + } + + if (!changed) { + return ffi::GetRef(op); + } + return ShapeExpr(new_values, op->span); + } + + private: + tirx::Var GetOrCreateSymbol(const PrimExpr& expr) { + auto it = expr_to_var_.find(expr); + if (it != expr_to_var_.end()) { + return it->second; + } + + std::string base_name = "shape_expr_symbol_" + std::to_string(symbol_counter_++); + tirx::Var sym_var(base_name, expr->dtype); + expr_to_var_.emplace(expr, sym_var); + + PrimStructInfo target_sinfo(sym_var); + Var match_var(base_name + "_pv", target_sinfo); + builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo)); + return sym_var; + } + + int symbol_counter_ = 0; + bool inside_binding_block_ = false; + std::unordered_map expr_to_var_; +}; + +} // namespace + +namespace transform { + +Pass CanonicalizeShapeExpr() { + auto pass_func = [](Function f, IRModule m, PassContext pc) { + return Downcast(ShapeExprCanonicalizer()(std::move(f))); + }; + return CreateFunctionPass(/*pass_function=*/pass_func, + /*opt_level=*/0, + /*pass_name=*/"CanonicalizeShapeExpr", + /*required=*/{}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr); +} + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_canonicalize_shape_expr.py b/tests/python/relax/test_transform_canonicalize_shape_expr.py new file mode 100644 index 000000000000..8ecfee954059 --- /dev/null +++ b/tests/python/relax/test_transform_canonicalize_shape_expr.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm.testing +from tvm import relax, tirx +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tirx as T + + +@I.ir_module +class Before: + @R.function + def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")): + R.func_attr({"relax.force_pure": True}) + x_0, x_1, x_2, x_3 = T.int64(), T.int64(), T.int64(), T.int64() + out: R.Tensor((T.int64(4) * (x_0 * x_1 * x_2 * x_3),), "float32") = R.zeros( + R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]), dtype="float32" + ) + return out + + +def test_canonicalize_shape_expr_removes_composite_dims(): + mod = relax.transform.CanonicalizeShapeExpr()(Before) + composite_dims = [] + + def _visit(expr): + if isinstance(expr, relax.ShapeExpr): + for dim in expr.values: + if not isinstance(dim, (tirx.IntImm, tirx.Var)): + composite_dims.append(dim) + + relax.analysis.post_order_visit(mod["main"], _visit) + assert not composite_dims + + +def test_canonicalize_shape_expr_unblocks_vm_shape_lower(): + mod = Before + mod = relax.transform.CanonicalizeShapeExpr()(mod) + mod = relax.transform.ComputePrimValue()(mod) + mod = relax.transform.VMShapeLower()(mod) + + assert any("compute_symbolic_expr" in gv.name_hint for gv in mod.get_global_vars()) + + +@I.ir_module +class ParamCompoundShape: + @R.function + def main(x: R.Tensor(("A", "B", "A + B"), "float32")) -> R.Tensor((1,), "float32"): + out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32") + return out + + +def test_canonicalize_shape_expr_skips_parameter_struct_info(): + mod = relax.transform.CanonicalizeShapeExpr()(ParamCompoundShape) + param_shape = mod["main"].params[0].struct_info.shape + + assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values) + + +@I.ir_module +class NestedParamCompound: + @R.function + def main(x: R.Tensor(("n",), "float32")): + @R.function + def inner(y: R.Tensor(("a", "b", "a + b"), "float32")) -> R.Tensor((1,), "float32"): + out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32") + return out + + return inner + + +def test_canonicalize_shape_expr_nested_function_no_block_pollution(): + mod = relax.transform.CanonicalizeShapeExpr()(NestedParamCompound) + + leaked = [ + binding + for block in mod["main"].body.blocks + for binding in block.bindings + if isinstance(binding, relax.MatchCast) + ] + + assert not leaked + + +if __name__ == "__main__": + tvm.testing.main()