From ef9fb978eb28f06487f4f200a3746f26d3244f5e Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Tue, 26 May 2026 23:09:23 -0500 Subject: [PATCH] Add det(inv(X)) -> 1/det(X) and log/sign(reciprocal) rewrites Closes #2083 --- pytensor/tensor/rewriting/linalg/summary.py | 11 +++ pytensor/tensor/rewriting/math.py | 91 ++++++++++++++++++- tests/tensor/rewriting/linalg/test_summary.py | 35 +++++++ tests/tensor/rewriting/test_math.py | 80 ++++++++++++++++ 4 files changed, 216 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg/summary.py b/pytensor/tensor/rewriting/linalg/summary.py index d702a064cb..2e41fb3d1b 100644 --- a/pytensor/tensor/rewriting/linalg/summary.py +++ b/pytensor/tensor/rewriting/linalg/summary.py @@ -18,6 +18,7 @@ from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor from pytensor.tensor.linalg.decomposition.qr import QR from pytensor.tensor.linalg.decomposition.svd import SVD +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.summary import SLogDet, det from pytensor.tensor.math import Prod, log, prod from pytensor.tensor.rewriting.basic import ( @@ -192,6 +193,16 @@ def det_of_triangular(fgraph, node): return [det_val] +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def det_of_inv(fgraph, node): + """Replace det(matrix_inverse(X)) with reciprocal(det(X)).""" + match node.inputs[0].owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + return [1 / det(X)] + + @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b04892d6ce..f7839446fe 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -112,7 +112,7 @@ from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import Shape, Shape_i, specify_shape -from pytensor.tensor.subtensor import Subtensor +from pytensor.tensor.subtensor import Subtensor, _is_provably_positive from pytensor.tensor.type import ( complex_dtypes, uint_dtypes, @@ -689,6 +689,95 @@ def local_exp_log_nan_switch(fgraph, node): return [new_out] +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_reciprocal(fgraph, node): + """Rewrite log(reciprocal(x)) -> -log(x).""" + (inp,) = node.inputs + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) + ): + return [neg(log(inp.owner.inputs[0]))] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_div(fgraph, node): + """Rewrite log(a / b) -> log(a) - log(b) when a or b is provably positive. + + The provably-positive side is typically a constant or a shape, which the + surrounding pipeline constant-folds. + """ + (inp,) = node.inputs + if not ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.TrueDiv) + ): + return None + + num, den = inp.owner.inputs + if _is_provably_positive(num, strict=True) or _is_provably_positive( + den, strict=True + ): + return [log(num) - log(den)] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([sign]) +def local_sign_reciprocal(fgraph, node): + """Rewrite sign(reciprocal(x)) -> sign(x).""" + (inp,) = node.inputs + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) + ): + return [sign(inp.owner.inputs[0])] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([sign]) +def local_sign_div(fgraph, node): + """Rewrite sign(a / b) using a known-sign numerator or denominator. + + Provably positive side -> ``sign(other)``; negative constant side -> + ``-sign(other)``. Bails out otherwise. + """ + (inp,) = node.inputs + if not ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.TrueDiv) + ): + return None + + num, den = inp.owner.inputs + + if _is_provably_positive(num, strict=True): + return [sign(den)] + if _is_provably_positive(den, strict=True): + return [sign(num)] + + for side, other in ((num, den), (den, num)): + try: + val = get_underlying_scalar_constant_value(side) + except NotScalarConstantError: + continue + if np.all(val < 0): + return [neg(sign(other))] + + @register_canonicalize @register_specialize @node_rewriter([Sum]) diff --git a/tests/tensor/rewriting/linalg/test_summary.py b/tests/tensor/rewriting/linalg/test_summary.py index 8f61da491c..517e178ebc 100644 --- a/tests/tensor/rewriting/linalg/test_summary.py +++ b/tests/tensor/rewriting/linalg/test_summary.py @@ -459,3 +459,38 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn): expected = expected_fn(x) rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) assert_equal_computations([rewritten], [expected]) + + +def test_det_of_inv(): + x = pt.tensor("x", shape=(3, 3)) + out = det(pt.linalg.inv(x)) + expected = pt.as_tensor(1.0, dtype="float64") / det(x) + rewritten = rewrite_graph(out, include=["canonicalize", "stabilize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_slogdet_of_inv(): + x = pt.dmatrix("x") + # slogdet(inv(x)) -> (sign, logabsdet) + sign_inv, logabsdet_inv = pt.linalg.slogdet(pt.linalg.inv(x)) + + # expected: (sign(det(x)), -logabsdet(det(x))) + # det(inv(x)) = 1/det(x), so sign is same. + # logabsdet(inv(x)) = log(abs(1/det(x))) = -log(abs(det(x))) + sign_x, logabsdet_x = pt.linalg.slogdet(x) + expected_sign = sign_x + expected_logabsdet = -logabsdet_x + + # We need stabilize for det_of_inv and log_reciprocal + # and specialize for slogdet_specialization + rewritten_sign, rewritten_logabsdet = rewrite_graph( + [sign_inv, logabsdet_inv], include=["canonicalize", "stabilize", "specialize"] + ) + + expected_sign_opt, expected_logabsdet_opt = rewrite_graph( + [expected_sign, expected_logabsdet], + include=["canonicalize", "stabilize", "specialize"], + ) + + assert_equal_computations([rewritten_sign], [expected_sign_opt]) + assert_equal_computations([rewritten_logabsdet], [expected_logabsdet_opt]) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9106163c2b..7cd7ee286b 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5116,3 +5116,83 @@ def test_rewrite_does_not_apply(self): original, include=("canonicalize", "stabilize", "specialize") ) assert_equal_computations([rewritten], [original]) + + +def test_log_reciprocal(): + x = pt.dscalar("x") + out = pt.log(pt.reciprocal(x)) + expected = -pt.log(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_sign_reciprocal(): + x = pt.dscalar("x") + out = pt.sign(pt.reciprocal(x)) + expected = pt.sign(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +@pytest.mark.parametrize( + "build, expected_fn", + [ + (lambda x: pt.log(3.0 / x), lambda x: pt.log(3.0) - pt.log(x)), + (lambda x: pt.log(x / 3.0), lambda x: pt.log(x) - pt.log(3.0)), + (lambda x: pt.log(1.0 / x), lambda x: -pt.log(x)), + ], + ids=["pos_const_num", "pos_const_den", "one_over_x"], +) +def test_log_div_positive_constant(build, expected_fn): + x = pt.dscalar("x") + rewritten = rewrite_graph( + build(x), include=["canonicalize", "stabilize", "specialize"] + ) + expected = rewrite_graph( + expected_fn(x), include=["canonicalize", "stabilize", "specialize"] + ) + assert_equal_computations([rewritten], [expected]) + + +def test_log_div_non_constant_not_rewritten(): + x = pt.dscalar("x") + y = pt.dscalar("y") + out = pt.log(x / y) + rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"]) + # No constant to peel off — graph should still contain a true_div. + nodes = [v.owner for v in ancestors([rewritten]) if v.owner] + assert any( + isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes + ) + + +@pytest.mark.parametrize( + "build, expected_fn", + [ + (lambda x: pt.sign(3.0 / x), lambda x: pt.sign(x)), + (lambda x: pt.sign(-3.0 / x), lambda x: -pt.sign(x)), + (lambda x: pt.sign(x / 3.0), lambda x: pt.sign(x)), + (lambda x: pt.sign(x / -3.0), lambda x: -pt.sign(x)), + ], + ids=["pos_num", "neg_num", "pos_den", "neg_den"], +) +def test_sign_div_constant(build, expected_fn): + x = pt.dscalar("x") + rewritten = rewrite_graph( + build(x), include=["canonicalize", "stabilize", "specialize"] + ) + expected = rewrite_graph( + expected_fn(x), include=["canonicalize", "stabilize", "specialize"] + ) + assert_equal_computations([rewritten], [expected]) + + +def test_sign_div_non_constant_not_rewritten(): + x = pt.dscalar("x") + y = pt.dscalar("y") + out = pt.sign(x / y) + rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"]) + nodes = [v.owner for v in ancestors([rewritten]) if v.owner] + assert any( + isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes + )