Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pytensor/xtensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pytensor.graph import node_rewriter
from pytensor.tensor import einsum
from pytensor.tensor.einsum import Einsum
from pytensor.tensor.rewriting.ofg import inline_ofg_node
from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import Dot
Expand Down Expand Up @@ -41,6 +43,14 @@ def lower_dot(fgraph, node):
# Perform the einsum operation
out_tensor = einsum(einsum_str, x_tensor, y_tensor)

# Inline the Einsum OFG eagerly. `inline_optimized_einsum` only fires
# during `specialize`, but while the OFG is alive `ShapeFeature` calls
# `OpFromGraph.infer_shape` on every import, re-walking the inner graph
# each time. With many composed xtensor dots that dominates compile
# time. The 2-operand case has no path optimisation to defer.
if out_tensor.owner is not None and isinstance(out_tensor.owner.op, Einsum):
[out_tensor] = inline_ofg_node(out_tensor.owner)

# Reshape to match the output shape
out_tensor = specify_shape(out_tensor, out.type.shape)

Expand Down
33 changes: 31 additions & 2 deletions tests/xtensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,30 @@ def test_dot():
xr_assert_allclose(z_test, expected)


def test_dot_lowering_inlines_einsum_ofg():
"""``lower_dot`` must inline the ``Einsum`` OFG that ``pt.einsum`` wraps.

Leaving the OFG in place lets ``ShapeFeature.on_import`` call
``OpFromGraph.infer_shape`` on every node import during canonicalize,
which re-walks the inner graph and dominates compile time once several
xtensor dots are composed (e.g. multi-layer attention).
"""
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.traversal import io_toposort

x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
z = x.dot(y)

lowered = rewrite_graph(z.values, include=("lower_xtensor",))
ofg_nodes = [n for n in io_toposort([], [lowered]) if isinstance(n.op, OpFromGraph)]
assert ofg_nodes == [], (
"lower_dot should inline the Einsum OpFromGraph eagerly; got: "
f"{[type(n.op).__name__ for n in ofg_nodes]}"
)


def test_dot_errors():
# No matching dimensions
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
Expand All @@ -339,10 +363,15 @@ def test_dot_errors():
fn = xr_function([x, y], z)
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))
y_test = DataArray(np.ones((4, 5)), dims=("b", "c"))
# Doesn't fail until the rewrite
# Doesn't fail until the rewrite. The exact message depends on which op
# raises (np.einsum vs np.dot vs the inlined Dot's runtime shape check).
with pytest.raises(
ValueError,
match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)",
match=(
r"Input operand 1 has a mismatch in its core dimension 0"
r"|incompatible array sizes for np.dot"
r"|Shape mismatch: x has"
),
):
fn(x_test, y_test)

Expand Down
Loading