-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Relax][Backend] Fix TVM crashes with default relax pipeline when opt_level=1: InternalError: Check failed: (slot->value_computed) is false #18491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
cchung100m
wants to merge
13
commits into
apache:main
Choose a base branch
from
cchung100m:issue-17876
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
af88d40
[#17876]Fix TVM crashes with default relax pipeline when opt_level=1:…
cchung100m a6487f6
Add test case: test_composite_shape_expression
cchung100m 10094bb
Revert the incorrect solution
cchung100m 458e9ba
new pass: ShapeExprCanonicalizer
cchung100m 89a5e0a
Add unit test case for new pass ShapeExprCanonicalizer
cchung100m 3eb8b5f
Refactor: CanonicalizeShapeExpr
cchung100m 865bb7b
Refactor the canonicalize_shape_expr.cc
cchung100m c51a135
Fix lint error UP038
cchung100m 599d593
Fix include error
cchung100m e17ecef
Fix GetOrCreateSymbol crash
cchung100m a31b7c2
Remove the unnecessary handling
cchung100m b997045
Fix builder_->EmitNormalized outside an active binding block would crash
cchung100m 9d421ed
Add test case: test_canonicalize_shape_expr_nested_function_no_block_…
cchung100m File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| /*! | ||
| * \file src/relax/transform/canonicalize_shape_expr.cc | ||
| * \brief Canonicalize ShapeExpr by replacing composite PrimExpr dimensions with symbolic vars. | ||
| */ | ||
|
|
||
| #include <tvm/ffi/reflection/registry.h> | ||
| #include <tvm/relax/analysis.h> | ||
| #include <tvm/relax/expr.h> | ||
| #include <tvm/relax/expr_functor.h> | ||
| #include <tvm/relax/transform.h> | ||
|
|
||
| #include <string> | ||
| #include <unordered_map> | ||
|
|
||
| namespace tvm { | ||
| namespace relax { | ||
|
|
||
| namespace { | ||
|
|
||
| bool IsSimpleShapeDim(const PrimExpr& expr) { | ||
| return expr->IsInstance<IntImmNode>() || expr->IsInstance<tirx::VarNode>(); | ||
| } | ||
|
|
||
| class ShapeExprCanonicalizer : public ExprMutator { | ||
| public: | ||
| using ExprMutator::VisitExpr_; | ||
|
|
||
| BindingBlock VisitBindingBlock(const BindingBlock& block) final { | ||
| bool prev = inside_binding_block_; | ||
| inside_binding_block_ = true; | ||
| BindingBlock ret = ExprMutator::VisitBindingBlock(block); | ||
| inside_binding_block_ = prev; | ||
| return ret; | ||
| } | ||
|
|
||
| Expr VisitExpr_(const FunctionNode* op) final { | ||
| bool prev = inside_binding_block_; | ||
| inside_binding_block_ = false; | ||
| Expr ret = ExprMutator::VisitExpr_(op); | ||
| inside_binding_block_ = prev; | ||
| return ret; | ||
| } | ||
|
|
||
| Expr VisitExpr_(const ShapeExprNode* op) final { | ||
| if (!inside_binding_block_) { | ||
| return ffi::GetRef<ShapeExpr>(op); | ||
| } | ||
|
|
||
| ffi::Array<PrimExpr> new_values; | ||
| bool changed = false; | ||
| for (const PrimExpr& dim : op->values) { | ||
| if (IsSimpleShapeDim(dim)) { | ||
| new_values.push_back(dim); | ||
| continue; | ||
| } | ||
|
|
||
| changed = true; | ||
| new_values.push_back(GetOrCreateSymbol(dim)); | ||
| } | ||
|
|
||
| if (!changed) { | ||
| return ffi::GetRef<ShapeExpr>(op); | ||
| } | ||
| return ShapeExpr(new_values, op->span); | ||
| } | ||
|
|
||
| private: | ||
| tirx::Var GetOrCreateSymbol(const PrimExpr& expr) { | ||
| auto it = expr_to_var_.find(expr); | ||
| if (it != expr_to_var_.end()) { | ||
| return it->second; | ||
| } | ||
|
|
||
| std::string base_name = "shape_expr_symbol_" + std::to_string(symbol_counter_++); | ||
| tirx::Var sym_var(base_name, expr->dtype); | ||
| expr_to_var_.emplace(expr, sym_var); | ||
|
|
||
| PrimStructInfo target_sinfo(sym_var); | ||
| Var match_var(base_name + "_pv", target_sinfo); | ||
| builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo)); | ||
| return sym_var; | ||
| } | ||
|
|
||
| int symbol_counter_ = 0; | ||
| bool inside_binding_block_ = false; | ||
| std::unordered_map<PrimExpr, tirx::Var, ffi::StructuralHash, ffi::StructuralEqual> expr_to_var_; | ||
|
cchung100m marked this conversation as resolved.
|
||
| }; | ||
|
cchung100m marked this conversation as resolved.
|
||
|
|
||
| } // namespace | ||
|
|
||
| namespace transform { | ||
|
|
||
| Pass CanonicalizeShapeExpr() { | ||
| auto pass_func = [](Function f, IRModule m, PassContext pc) { | ||
| return Downcast<Function>(ShapeExprCanonicalizer()(std::move(f))); | ||
| }; | ||
| return CreateFunctionPass(/*pass_function=*/pass_func, | ||
| /*opt_level=*/0, | ||
| /*pass_name=*/"CanonicalizeShapeExpr", | ||
| /*required=*/{}); | ||
| } | ||
|
|
||
| TVM_FFI_STATIC_INIT_BLOCK() { | ||
| namespace refl = tvm::ffi::reflection; | ||
| refl::GlobalDef().def("relax.transform.CanonicalizeShapeExpr", CanonicalizeShapeExpr); | ||
| } | ||
|
|
||
| } // namespace transform | ||
| } // namespace relax | ||
| } // namespace tvm | ||
101 changes: 101 additions & 0 deletions
101
tests/python/relax/test_transform_canonicalize_shape_expr.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| import tvm.testing | ||
| from tvm import relax, tirx | ||
| from tvm.script import ir as I | ||
| from tvm.script import relax as R | ||
| from tvm.script import tirx as T | ||
|
|
||
|
|
||
| @I.ir_module | ||
| class Before: | ||
| @R.function | ||
| def main(x: R.Tensor(("x_0", "x_1", "x_2", "x_3"), "float32")): | ||
| R.func_attr({"relax.force_pure": True}) | ||
| x_0, x_1, x_2, x_3 = T.int64(), T.int64(), T.int64(), T.int64() | ||
| out: R.Tensor((T.int64(4) * (x_0 * x_1 * x_2 * x_3),), "float32") = R.zeros( | ||
| R.shape([T.int64(4) * (x_0 * x_1 * x_2 * x_3)]), dtype="float32" | ||
| ) | ||
| return out | ||
|
|
||
|
|
||
| def test_canonicalize_shape_expr_removes_composite_dims(): | ||
| mod = relax.transform.CanonicalizeShapeExpr()(Before) | ||
| composite_dims = [] | ||
|
|
||
| def _visit(expr): | ||
| if isinstance(expr, relax.ShapeExpr): | ||
| for dim in expr.values: | ||
| if not isinstance(dim, (tirx.IntImm, tirx.Var)): | ||
| composite_dims.append(dim) | ||
|
|
||
| relax.analysis.post_order_visit(mod["main"], _visit) | ||
| assert not composite_dims | ||
|
|
||
|
|
||
| def test_canonicalize_shape_expr_unblocks_vm_shape_lower(): | ||
| mod = Before | ||
| mod = relax.transform.CanonicalizeShapeExpr()(mod) | ||
| mod = relax.transform.ComputePrimValue()(mod) | ||
| mod = relax.transform.VMShapeLower()(mod) | ||
|
|
||
| assert any("compute_symbolic_expr" in gv.name_hint for gv in mod.get_global_vars()) | ||
|
|
||
|
|
||
| @I.ir_module | ||
| class ParamCompoundShape: | ||
| @R.function | ||
| def main(x: R.Tensor(("A", "B", "A + B"), "float32")) -> R.Tensor((1,), "float32"): | ||
| out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32") | ||
| return out | ||
|
|
||
|
|
||
| def test_canonicalize_shape_expr_skips_parameter_struct_info(): | ||
| mod = relax.transform.CanonicalizeShapeExpr()(ParamCompoundShape) | ||
| param_shape = mod["main"].params[0].struct_info.shape | ||
|
|
||
| assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values) | ||
|
|
||
|
|
||
| @I.ir_module | ||
| class NestedParamCompound: | ||
| @R.function | ||
| def main(x: R.Tensor(("n",), "float32")): | ||
| @R.function | ||
| def inner(y: R.Tensor(("a", "b", "a + b"), "float32")) -> R.Tensor((1,), "float32"): | ||
| out: R.Tensor((1,), "float32") = R.zeros(R.shape([1]), dtype="float32") | ||
| return out | ||
|
|
||
| return inner | ||
|
|
||
|
|
||
| def test_canonicalize_shape_expr_nested_function_no_block_pollution(): | ||
| mod = relax.transform.CanonicalizeShapeExpr()(NestedParamCompound) | ||
|
|
||
| leaked = [ | ||
| binding | ||
| for block in mod["main"].body.blocks | ||
| for binding in block.bindings | ||
| if isinstance(binding, relax.MatchCast) | ||
| ] | ||
|
|
||
| assert not leaked | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() | ||
|
Comment on lines
+74
to
+101
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test case to verify that nested functions do not have their parameters incorrectly canonicalized in the outer function's scope. @I.ir_module
class NestedFunc:
@R.function
def main(x: R.Tensor(("n",), "float32")):
with R.dataflow():
@R.function
def local_func(y: R.Tensor(("n + 1",), "float32")):
out = R.zeros(R.shape([n + 1]), dtype="float32")
return out
res = local_func(x)
R.output(res)
return res
def test_canonicalize_shape_expr_nested_function():
mod = relax.transform.CanonicalizeShapeExpr()(NestedFunc)
local_func = mod["main"].body.blocks[0].bindings[0].value
param_shape = local_func.params[0].struct_info.shape
assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values)
if __name__ == "__main__":
tvm.testing.main() |
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a local function (lambda) is defined inside a binding block,
inside_binding_block_will remaintruewhen visiting the local function's parameters and return type. This can lead to incorrect canonicalization of the nested function's parameters in the outer function's scope, emittingMatchCastbindings into the wrong block. We should overrideVisitExpr_(const FunctionNode* op)to temporarily resetinside_binding_block_tofalsewhen entering a local function.