Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pytensor/tensor/rewriting/linalg/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky
from pytensor.tensor.linalg.decomposition.lu import lu_factor
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.solvers.core import SolveBase
from pytensor.tensor.linalg.solvers.general import Solve, lu_solve
from pytensor.tensor.linalg.solvers.linear_control import (
Expand Down Expand Up @@ -162,6 +163,23 @@ def scalar_solve_to_division(fgraph, node):
return [new_out]


@register_stabilize
@node_rewriter([blockwise_of(SolveBase)])
def solve_of_inv_to_matmul(fgraph, node):
"""Replace solve(matrix_inverse(X), b) with X @ b.

If A = inv(X), then solve(A, b) finds x such that A @ x = b,
i.e., inv(X) @ x = b, so x = X @ b.
"""
A, b = node.inputs

match A.owner_op_and_inputs:
case (Blockwise(MatrixInverse()), X):
new_out = X @ b
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


def decompose_A(A, assume_a, lower):
if assume_a == "gen":
return lu_factor(A)
Expand Down
30 changes: 30 additions & 0 deletions tests/tensor/rewriting/linalg/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import ancestors
from pytensor.graph.rewriting.basic import WalkingGraphRewriter
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.scan.op import Scan
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky
Expand All @@ -24,8 +26,10 @@
from pytensor.tensor.rewriting.linalg.solvers import (
reuse_decomposition_multiple_solves,
scan_split_non_sequence_decomposition_and_solve,
solve_of_inv_to_matmul,
)
from pytensor.tensor.type import matrix, tensor
from tests.unittest_tools import assert_equal_computations


def test_generic_solve_to_solve_triangular():
Expand Down Expand Up @@ -439,3 +443,29 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-4
np.testing.assert_allclose(resx0, resx1, rtol=rtol)


@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
X = pt.dmatrix("X")
b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

# Just include the rewrite we are testing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not true, rewrite_graph includes canonicalize by default

rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)
Comment on lines +455 to +456
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the import + custom_rewrite. The pattern I had in mind was to just use rewrite_graph(out, include=('your_rewrite_name', )). If that doesn't work I'd rather it be reverted to what you had before for simplicity. But also fine with it staying this way if you don't want to keep going back and forth.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the test suite for solve_of_inv_to_matmul to address the feedback regarding custom rewriters and numerical checks. During the refactoring, I encountered a failure that highlights why we need to stabilize both sides of the assertion when using assert_equal_computations.

The Issue

Even when the rewrite was triggered, X @ b initially produces a Matmul op. However, during the stabilize phase, PyTensor lowers this to a more specific Dot op.

Failure Example (b_ndim=1):

  • Rewritten Graph: Squeeze(Dot(X, ExpandDims(b)))
  • Expected Graph (Raw): Squeeze(Matmul(X, ExpandDims(b)))
  • Result: AssertionError: equal_computations failed (Dot vs Matmul).

The Fix

I updated the test to apply the stabilize group to both the output and the expected result. This ensures we are comparing the graphs in their final, canonicalized form:

# Restore stabilization to trigger the rewrite and canonicalize the internal Ops (Matmul -> Dot)
rewritten_out = rewrite_graph(out, include=["stabilize"])
expected = rewrite_graph(X @ b, include=["stabilize"])

assert_equal_computations([rewritten_out], [expected])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes matmul is eagerly rewritten as dot because old rewrites handled dot directly, we still need to transition to matmul by default. But now test is good?


# Verify the rewrite
expected = X @ b
assert_equal_computations([rewritten_out], [expected])

# Numerical check
rng = np.random.default_rng(42)
X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)

f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
Comment on lines +467 to +468
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT (if you saw it): NVM

f_opt, no need to compile aggressively, you already locked the rewrite in, so now a simple mode=Mode(linker="py", optimizer=None) allows the fastest numerical eval.

In this case I'm okay with not testing it numerically, up to you

res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)

np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)
Comment on lines +462 to +471
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works)

Loading