diff --git a/pytensor/xtensor/rewriting/math.py b/pytensor/xtensor/rewriting/math.py index c767ec490e..bdbebb26e8 100644 --- a/pytensor/xtensor/rewriting/math.py +++ b/pytensor/xtensor/rewriting/math.py @@ -2,6 +2,8 @@ from pytensor.graph import node_rewriter from pytensor.tensor import einsum +from pytensor.tensor.einsum import Einsum +from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.shape import specify_shape from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.math import Dot @@ -41,6 +43,14 @@ def lower_dot(fgraph, node): # Perform the einsum operation out_tensor = einsum(einsum_str, x_tensor, y_tensor) + # Inline the Einsum OFG eagerly. `inline_optimized_einsum` only fires + # during `specialize`, but while the OFG is alive `ShapeFeature` calls + # `OpFromGraph.infer_shape` on every import, re-walking the inner graph + # each time. With many composed xtensor dots that dominates compile + # time. The 2-operand case has no path optimisation to defer. + if out_tensor.owner is not None and isinstance(out_tensor.owner.op, Einsum): + [out_tensor] = inline_ofg_node(out_tensor.owner) + # Reshape to match the output shape out_tensor = specify_shape(out_tensor, out.type.shape) diff --git a/tests/xtensor/test_math.py b/tests/xtensor/test_math.py index bb187a0fac..df69a7a22d 100644 --- a/tests/xtensor/test_math.py +++ b/tests/xtensor/test_math.py @@ -316,6 +316,30 @@ def test_dot(): xr_assert_allclose(z_test, expected) +def test_dot_lowering_inlines_einsum_ofg(): + """``lower_dot`` must inline the ``Einsum`` OFG that ``pt.einsum`` wraps. + + Leaving the OFG in place lets ``ShapeFeature.on_import`` call + ``OpFromGraph.infer_shape`` on every node import during canonicalize, + which re-walks the inner graph and dominates compile time once several + xtensor dots are composed (e.g. multi-layer attention). + """ + from pytensor.compile.builders import OpFromGraph + from pytensor.graph.rewriting.utils import rewrite_graph + from pytensor.graph.traversal import io_toposort + + x = xtensor("x", dims=("a", "b"), shape=(2, 3)) + y = xtensor("y", dims=("b", "c"), shape=(3, 4)) + z = x.dot(y) + + lowered = rewrite_graph(z.values, include=("lower_xtensor",)) + ofg_nodes = [n for n in io_toposort([], [lowered]) if isinstance(n.op, OpFromGraph)] + assert ofg_nodes == [], ( + "lower_dot should inline the Einsum OpFromGraph eagerly; got: " + f"{[type(n.op).__name__ for n in ofg_nodes]}" + ) + + def test_dot_errors(): # No matching dimensions x = xtensor("x", dims=("a", "b"), shape=(2, 3)) @@ -339,10 +363,15 @@ def test_dot_errors(): fn = xr_function([x, y], z) x_test = DataArray(np.ones((2, 3)), dims=("a", "b")) y_test = DataArray(np.ones((4, 5)), dims=("b", "c")) - # Doesn't fail until the rewrite + # Doesn't fail until the rewrite. The exact message depends on which op + # raises (np.einsum vs np.dot vs the inlined Dot's runtime shape check). with pytest.raises( ValueError, - match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)", + match=( + r"Input operand 1 has a mismatch in its core dimension 0" + r"|incompatible array sizes for np.dot" + r"|Shape mismatch: x has" + ), ): fn(x_test, y_test)