Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,14 @@ def elemwise_scalar_op_has_c_code(
# Already part of the subgraph
continue

if node_bitflag & all_subgraphs_bitset:
# Already part of another subgraph
if is_ancestor:
unfuseable_ancestors_bitset |= node_ancestors_bitset
else:
unfuseable_clients_bitset |= node_bitflag
continue

if is_ancestor:
if node_bitflag & unfuseable_ancestors_bitset:
# An unfuseable ancestor of the subgraph depends on this node, can't fuse
Expand Down Expand Up @@ -827,7 +835,8 @@ def elemwise_scalar_op_has_c_code(
ancestors_bitsets |= (
(node, node_ancestors_bitset | subgraph_and_ancestors)
for node, node_ancestors_bitset in ancestors_bitsets.items()
if node_ancestors_bitset & subgraph_bitset
if (node_ancestors_bitset & subgraph_bitset)
and not (nodes_bitflags[node] & subgraph_bitset)
)

# Collect the subgraph
Expand Down
43 changes: 43 additions & 0 deletions tests/tensor/rewriting/test_fusion_cycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.rewriting.elemwise import FusionOptimizer


def test_fusion_no_cycle_after_multi_output():
"""Verify that multi-output fusions do not cause cycles in later discoveries.

A bug in the ancestors_bitsets update logic for multi-output subgraphs
incorrectly marked ancestors within the subgraph as depending on their
own descendants.
"""
in_a = pt.vector("a")
a = pt.exp(in_a)
b = pt.exp(a)
c = pt.log(a)
d = pt.exp(b)
e = pt.exp(c)

fgraph = FunctionGraph([in_a], [d, e])
optimizer = FusionOptimizer()
# Should not raise ValueError: graph contains cycles
optimizer.apply(fgraph)


def test_fusion_cycle_diamond():
"""Test fusion in a diamond-like structure with mixed fuseability.

This structure can trigger cycles if subgraphs overlap or are non-convex.
"""
x = pt.matrix("x")
a = pt.exp(x)
# Path 1: fuseable
b = pt.exp(a)
# Path 2: unfuseable in the middle
c = a[0]
d = pt.exp(c)
# Sink
out = b + d

fgraph = FunctionGraph([x], [out])
# Should not raise ValueError: graph contains cycles
FusionOptimizer().apply(fgraph)
Loading