-
Notifications
You must be signed in to change notification settings - Fork 186
Rewrite solve(matrix_inverse(X), b) → X @ b #2101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,8 @@ | |
| from pytensor.configdefaults import config | ||
| from pytensor.gradient import grad | ||
| from pytensor.graph import ancestors | ||
| from pytensor.graph.rewriting.basic import WalkingGraphRewriter | ||
| from pytensor.graph.rewriting.utils import rewrite_graph | ||
| from pytensor.scan.op import Scan | ||
| from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape | ||
| from pytensor.tensor.linalg.decomposition.cholesky import Cholesky, cholesky | ||
|
|
@@ -24,8 +26,10 @@ | |
| from pytensor.tensor.rewriting.linalg.solvers import ( | ||
| reuse_decomposition_multiple_solves, | ||
| scan_split_non_sequence_decomposition_and_solve, | ||
| solve_of_inv_to_matmul, | ||
| ) | ||
| from pytensor.tensor.type import matrix, tensor | ||
| from tests.unittest_tools import assert_equal_computations | ||
|
|
||
|
|
||
| def test_generic_solve_to_solve_triangular(): | ||
|
|
@@ -439,3 +443,29 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): | |
| resx1 = fn_opt(A_test, x0_test) | ||
| rtol = 1e-7 if config.floatX == "float64" else 1e-4 | ||
| np.testing.assert_allclose(resx0, resx1, rtol=rtol) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}") | ||
| def test_solve_of_inv_to_matmul(b_ndim): | ||
| X = pt.dmatrix("X") | ||
| b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b") | ||
| out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) | ||
|
|
||
| # Just include the rewrite we are testing | ||
| rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul) | ||
| rewritten_out = rewrite_graph(out, custom_rewrite=rewriter) | ||
|
Comment on lines
+455
to
+456
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love the import + custom_rewrite. The pattern I had in mind was to just use
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have updated the test suite for The IssueEven when the rewrite was triggered, Failure Example (
The FixI updated the test to apply the # Restore stabilization to trigger the rewrite and canonicalize the internal Ops (Matmul -> Dot)
rewritten_out = rewrite_graph(out, include=["stabilize"])
expected = rewrite_graph(X @ b, include=["stabilize"])
assert_equal_computations([rewritten_out], [expected])
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes matmul is eagerly rewritten as dot because old rewrites handled dot directly, we still need to transition to matmul by default. But now test is good? |
||
|
|
||
| # Verify the rewrite | ||
| expected = X @ b | ||
| assert_equal_computations([rewritten_out], [expected]) | ||
|
|
||
| # Numerical check | ||
| rng = np.random.default_rng(42) | ||
| X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype) | ||
| b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype) | ||
|
|
||
| f_opt = function([X, b], rewritten_out) | ||
| res_opt = f_opt(X_val, b_val) | ||
|
Comment on lines
+467
to
+468
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EDIT (if you saw it): NVM f_opt, no need to compile aggressively, you already locked the rewrite in, so now a simple In this case I'm okay with not testing it numerically, up to you |
||
| res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val) | ||
|
|
||
| np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7) | ||
|
Comment on lines
+462
to
+471
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works) |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment is not true, rewrite_graph includes canonicalize by default