diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 6cdd98e11d..9e0f5fc54d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -721,6 +721,14 @@ def elemwise_scalar_op_has_c_code( # Already part of the subgraph continue + if node_bitflag & all_subgraphs_bitset: + # Already part of another subgraph + if is_ancestor: + unfuseable_ancestors_bitset |= node_ancestors_bitset + else: + unfuseable_clients_bitset |= node_bitflag + continue + if is_ancestor: if node_bitflag & unfuseable_ancestors_bitset: # An unfuseable ancestor of the subgraph depends on this node, can't fuse @@ -827,7 +835,8 @@ def elemwise_scalar_op_has_c_code( ancestors_bitsets |= ( (node, node_ancestors_bitset | subgraph_and_ancestors) for node, node_ancestors_bitset in ancestors_bitsets.items() - if node_ancestors_bitset & subgraph_bitset + if (node_ancestors_bitset & subgraph_bitset) + and not (nodes_bitflags[node] & subgraph_bitset) ) # Collect the subgraph diff --git a/tests/tensor/rewriting/test_fusion_cycle.py b/tests/tensor/rewriting/test_fusion_cycle.py new file mode 100644 index 0000000000..dffa543cfd --- /dev/null +++ b/tests/tensor/rewriting/test_fusion_cycle.py @@ -0,0 +1,43 @@ +import pytensor.tensor as pt +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.rewriting.elemwise import FusionOptimizer + + +def test_fusion_no_cycle_after_multi_output(): + """Verify that multi-output fusions do not cause cycles in later discoveries. + + A bug in the ancestors_bitsets update logic for multi-output subgraphs + incorrectly marked ancestors within the subgraph as depending on their + own descendants. + """ + in_a = pt.vector("a") + a = pt.exp(in_a) + b = pt.exp(a) + c = pt.log(a) + d = pt.exp(b) + e = pt.exp(c) + + fgraph = FunctionGraph([in_a], [d, e]) + optimizer = FusionOptimizer() + # Should not raise ValueError: graph contains cycles + optimizer.apply(fgraph) + + +def test_fusion_cycle_diamond(): + """Test fusion in a diamond-like structure with mixed fuseability. + + This structure can trigger cycles if subgraphs overlap or are non-convex. + """ + x = pt.matrix("x") + a = pt.exp(x) + # Path 1: fuseable + b = pt.exp(a) + # Path 2: unfuseable in the middle + c = a[0] + d = pt.exp(c) + # Sink + out = b + d + + fgraph = FunctionGraph([x], [out]) + # Should not raise ValueError: graph contains cycles + FusionOptimizer().apply(fgraph)