diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index 01ad379e1f..658722171a 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -17,6 +17,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky from pytensor.tensor.linalg.decomposition.lu import lu_factor +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.solvers.core import SolveBase from pytensor.tensor.linalg.solvers.general import Solve, lu_solve from pytensor.tensor.linalg.solvers.linear_control import ( @@ -162,6 +163,23 @@ def scalar_solve_to_division(fgraph, node): return [new_out] +@register_stabilize +@node_rewriter([blockwise_of(SolveBase)]) +def solve_of_inv_to_matmul(fgraph, node): + """Replace solve(matrix_inverse(X), b) with X @ b. + + If A = inv(X), then solve(A, b) finds x such that A @ x = b, + i.e., inv(X) @ x = b, so x = X @ b. + """ + A, b = node.inputs + + match A.owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + new_out = X @ b + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + def decompose_A(A, assume_a, lower): if assume_a == "gen": return lu_factor(A) diff --git a/tests/tensor/rewriting/linalg/test_solvers.py b/tests/tensor/rewriting/linalg/test_solvers.py index 6f554da1fd..fe36d886cf 100644 --- a/tests/tensor/rewriting/linalg/test_solvers.py +++ b/tests/tensor/rewriting/linalg/test_solvers.py @@ -9,6 +9,8 @@ from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph import ancestors +from pytensor.graph.rewriting.basic import WalkingGraphRewriter +from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.scan.op import Scan from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky @@ -24,8 +26,10 @@ from pytensor.tensor.rewriting.linalg.solvers import ( reuse_decomposition_multiple_solves, scan_split_non_sequence_decomposition_and_solve, + solve_of_inv_to_matmul, ) from pytensor.tensor.type import matrix, tensor +from tests.unittest_tools import assert_equal_computations def test_generic_solve_to_solve_triangular(): @@ -439,3 +443,29 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): resx1 = fn_opt(A_test, x0_test) rtol = 1e-7 if config.floatX == "float64" else 1e-4 np.testing.assert_allclose(resx0, resx1, rtol=rtol) + + +@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") +def test_solve_of_inv_to_matmul(b_ndim): + X = pt.dmatrix("X") + b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b") + out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) + + # Just include the rewrite we are testing + rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul) + rewritten_out = rewrite_graph(out, custom_rewrite=rewriter) + + # Verify the rewrite + expected = X @ b + assert_equal_computations([rewritten_out], [expected]) + + # Numerical check + rng = np.random.default_rng(42) + X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype) + b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype) + + f_opt = function([X, b], rewritten_out) + res_opt = f_opt(X_val, b_val) + res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val) + + np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)