[Relax][Backend] Fix TVM crashes with default relax pipeline when opt_level=1: InternalError: Check failed: (slot->value_computed) is false#18491
Conversation
1058392 to
95466f1
Compare
| if (it != slot_map_.end() && !it->second->value_computed) { | ||
| // If it's a variable, mark it as ready for computation | ||
| if (expr.as<tir::VarNode>()) { | ||
| it->second->value_computed = true; |
There was a problem hiding this comment.
Marking arbitrary tir::VarNode instances as value_computed = true in VisitExpr_ is incorrect. In VMShapeLower, only Relax ShapeVars (from function parameters/match_cast) or IntImm constants can be safely marked as computed, because only these have runtime values accessible to the VM.
I think correct way is to fix this earlier in the pipeline (shape inference/canonicalization) by symbolizing composite PrimExpr (introduce Relax ShapeVars) or simplifying them before VMShapeLower.
There was a problem hiding this comment.
Hi @tlopex
Thanks for the suggestion.
Should we extend ComputePrimValue() to handle ShapeExpr nodes with compound PrimExpr?
https://github.com/apache/tvm/blob/main/src/relax/transform/compute_prim_value.cc#L65
There was a problem hiding this comment.
Well, ComputePrimValue is intended only for evaluating statically evaluable PrimExpr into IntImm (constant folding), so I think extending ComputePrimValue would not address the root issue.
The real problem is that VMShapeLower cannot consume composite PrimExpr directly. The correct solution here should be to canonicalize ShapeExpr earlier by introducing a Relax ShapeVar binding for any non-trivial PrimExpr.
Just like:
# 1. Compute the symbolic value first (Canonicalization)
s1 = R.prim_value(n + 1)
# 2. Pass the computed var to the shape (VMShapeLower is happy now)
lv = R.call_tir(cls.func, (x,), R.shape([s1]), dtype="float32")
533de8a to
b98b5a7
Compare
a59ed45 to
b3f7800
Compare
8cbe397 to
4089fec
Compare
4100b4c to
bbd6752
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new Relax pass, CanonicalizeShapeExpr, which canonicalizes compound shape expressions by replacing them with fresh symbolic variables, and integrates it into several backend pipelines. The review feedback highlights a potential crash when canonicalizing function parameters due to emitting bindings outside of an active block, a Python compatibility issue in the tests when using the union operator in isinstance, and a minor typo in the Python docstring.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
…vel=1: InternalError: Check failed: (slot->value_computed) is false
9d4e9c3 to
95a5610
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a new pass, CanonicalizeShapeExpr, which canonicalizes compound shape expressions (e.g., n + 1) by replacing them with fresh symbolic variables and emitting MatchCast bindings. This pass is integrated into the finalize pipelines for CPU, CUDA, GPU, and ROCm backends, and includes Python bindings and unit tests. The review feedback points out a critical issue where the function's return struct info (op->ret_struct_info) is not canonicalized, which can cause downstream passes to crash. Additionally, the reviewer notes that calling builder_->EmitNormalized outside of an active binding block (such as when visiting the return struct info) will crash the compiler, and suggests tracking whether the visitor is inside a binding block to safely collect and prepend out-of-block bindings.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces the CanonicalizeShapeExpr pass in TVM Relax, which replaces compound PrimExpr dimensions in ShapeExpr with fresh symbolic variables to prevent issues during downstream shape lowering. The pass is integrated into the CPU, CUDA, GPU, and ROCm compilation pipelines. The review feedback highlights a potential bug where nested local functions might have their parameters incorrectly canonicalized within the outer scope, suggesting a fix to reset the state when entering a FunctionNode. Additionally, the feedback points out an unused MatchCast binding variable and recommends adding a test case to verify the behavior of nested functions.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| BindingBlock VisitBindingBlock(const BindingBlock& block) final { | ||
| bool prev = inside_binding_block_; | ||
| inside_binding_block_ = true; | ||
| BindingBlock ret = ExprMutator::VisitBindingBlock(block); | ||
| inside_binding_block_ = prev; | ||
| return ret; | ||
| } |
There was a problem hiding this comment.
If a local function (lambda) is defined inside a binding block, inside_binding_block_ will remain true when visiting the local function's parameters and return type. This can lead to incorrect canonicalization of the nested function's parameters in the outer function's scope, emitting MatchCast bindings into the wrong block. We should override VisitExpr_(const FunctionNode* op) to temporarily reset inside_binding_block_ to false when entering a local function.
BindingBlock VisitBindingBlock(const BindingBlock& block) final {
bool prev = inside_binding_block_;
inside_binding_block_ = true;
BindingBlock ret = ExprMutator::VisitBindingBlock(block);
inside_binding_block_ = prev;
return ret;
}
Expr VisitExpr_(const FunctionNode* op) final {
bool prev = inside_binding_block_;
inside_binding_block_ = false;
Expr ret = ExprMutator::VisitExpr_(op);
inside_binding_block_ = prev;
return ret;
}| builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo)); | ||
| MatchCast binding(match_var, PrimValue(expr), target_sinfo); |
There was a problem hiding this comment.
The local variable binding of type MatchCast is constructed but never used. Since builder_->EmitNormalized already creates and emits the MatchCast binding, this line is completely redundant and can be safely removed.
| builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo)); | |
| MatchCast binding(match_var, PrimValue(expr), target_sinfo); | |
| builder_->EmitNormalized(MatchCast(match_var, PrimValue(expr), target_sinfo)); |
|
|
||
| if __name__ == "__main__": | ||
| tvm.testing.main() |
There was a problem hiding this comment.
Add a test case to verify that nested functions do not have their parameters incorrectly canonicalized in the outer function's scope.
@I.ir_module
class NestedFunc:
@R.function
def main(x: R.Tensor(("n",), "float32")):
with R.dataflow():
@R.function
def local_func(y: R.Tensor(("n + 1",), "float32")):
out = R.zeros(R.shape([n + 1]), dtype="float32")
return out
res = local_func(x)
R.output(res)
return res
def test_canonicalize_shape_expr_nested_function():
mod = relax.transform.CanonicalizeShapeExpr()(NestedFunc)
local_func = mod["main"].body.blocks[0].bindings[0].value
param_shape = local_func.params[0].struct_info.shape
assert any(not isinstance(dim, (tirx.IntImm, tirx.Var)) for dim in param_shape.values)
if __name__ == "__main__":
tvm.testing.main()
Hi Commiters,
This PR is trying to fix issues #17876. Any suggestions would be appreciated if you are available.
Root Cause
VMShapeLowercrashed when processingShapeExprcontaining compositePrimExprthat weren't computed yet.Solution
ModifiedVisitExpr_(const ShapeExprNode* op)invm_shape_lower.ccto:1. Mark uncomputed variables as ready for computation2. Trigger EmitOutstandingPrimExprCompute() to resolve the dependency chain3. Ensure all expressions are computed before callingMakeSymbolicShapeArgAdded test case:test_composite_shape_expression_fix()to prevent future occurrences.symbolizing the composite PrimExpr in the pipeline: canonicalization