From 0cc426a8ae236751a05731865002963c232f778b Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:22:13 -0500 Subject: [PATCH 01/10] Add pt.block helper returning nested Joins Pure-Python helper that walks a nested-list np.block-style structure and returns nested concatenate outputs directly, validating uniform leaf depth and promoting ranks via atleast_Nd. --- pytensor/tensor/basic.py | 149 +++++++++++++++++++++++++++++++++---- tests/tensor/test_basic.py | 85 +++++++++++++++++++++ 2 files changed, 219 insertions(+), 15 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..0c11c36fe5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3020,28 +3020,146 @@ def vertical_stack(*args): return concatenate(_args, axis=0) -def is_flat(var, ndim=1): - """ - Verifies the dimensionality of the var is equal to - ndim. This method is usually called after flatten method on a - variable, where the first ndim-1 dimension size(s) of the variable - is kept intact, and the last dimension size of the variable is made - equal to the multiplication of its remaining dimension size(s), such that - the variable would end up with as many dimension as ndim. +def _block_check_depths_match(arrays, parent_index=()): + """Walk a nested block-list and check every leaf sits at the same depth. Parameters ---------- - var : pytensor.tensor.var.TensorVariable - the pytensor var on which the dimensionality is checked. + arrays : list or array_like + Nested block-list to validate. + parent_index : tuple of int, optional + Indices accumulated from the root, used in error messages. Default ``()``. - ndim : int - the expected dimensionality of var. + Returns + ------- + structure : nested tuple of None + Tree shape with ``None`` at each leaf position. + leaf_depth : int + Depth at which every leaf sits. + max_leaf_ndim : int + Largest ``ndim`` across all leaves. + """ + if isinstance(arrays, list): + if not arrays: + raise ValueError("Block: empty list is not allowed") + children = [] + first_leaf_depth = None + max_ndim = 0 + for i, child in enumerate(arrays): + child_struct, child_leaf_depth, child_ndim = _block_check_depths_match( + child, (*parent_index, i) + ) + if first_leaf_depth is None: + first_leaf_depth = child_leaf_depth + elif first_leaf_depth != child_leaf_depth: + raise ValueError( + "Block: all leaves must be at the same nesting depth " + f"(got depth {child_leaf_depth} at index {(*parent_index, i)}, " + f"expected {first_leaf_depth})" + ) + if child_ndim > max_ndim: + max_ndim = child_ndim + children.append(child_struct) + return tuple(children), first_leaf_depth, max_ndim + elif isinstance(arrays, tuple): + raise TypeError("Block: tuples are not allowed as nested containers; use lists") + else: + leaf = as_tensor_variable(arrays) + return None, len(parent_index), leaf.type.ndim + + +def block(arrays): + """Assemble a tensor from nested lists of blocks, like ``numpy.block``. + + Parameters + ---------- + arrays : nested list of array_like + Tensors at the leaves, lists at the interior. Every leaf must sit at + the same nesting depth ``d``; the concatenation spans the last ``d`` + axes. Returns ------- - bool - the comparison result of var's dim - and the expected outdim. + result : TensorVariable + Assembled block tensor. A bare tensor (no list wrapping) returns as + ``atleast_1d(arrays)``. + + Examples + -------- + .. testcode:: + + import numpy as np + import pytensor.tensor as pt + + A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) + B = pt.as_tensor_variable(np.array([[5], [6]])) + C = pt.as_tensor_variable(np.array([[7, 8]])) + D = pt.as_tensor_variable(np.array([[9]])) + M = pt.block([[A, B], [C, D]]) + print(M.eval()) + + .. testoutput:: + + [[1 2 5] + [3 4 6] + [7 8 9]] + """ + structure, _, _ = _block_check_depths_match(arrays) + + if structure is None: + return atleast_Nd(arrays, n=1) + + flat = [] + + def _gather(node): + if isinstance(node, list): + for child in node: + _gather(child) + else: + flat.append(as_tensor_variable(node)) + + _gather(arrays) + + def _structure_depth(structure): + if structure is None: + return 0 + return 1 + _structure_depth(structure[0]) + + list_ndim = _structure_depth(structure) + result_ndim = builtins.max(list_ndim, builtins.max(inp.type.ndim for inp in flat)) + promoted = [atleast_Nd(inp, n=result_ndim) for inp in flat] + + def _unflatten_structure(flat, structure): + """Rebuild a nested list from ``flat`` consumed in pre-order against ``structure``.""" + it = iter(flat) + + def _build(s): + if s is None: + return next(it) + return [_build(child) for child in s] + + return _build(structure) + + nested = _unflatten_structure(promoted, structure) + + def _recurse(node, depth): + if depth == list_ndim: + return node + children = [_recurse(child, depth + 1) for child in node] + return concatenate(children, axis=-(list_ndim - depth)) + + return _recurse(nested, 0) + + +def is_flat(var, ndim=1): + """Return ``True`` when ``var`` has exactly ``ndim`` dimensions. + + Parameters + ---------- + var : TensorVariable + Variable to inspect. + ndim : int + Expected number of dimensions. Default 1. """ return var.ndim == ndim @@ -4558,6 +4676,7 @@ def ix_(*args): "atleast_2d", "atleast_3d", "atleast_Nd", + "block", "cast", "choose", "concatenate", diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 2ab01c047d..e96353dbd9 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -41,6 +41,7 @@ arange, as_tensor_variable, atleast_Nd, + block, cast, choose, constant, @@ -3632,6 +3633,90 @@ def test_grad_3d(self, offset, axis1, axis2): utt.verify_grad(diag_fn, [x], rng=rng) +class TestBlock: + """Tests that ``pt.block`` matches ``np.block``.""" + + def _check(self, arrays_pt, arrays_np, *, expected_ndim=None): + out = block(arrays_pt) + if expected_ndim is not None: + assert out.ndim == expected_ndim + np.testing.assert_allclose(out.eval(), np.block(arrays_np)) + + def test_depth_1(self): + a = np.arange(3.0) + b = np.arange(4.0) + self._check([a, b], [a, b], expected_ndim=1) + + def test_depth_2_block_matrix(self): + rng = np.random.default_rng(0) + A = rng.standard_normal((2, 3)) + B = rng.standard_normal((2, 4)) + C = rng.standard_normal((5, 3)) + D = rng.standard_normal((5, 4)) + self._check([[A, B], [C, D]], [[A, B], [C, D]], expected_ndim=2) + + def test_depth_3(self): + rng = np.random.default_rng(0) + x = rng.standard_normal((2, 3, 4)) + self._check([[[x, x], [x, x]]], [[[x, x], [x, x]]], expected_ndim=3) + + def test_scalar_promotion(self): + # Depth-2 list of scalars → 2-D output, like np.block. + self._check([[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], expected_ndim=2) + + def test_mixed_rank_promotion(self): + # 1-D leaf in a depth-2 list gets atleast_2d-promoted on the left. + rng = np.random.default_rng(0) + v = rng.standard_normal(3) + A = rng.standard_normal((2, 3)) + self._check([[v], [A]], [[v], [A]], expected_ndim=2) + + def test_ragged_widths(self): + # Same depth, different number of children per row — np.block allows it. + rng = np.random.default_rng(0) + A = rng.standard_normal((2, 1)) + B = rng.standard_normal((2, 2)) + C = rng.standard_normal((2, 3)) + D = rng.standard_normal((3, 6)) + self._check([[A, B, C], [D]], [[A, B, C], [D]], expected_ndim=2) + + def test_single_leaf_passthrough(self): + # np.block(arr) returns atleast_1d(arr). + out = block(5.0) + assert out.ndim == 1 + np.testing.assert_array_equal(out.eval(), np.block(5.0)) + + def test_depth_mismatch_raises(self): + A = pytensor.tensor.matrix("A") + with pytest.raises(ValueError, match="same nesting depth"): + block([[A, A], A]) + + def test_tuple_rejected(self): + A = pytensor.tensor.matrix("A") + with pytest.raises(TypeError, match="tuples are not allowed"): + block((A, A)) + + def test_empty_list_raises(self): + with pytest.raises(ValueError, match="empty list"): + block([]) + + def test_returns_nested_join(self): + # `block` is a pure helper: the output is the result of a Join of Joins, + # not a wrapping Op. + A = pytensor.tensor.matrix("A") + B = pytensor.tensor.matrix("B") + out = block([[A, B], [B, A]]) + # Outer concat along axis -2 (rows). + assert isinstance(out.owner.op, ptb.Join) + outer_axis = int(out.owner.inputs[0].data) + assert outer_axis == out.ndim - 2 + # Each row is itself a Join along axis -1 (columns). + for inner in out.owner.inputs[1:]: + assert isinstance(inner.owner.op, ptb.Join) + inner_axis = int(inner.owner.inputs[0].data) + assert inner_axis == inner.ndim - 1 + + class TestAllocDiag: # TODO: Separate perform, grad and infer_shape tests From 44bef5923ee31746d75187eab1de1b7cc3a7bb15 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:27:30 -0500 Subject: [PATCH 02/10] Add local_dot_of_join: decompose dot through Join Push dot inside a Join, splitting the other operand by leaf widths (or heights) and emitting per-leaf dots that sum or concat to the original result. Conservative: skips when partition dims are dynamic. --- pytensor/tensor/rewriting/math.py | 91 +++++++++++++ tests/tensor/rewriting/test_math.py | 202 ++++++++++++++++++++++++++++ 2 files changed, 293 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b04892d6ce..179e887aa1 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -37,6 +37,7 @@ ones_like, register_infer_shape, split, + stack, switch, zeros, zeros_like, @@ -231,6 +232,96 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): return {client.outputs[0]: new_output} +@register_stabilize +@node_rewriter([Join]) +def local_dot_of_join(fgraph, node): + r"""Push ``dot`` inside a :class:`Join`, decomposing the matmul into per-leaf products. + + When ``Join`` runs along the matmul-contracted axis, ``Y`` is split by symbolic per-leaf sizes and + the per-leaf products are summed. Otherwise each leaf multiplies ``Y`` directly and the results are concatenated. + + Walks through chains of left-``expand_dims`` ``DimShuffle`` nodes between the Join and the matmul + (Blockwise stacks pads this way). + """ + try: + join_axis = int( + get_underlying_scalar_constant_value( + node.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + + out_ndim = node.outputs[0].type.ndim + if join_axis < 0: + join_axis += out_ndim + # Only the last two axes participate in matmul; other axes are batch. + if join_axis not in (out_ndim - 1, out_ndim - 2): + return None + + leaves = node.inputs[1:] + if len(leaves) < 2: + return None + + join_out = node.outputs[0] + # Translate Join's axis (in its own ndim) to a "matmul axis" tag: + # join_matmul_axis = -1 -> Join concatenates along the inner mat axis + # join_matmul_axis = -2 -> Join concatenates along the outer mat axis + join_matmul_axis = join_axis - out_ndim # -1 or -2 + + def _walk_to_matmul(var): + """Yield ``(matmul_node, input_idx)`` for every Dot/_matmul reachable + from ``var`` through a chain of left-expand-dims DimShuffles.""" + for client, input_idx in fgraph.clients[var]: + if client.op in (_dot, _matmul): + yield client, input_idx + elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: + yield from _walk_to_matmul(client.outputs[0]) + + replacements: dict = {} + for client, client_idx in _walk_to_matmul(join_out): + if client.outputs[0] in replacements: + # ``dot(J, J)`` reaches the same matmul via both inputs, and + # ``ds(ds(J))`` (chained DimShuffles) reaches it via multiple + # paths. Either way: decompose once, let the next pass handle + # any side still wrapping the (now-fewer-clients) Join. + continue + + other = client.inputs[1 - client_idx] + dot_op = client.op + old_out = client.outputs[0] + + if client_idx == 0: + # Join @ other + if join_matmul_axis == -1: + widths = stack([leaf.shape[-1] for leaf in leaves]) + other_chunks = split(other, splits_size=widths, axis=-2) + terms = [ + dot_op(leaf, chunk) for leaf, chunk in zip(leaves, other_chunks) + ] + new_output = add(*terms) + else: + terms = [dot_op(leaf, other) for leaf in leaves] + new_output = concat_with_broadcast(terms, axis=-2) + else: + # other @ Join + if join_matmul_axis == -1: + terms = [dot_op(other, leaf) for leaf in leaves] + new_output = concat_with_broadcast(terms, axis=-1) + else: + heights = stack([leaf.shape[-2] for leaf in leaves]) + other_chunks = split(other, splits_size=heights, axis=-1) + terms = [ + dot_op(chunk, leaf) for chunk, leaf in zip(other_chunks, leaves) + ] + new_output = add(*terms) + + copy_stack_trace(old_out, new_output) + replacements[old_out] = new_output + + return replacements or None + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9106163c2b..48be68d87c 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5116,3 +5116,205 @@ def test_rewrite_does_not_apply(self): original, include=("canonicalize", "stabilize", "specialize") ) assert_equal_computations([rewritten], [original]) + + +class TestDotOfJoin: + @staticmethod + def _n_dots(fn): + return sum( + isinstance(n.op, Dot | Dot22) + or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, Dot | Dot22)) + for n in fn.maker.fgraph.toposort() + ) + + @staticmethod + def _join_consumes_originals(fn, originals): + """Whether any Join in the graph still has any of ``originals`` as inputs. + + Used to confirm the original Join was decomposed (not just shuffled).""" + original_set = set(originals) + for n in fn.maker.fgraph.toposort(): + from pytensor.tensor.basic import Join + + if isinstance(n.op, Join) and any( + inp in original_set for inp in n.inputs[1:] + ): + return True + return False + + def test_join_lhs_axis_neg1(self): + # [A | B] @ y -> A @ y[:n_a] + B @ y[n_a:] + a = pt.tensor("a", shape=(3, 4)) + b = pt.tensor("b", shape=(3, 5)) + y = pt.tensor("y", shape=(9, 6)) + M = pt.concatenate([a, b], axis=-1) + fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 4)) + b_v = rng.standard_normal((3, 5)) + y_v = rng.standard_normal((9, 6)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + np.concatenate([a_v, b_v], axis=-1) @ y_v, + atol=1e-12, + ) + + def test_join_lhs_axis_neg2(self): + # [[A], [B]] @ y -> concat([A @ y, B @ y], -2) + a = pt.tensor("a", shape=(3, 4)) + b = pt.tensor("b", shape=(2, 4)) + y = pt.tensor("y", shape=(4, 6)) + M = pt.concatenate([a, b], axis=-2) + fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 4)) + b_v = rng.standard_normal((2, 4)) + y_v = rng.standard_normal((4, 6)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + np.concatenate([a_v, b_v], axis=-2) @ y_v, + atol=1e-12, + ) + + def test_join_rhs_axis_neg1(self): + # y @ [A | B] -> concat([y @ A, y @ B], -1) + a = pt.tensor("a", shape=(4, 3)) + b = pt.tensor("b", shape=(4, 5)) + y = pt.tensor("y", shape=(7, 4)) + M = pt.concatenate([a, b], axis=-1) + fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((4, 3)) + b_v = rng.standard_normal((4, 5)) + y_v = rng.standard_normal((7, 4)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + y_v @ np.concatenate([a_v, b_v], axis=-1), + atol=1e-12, + ) + + def test_join_rhs_axis_neg2(self): + # y @ [[A], [B]] -> y[..., :n_a] @ A + y[..., n_a:] @ B + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(2, 5)) + y = pt.tensor("y", shape=(7, 5)) + M = pt.concatenate([a, b], axis=-2) + fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((2, 5)) + y_v = rng.standard_normal((7, 5)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + y_v @ np.concatenate([a_v, b_v], axis=-2), + atol=1e-12, + ) + + def test_unknown_widths_lhs_decomposes(self): + # LHS axis=-1 splits y by symbolic leaf widths; fires on dynamic shapes. + a = pt.matrix("a") + b = pt.matrix("b") + y = pt.matrix("y") + M = pt.concatenate([a, b], axis=-1) + fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 4)) + b_v = rng.standard_normal((3, 5)) + y_v = rng.standard_normal((9, 6)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + np.concatenate([a_v, b_v], axis=-1) @ y_v, + atol=1e-12, + ) + + def test_unknown_heights_rhs_decomposes(self): + # RHS axis=-2 splits y by symbolic leaf heights; fires on dynamic shapes. + a = pt.matrix("a") + b = pt.matrix("b") + y = pt.matrix("y") + M = pt.concatenate([a, b], axis=-2) + fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) + assert self._n_dots(fn) == 2 + assert not self._join_consumes_originals(fn, [a, b]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((2, 5)) + y_v = rng.standard_normal((7, 5)) + np.testing.assert_allclose( + fn(a_v, b_v, y_v), + y_v @ np.concatenate([a_v, b_v], axis=-2), + atol=1e-12, + ) + + def test_single_input_join_skipped(self): + # A "Join" with one input is a no-op; skip. + a = pt.tensor("a", shape=(3, 4)) + y = pt.tensor("y", shape=(4, 5)) + M = pt.join(-1, a) # single-input join + fn = pytensor.function([a, y], M @ y, mode=rewrite_mode) + assert self._n_dots(fn) == 1 + + @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) + def test_depth_2_block_at_x(self, left_multiply): + # ``pt.block([[a, b], [c, d]])`` produces nested Joins. Iterated + # ``local_dot_of_join`` decomposes both levels into leaf-level dots. + a = pt.tensor("a", shape=(2, 3)) + b = pt.tensor("b", shape=(2, 4)) + c = pt.tensor("c", shape=(5, 3)) + d = pt.tensor("d", shape=(5, 4)) + M = pt.block([[a, b], [c, d]]) + + if left_multiply: + other = pt.tensor("other", shape=(7, 6)) + out = M @ other + else: + other = pt.tensor("other", shape=(8, 7)) + out = other @ M + + fn = pytensor.function([a, b, c, d, other], out, mode=rewrite_mode) + # 4 leaf-level dots, regardless of side. + assert self._n_dots(fn) == 4 + + rng = np.random.default_rng(0) + a_v = rng.standard_normal(a.type.shape) + b_v = rng.standard_normal(b.type.shape) + c_v = rng.standard_normal(c.type.shape) + d_v = rng.standard_normal(d.type.shape) + other_v = rng.standard_normal(other.type.shape) + ref_M = np.block([[a_v, b_v], [c_v, d_v]]) + expected = ref_M @ other_v if left_multiply else other_v @ ref_M + np.testing.assert_allclose( + fn(a_v, b_v, c_v, d_v, other_v), + expected, + atol=1e-12, + rtol=1e-12, + ) + + def test_depth_2_unknown_shapes_decomposes(self): + # Symbolic split sizes: even with fully dynamic shapes, both Join + # levels decompose down to leaf-level dots. + a = pt.matrix("a") + b = pt.matrix("b") + c = pt.matrix("c") + d = pt.matrix("d") + x = pt.matrix("x") + out = pt.block([[a, b], [c, d]]) @ x + fn = pytensor.function([a, b, c, d, x], out, mode=rewrite_mode) + assert self._n_dots(fn) == 4 + From f50007c7424937b2c1f27727e1416528525d82f6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:29:23 -0500 Subject: [PATCH 03/10] Add local_transpose_of_join: push mT inside Join Matrix-transpose distributes through Join by transposing each leaf and swapping concatenation axis when it's one of the last two. --- pytensor/tensor/rewriting/math.py | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 179e887aa1..979ef120c3 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -33,6 +33,7 @@ diagonal, expand_dims, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, @@ -322,6 +323,47 @@ def _walk_to_matmul(var): return replacements or None +@register_canonicalize +@register_stabilize +@node_rewriter([DimShuffle]) +def local_transpose_of_join(fgraph, node): + r"""Rewrite Join(axis, *inputs).mT to Join(axis, *[inp.mT for inp in inputs]) + + Swap axis=-1 <-> axis=-2 when axis is one of the matrix axes and leave batch axes unchanged. + """ + if not node.op.is_matrix_transpose: + return None + + [src] = node.inputs + if src.owner is None or not isinstance(src.owner.op, Join): + return None + + try: + join_axis = int( + get_underlying_scalar_constant_value( + src.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + + src_ndim = src.type.ndim + if join_axis < 0: + join_axis += src_ndim + + if join_axis == src_ndim - 1: + new_axis = src_ndim - 2 + elif join_axis == src_ndim - 2: + new_axis = src_ndim - 1 + else: + new_axis = join_axis # batch axis, mT doesn't touch it + + transposed_inputs = [inp.mT for inp in src.owner.inputs[1:]] + new_out = join(new_axis, *transposed_inputs) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): From ad8b3682a01c40ed2719737e155f53df897b48dc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:32:15 -0500 Subject: [PATCH 04/10] Add local_nested_join_to_block_diagonal canonicalization Recognize a square 2-D nested-Join with statically-zero off-diagonals and rewrite to BlockDiagonal so the existing block-diag rewrites can fire on user-written or rewrite-induced concat patterns. --- pytensor/tensor/rewriting/math.py | 87 ++++++++++++++++++++++++++++- tests/tensor/rewriting/test_math.py | 48 ++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 979ef120c3..affc983280 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -46,7 +46,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast -from pytensor.tensor.linalg.constructors import BlockDiagonal +from pytensor.tensor.linalg.constructors import BlockDiagonal, block_diag from pytensor.tensor.math import ( Dot, Prod, @@ -364,6 +364,91 @@ def local_transpose_of_join(fgraph, node): return [new_out] +@register_canonicalize +@register_stabilize +@node_rewriter([Join]) +def local_nested_join_to_block_diagonal(fgraph, node): + r"""Recognize a square 2-D block-matrix-shaped concatenation whose + off-diagonal entries are statically zero, and rewrite to :func:`block_diag`. + + Detects ``Join(axis=-2, *Join(axis=-1, ...))`` -- an outer row-concat whose + every input is itself a column-concat -- with uniform structure (n x n + square grid) and statically-zero off-diagonal leaves. Replaces with + ``BlockDiagonal`` to unlock its targeted rewrites (det, diag, trace, dot, + solve pushdowns). + """ + out_ndim = node.outputs[0].type.ndim + if out_ndim < 2: + return None + + try: + outer_axis = int( + get_underlying_scalar_constant_value( + node.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if outer_axis < 0: + outer_axis += out_ndim + if outer_axis != out_ndim - 2: + return None + + rows = node.inputs[1:] + n_rows = len(rows) + if n_rows < 2: + return None + + leaves = [] + n_cols = None + for row in rows: + if row.owner is None or not isinstance(row.owner.op, Join): + return None + try: + inner_axis = int( + get_underlying_scalar_constant_value( + row.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if inner_axis < 0: + inner_axis += row.type.ndim + if inner_axis != row.type.ndim - 1: + return None + row_leaves = list(row.owner.inputs[1:]) + if n_cols is None: + n_cols = len(row_leaves) + elif len(row_leaves) != n_cols: + return None # ragged + leaves.append(row_leaves) + + # Square grid only. + if n_rows != n_cols: + return None + + diag_blocks = [] + for i in range(n_rows): + for j in range(n_cols): + leaf = leaves[i][j] + if i == j: + diag_blocks.append(leaf) + elif ( + get_underlying_scalar_constant_value( + leaf, only_process_constants=False, raise_not_constant=False + ) + != 0 + ): + return None + + if len(diag_blocks) < 2: + return None + + new_out = block_diag(*diag_blocks) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 48be68d87c..1f0662b371 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5118,6 +5118,54 @@ def test_rewrite_does_not_apply(self): assert_equal_computations([rewritten], [original]) +class TestNestedJoinToBlockDiagonal: + @staticmethod + def _has_block_diagonal(fn): + return any( + isinstance(n.op, Blockwise) and isinstance(n.op.core_op, BlockDiagonal) + for n in fn.maker.fgraph.toposort() + ) + + def test_zeros_off_diagonal_canonicalizes(self): + # Square nested-Join with statically-zero off-diagonals -> BlockDiagonal. + a = pt.tensor("a", shape=(3, 3)) + d = pt.tensor("d", shape=(4, 4)) + M = pt.block([[a, pt.zeros((3, 4))], [pt.zeros((4, 3)), d]]) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 3)) + d_v = rng.standard_normal((4, 4)) + fn = pytensor.function([a, d], M, mode=rewrite_mode) + np.testing.assert_allclose( + fn(a_v, d_v), + np.block([[a_v, np.zeros((3, 4))], [np.zeros((4, 3)), d_v]]), + atol=1e-12, + rtol=1e-12, + ) + + def test_nonzero_off_diagonal_skips(self): + # Off-diagonal isn't statically zero -> don't canonicalize. + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + M = pt.block([[a, b], [c, d]]) + + fn = pytensor.function([a, b, c, d], M, mode=rewrite_mode) + assert not self._has_block_diagonal(fn) + + def test_non_square_skips(self): + # 2x3 grid (non-square) -> not a candidate. + a = pt.tensor("a", shape=(3, 4)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(3, 4)) + z = pt.zeros((3, 4)) + M = pt.block([[a, z, z], [b, z, c]]) + + fn = pytensor.function([a, b, c], M, mode=rewrite_mode) + assert not self._has_block_diagonal(fn) + + class TestDotOfJoin: @staticmethod def _n_dots(fn): From c7d7a11d6b6997cff4681f1ab1a03209782f6631 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:35:54 -0500 Subject: [PATCH 05/10] Add local_split_of_join lifting Push Split through Join: matching axis with matching sizes returns the Join inputs directly; different axis distributes the Split per input. Unblocks Block@Block and X@S@X.T leaf-level decompositions. --- pytensor/tensor/rewriting/math.py | 110 +++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 147 ++++++++++++++++++++++++++++ 2 files changed, 257 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index affc983280..7bc83e1d0c 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -24,6 +24,7 @@ Alloc, Join, MakeVector, + Split, alloc, alloc_diag, as_tensor_variable, @@ -449,6 +450,115 @@ def local_nested_join_to_block_diagonal(fgraph, node): return [new_out] +def _const_int_vector(var): + """Extract a Python ``list[int]`` from a vector :class:`Variable` if its + contents are statically known. Handles ``Constant`` and ``MakeVector`` of + scalar constants. Returns ``None`` otherwise. + """ + if isinstance(var, Constant): + try: + arr = np.asarray(var.data) + except Exception: + return None + if arr.ndim != 1: + return None + return [int(x) for x in arr] + if var.owner is not None and isinstance(var.owner.op, MakeVector): + try: + return [ + int(get_underlying_scalar_constant_value(inp, raise_not_constant=True)) + for inp in var.owner.inputs + ] + except NotScalarConstantError: + return None + return None + + +@register_canonicalize +@register_stabilize +@node_rewriter([Split]) +def local_split_of_join(fgraph, node): + r"""Push :class:`Split` through :class:`Join`. + + Two cases are handled: + + - **Same axis, matching sizes.** ``Split(Join(a, X_0, ..., X_k), + [s_0, ..., s_k], axis=a)`` with ``s_i == X_i.shape[a]`` returns the + ``Join``'s inputs directly. The split exactly undoes the concat. + + - **Different axis.** ``Split(Join(a, X_0, ..., X_k), [s_0, ...], axis=b)`` + with ``a != b`` distributes the split through the join: each cut output + becomes ``Join(a, *[Split(X_i, axis=b)[k] for i])``. Slicing along an + orthogonal axis commutes with concatenation. + + Together these unblock the cascades that show up after dot-of-Join + decomposition (e.g. ``Block @ Block``, ``X @ S @ X.T``): the resulting + ``Split(Join(...))`` patterns collapse to per-leaf operations instead of + materializing the assembled intermediate. + """ + x, axis_var, splits_size_var = node.inputs + if x.owner is None or not isinstance(x.owner.op, Join): + return None + + try: + split_axis = int( + get_underlying_scalar_constant_value(axis_var, raise_not_constant=True) + ) + join_axis = int( + get_underlying_scalar_constant_value( + x.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + + out_ndim = x.type.ndim + if split_axis < 0: + split_axis += out_ndim + if join_axis < 0: + join_axis += out_ndim + + join_inputs = list(x.owner.inputs[1:]) + n_splits = len(node.outputs) + + if split_axis == join_axis: + # Matching axis: return Join's inputs directly when sizes line up. + if len(join_inputs) != n_splits: + return None + join_sizes = [inp.type.shape[join_axis] for inp in join_inputs] + if any(s is None for s in join_sizes): + return None + split_sizes = _const_int_vector(splits_size_var) + if split_sizes is None: + return None + if join_sizes != split_sizes: + return None + for inp in join_inputs: + copy_stack_trace(node.outputs[0], inp) + return list(join_inputs) + + # Different axis: distribute Split through Join. + per_input_splits = [ + split( + inp, + splits_size=splits_size_var, + n_splits=n_splits, + axis=split_axis, + ) + for inp in join_inputs + ] + new_outputs = [ + join( + join_axis, + *[per_input_splits[i][k] for i in range(len(join_inputs))], + ) + for k in range(n_splits) + ] + for new_out in new_outputs: + copy_stack_trace(node.outputs[0], new_out) + return new_outputs + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 1f0662b371..60fed3e033 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5366,3 +5366,150 @@ def test_depth_2_unknown_shapes_decomposes(self): fn = pytensor.function([a, b, c, d, x], out, mode=rewrite_mode) assert self._n_dots(fn) == 4 + def test_depth_2_block_at_block(self): + # ``pt.block(X) @ pt.block(Y)`` cascades through the rewrite. With + # split-of-join lifting, every leaf product is exposed: 4 result + # entries x 2 inner products = 8 leaf-level dots. + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + X = pt.block([[a, b], [c, d]]) + fn = pytensor.function([a, b, c, d], X @ X, mode=rewrite_mode) + assert self._n_dots(fn) == 8 + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 3)) + b_v = rng.standard_normal((3, 4)) + c_v = rng.standard_normal((4, 3)) + d_v = rng.standard_normal((4, 4)) + X_v = np.block([[a_v, b_v], [c_v, d_v]]) + np.testing.assert_allclose( + fn(a_v, b_v, c_v, d_v), X_v @ X_v, atol=1e-12, rtol=1e-12 + ) + + @pytest.mark.parametrize( + "lhs_t, rhs_t", + [(False, True), (True, False), (True, True)], + ids=["X@X.T", "X.T@X", "X.T@X.T"], + ) + def test_depth_2_block_at_block_with_transpose(self, lhs_t, rhs_t): + # Matrix-transposed nested-Joins get canonicalized via + # ``local_transpose_of_join`` (transpose pushes to leaves). Then the + # ``Block @ Block`` cascade applies. + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + X = pt.block([[a, b], [c, d]]) + lhs = X.mT if lhs_t else X + rhs = X.mT if rhs_t else X + fn = pytensor.function([a, b, c, d], lhs @ rhs, mode=rewrite_mode) + assert self._n_dots(fn) == 8 + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 3)) + b_v = rng.standard_normal((3, 4)) + c_v = rng.standard_normal((4, 3)) + d_v = rng.standard_normal((4, 4)) + X_v = np.block([[a_v, b_v], [c_v, d_v]]) + ref_lhs = X_v.T if lhs_t else X_v + ref_rhs = X_v.T if rhs_t else X_v + np.testing.assert_allclose( + fn(a_v, b_v, c_v, d_v), + ref_lhs @ ref_rhs, + atol=1e-12, + rtol=1e-12, + ) + + +class TestSplitOfJoin: + @staticmethod + def _has_split(fn): + from pytensor.tensor.basic import Split + + return any(isinstance(n.op, Split) for n in fn.maker.fgraph.toposort()) + + def test_matching_axis_matching_sizes_returns_inputs(self): + # Split(Join(axis=-2, A, B), [3, 4], axis=-2) -> A, B directly. + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + joined = pt.concatenate([a, b], axis=-2) + out_a, out_b = pt.split(joined, splits_size=[3, 4], n_splits=2, axis=-2) + + fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) + assert not self._has_split(fn) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((4, 5)) + ra, rb = fn(a_v, b_v) + np.testing.assert_array_equal(ra, a_v) + np.testing.assert_array_equal(rb, b_v) + + def test_matching_axis_mismatched_sizes_skips(self): + # Split sizes don't match Join input sizes: don't fire. + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + joined = pt.concatenate([a, b], axis=-2) + # 3+4=7 total, but split as [2, 5] -- boundaries don't align with [3, 4]. + out_a, out_b = pt.split(joined, splits_size=[2, 5], n_splits=2, axis=-2) + + fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) + assert self._has_split(fn) + + def test_different_axis_distributes(self): + # Split(Join(axis=-2, A, B), [2, 3], axis=-1) distributes as + # Join(-2, Split(A, [2,3], -1)[k], Split(B, [2,3], -1)[k]) per k. + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + joined = pt.concatenate([a, b], axis=-2) # shape (7, 5) + c0, c1 = pt.split(joined, splits_size=[2, 3], n_splits=2, axis=-1) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((4, 5)) + joined_v = np.concatenate([a_v, b_v], axis=-2) + ref0 = joined_v[:, :2] + ref1 = joined_v[:, 2:] + + fn = pytensor.function([a, b], [c0, c1], mode=rewrite_mode) + r0, r1 = fn(a_v, b_v) + np.testing.assert_allclose(r0, ref0) + np.testing.assert_allclose(r1, ref1) + + def test_x_at_s_at_xT_no_intermediate_materialization(self): + # X @ S @ X.T with X = Block. The (M, N) intermediate of X @ S no + # longer materializes -- no Split feeds from a Join (which would mean + # we just rebuilt and re-split a Block-shaped tensor). + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + S = pt.tensor("S", shape=(7, 7)) + X = pt.block([[a, b], [c, d]]) + fn = pytensor.function([a, b, c, d, S], X @ S @ X.mT, mode=rewrite_mode) + + from pytensor.tensor.basic import Join, Split + + for n in fn.maker.fgraph.toposort(): + if isinstance(n.op, Split): + src = n.inputs[0] + assert src.owner is None or not isinstance(src.owner.op, Join), ( + f"Split feeds from a Join -- intermediate concat not " + f"decomposed: {n}" + ) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 3)) + b_v = rng.standard_normal((3, 4)) + c_v = rng.standard_normal((4, 3)) + d_v = rng.standard_normal((4, 4)) + S_v = rng.standard_normal((7, 7)) + X_v = np.block([[a_v, b_v], [c_v, d_v]]) + np.testing.assert_allclose( + fn(a_v, b_v, c_v, d_v, S_v), + X_v @ S_v @ X_v.T, + atol=1e-12, + rtol=1e-12, + ) From 4f9462f2db794e40e3519cf7f250dd6434769fb6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 3 May 2026 22:36:34 -0500 Subject: [PATCH 06/10] Add 2x2 nested-Join helpers to linalg.rewriting.utils Shared by upcoming block-triangular solve / det rewrites: is_static_zero predicate and match_2x2_nested_join structural matcher. --- pytensor/tensor/rewriting/linalg/utils.py | 80 +++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg/utils.py b/pytensor/tensor/rewriting/linalg/utils.py index 7b5645250a..471938f4ea 100644 --- a/pytensor/tensor/rewriting/linalg/utils.py +++ b/pytensor/tensor/rewriting/linalg/utils.py @@ -14,11 +14,14 @@ from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ( Eye, + Join, TensorVariable, atleast_Nd, diagonal, + get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg.inverse import MatrixInverse, MatrixPinv from pytensor.tensor.math import variadic_mul from pytensor.tensor.rewriting.basic import ( @@ -43,6 +46,83 @@ } +def match_2x2_nested_join(var): + """Match ``Join(-2, Join(-1, A_11, A_12), Join(-1, A_21, A_22))`` — a 2x2 + block-matrix-shaped concat. + + Returns ``[[A_11, A_12], [A_21, A_22]]`` when: + + - The outer ``Join`` concatenates along the row axis (``ndim - 2``). + - Both inner ``Join`` ops concatenate along the column axis (``ndim - 1``). + - The grid is uniform 2x2. + - All four leaves' relevant dims are statically known and the diagonal + blocks are square; row heights and column widths line up. + + Else returns ``None``. + """ + if var.owner is None or not isinstance(var.owner.op, Join): + return None + + out_ndim = var.type.ndim + if out_ndim < 2: + return None + + try: + outer_axis = int( + get_underlying_scalar_constant_value( + var.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if outer_axis < 0: + outer_axis += out_ndim + if outer_axis != out_ndim - 2: + return None + + rows = var.owner.inputs[1:] + if len(rows) != 2: + return None + + leaves = [] + for row in rows: + if row.owner is None or not isinstance(row.owner.op, Join): + return None + try: + inner_axis = int( + get_underlying_scalar_constant_value( + row.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if inner_axis < 0: + inner_axis += row.type.ndim + if inner_axis != row.type.ndim - 1: + return None + row_leaves = list(row.owner.inputs[1:]) + if len(row_leaves) != 2: + return None + leaves.append(row_leaves) + + [[A_11, A_12], [A_21, A_22]] = leaves + + m1 = A_11.type.shape[-2] + m2 = A_22.type.shape[-2] + n1 = A_11.type.shape[-1] + n2 = A_22.type.shape[-1] + if any(s is None for s in (m1, m2, n1, n2)): + return None + if m1 != n1 or m2 != n2: + return None # diagonal blocks not square + if A_12.type.shape[-2] != m1 or A_12.type.shape[-1] != n2: + return None + if A_21.type.shape[-2] != m2 or A_21.type.shape[-1] != n1: + return None + + return leaves + + def matrix_diagonal_product(x): return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1) From 2db6c10da63be9471157b4fa17730876932b6eeb Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 19:59:07 -0400 Subject: [PATCH 07/10] Tighten docstrings on Join-rewrite chain --- pytensor/tensor/rewriting/linalg/utils.py | 15 ++------ pytensor/tensor/rewriting/math.py | 46 +++++++---------------- 2 files changed, 18 insertions(+), 43 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg/utils.py b/pytensor/tensor/rewriting/linalg/utils.py index 471938f4ea..09e07182a3 100644 --- a/pytensor/tensor/rewriting/linalg/utils.py +++ b/pytensor/tensor/rewriting/linalg/utils.py @@ -47,18 +47,11 @@ def match_2x2_nested_join(var): - """Match ``Join(-2, Join(-1, A_11, A_12), Join(-1, A_21, A_22))`` — a 2x2 - block-matrix-shaped concat. + """Return ``[[A_11, A_12], [A_21, A_22]]`` if ``var`` is a 2x2 nested ``Join``, else ``None``. - Returns ``[[A_11, A_12], [A_21, A_22]]`` when: - - - The outer ``Join`` concatenates along the row axis (``ndim - 2``). - - Both inner ``Join`` ops concatenate along the column axis (``ndim - 1``). - - The grid is uniform 2x2. - - All four leaves' relevant dims are statically known and the diagonal - blocks are square; row heights and column widths line up. - - Else returns ``None``. + Requires the outer ``Join`` along ``ndim - 2``, both inner ``Join`` ops along + ``ndim - 1``, statically-known row heights and column widths that line up, and + square diagonal blocks. """ if var.owner is None or not isinstance(var.owner.op, Join): return None diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 7bc83e1d0c..e43be2e50f 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -328,9 +328,10 @@ def _walk_to_matmul(var): @register_stabilize @node_rewriter([DimShuffle]) def local_transpose_of_join(fgraph, node): - r"""Rewrite Join(axis, *inputs).mT to Join(axis, *[inp.mT for inp in inputs]) + r"""Rewrite ``Join(axis, *xs).mT`` to ``Join(swapped_axis, *[x.mT for x in xs])``. - Swap axis=-1 <-> axis=-2 when axis is one of the matrix axes and leave batch axes unchanged. + Swap ``axis=-1`` and ``axis=-2`` when axis is one of the matrix axes; leave batch + axes unchanged. """ if not node.op.is_matrix_transpose: return None @@ -357,10 +358,9 @@ def local_transpose_of_join(fgraph, node): elif join_axis == src_ndim - 2: new_axis = src_ndim - 1 else: - new_axis = join_axis # batch axis, mT doesn't touch it + new_axis = join_axis # batch axis -- mT doesn't touch it - transposed_inputs = [inp.mT for inp in src.owner.inputs[1:]] - new_out = join(new_axis, *transposed_inputs) + new_out = join(new_axis, *[inp.mT for inp in src.owner.inputs[1:]]) copy_stack_trace(node.outputs[0], new_out) return [new_out] @@ -369,14 +369,10 @@ def local_transpose_of_join(fgraph, node): @register_stabilize @node_rewriter([Join]) def local_nested_join_to_block_diagonal(fgraph, node): - r"""Recognize a square 2-D block-matrix-shaped concatenation whose - off-diagonal entries are statically zero, and rewrite to :func:`block_diag`. - - Detects ``Join(axis=-2, *Join(axis=-1, ...))`` -- an outer row-concat whose - every input is itself a column-concat -- with uniform structure (n x n - square grid) and statically-zero off-diagonal leaves. Replaces with - ``BlockDiagonal`` to unlock its targeted rewrites (det, diag, trace, dot, - solve pushdowns). + r"""Rewrite a square ``n x n`` nested ``Join`` with zero off-diagonals to :func:`block_diag`. + + Matches ``Join(-2, *Join(-1, ...))`` -- an outer row-concat whose every input is a + column-concat -- forming a square grid with statically-zero off-diagonal leaves. """ out_ndim = node.outputs[0].type.ndim if out_ndim < 2: @@ -451,10 +447,7 @@ def local_nested_join_to_block_diagonal(fgraph, node): def _const_int_vector(var): - """Extract a Python ``list[int]`` from a vector :class:`Variable` if its - contents are statically known. Handles ``Constant`` and ``MakeVector`` of - scalar constants. Returns ``None`` otherwise. - """ + """Return a ``list`` of ints from a 1-D ``var`` whose entries are statically known, else ``None``.""" if isinstance(var, Constant): try: arr = np.asarray(var.data) @@ -480,21 +473,10 @@ def _const_int_vector(var): def local_split_of_join(fgraph, node): r"""Push :class:`Split` through :class:`Join`. - Two cases are handled: - - - **Same axis, matching sizes.** ``Split(Join(a, X_0, ..., X_k), - [s_0, ..., s_k], axis=a)`` with ``s_i == X_i.shape[a]`` returns the - ``Join``'s inputs directly. The split exactly undoes the concat. - - - **Different axis.** ``Split(Join(a, X_0, ..., X_k), [s_0, ...], axis=b)`` - with ``a != b`` distributes the split through the join: each cut output - becomes ``Join(a, *[Split(X_i, axis=b)[k] for i])``. Slicing along an - orthogonal axis commutes with concatenation. - - Together these unblock the cascades that show up after dot-of-Join - decomposition (e.g. ``Block @ Block``, ``X @ S @ X.T``): the resulting - ``Split(Join(...))`` patterns collapse to per-leaf operations instead of - materializing the assembled intermediate. + Same axis with matching sizes: ``Split(Join(a, *X), [|X_i|_a], axis=a)`` returns the + join's inputs directly. Different axis: ``Split(Join(a, *X), s, axis=b)`` distributes + to ``Join(a, *[Split(X_i, s, b)[k]])`` per cut -- slicing an orthogonal axis commutes + with concatenation. """ x, axis_var, splits_size_var = node.inputs if x.owner is None or not isinstance(x.owner.op, Join): From 0f98a4fd0f0a12172ed2900383c0185b881dcbcc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 20:36:01 -0400 Subject: [PATCH 08/10] Strengthen TestNestedJoinToBlockDiagonal positive case --- tests/tensor/rewriting/test_math.py | 34 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 60fed3e033..fb00e8ce10 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -156,6 +156,9 @@ rewrite_mode = "FAST_RUN" rewrite_mode = get_mode(rewrite_mode) +# Paired with ``rewrite_mode`` in rewrite tests as the un-rewritten reference. +no_opt_mode = Mode(linker="py", optimizer=None) + dimshuffle_lift = out2in(local_dimshuffle_lift) _stabilize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) @@ -5122,47 +5125,44 @@ class TestNestedJoinToBlockDiagonal: @staticmethod def _has_block_diagonal(fn): return any( - isinstance(n.op, Blockwise) and isinstance(n.op.core_op, BlockDiagonal) + isinstance(n.op, BlockDiagonal) + or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, BlockDiagonal)) for n in fn.maker.fgraph.toposort() ) def test_zeros_off_diagonal_canonicalizes(self): - # Square nested-Join with statically-zero off-diagonals -> BlockDiagonal. a = pt.tensor("a", shape=(3, 3)) d = pt.tensor("d", shape=(4, 4)) M = pt.block([[a, pt.zeros((3, 4))], [pt.zeros((4, 3)), d]]) + ref_fn = pytensor.function([a, d], M, mode=no_opt_mode) + rewr_fn = pytensor.function([a, d], M, mode=rewrite_mode) + assert self._has_block_diagonal(rewr_fn) + rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 3)) - d_v = rng.standard_normal((4, 4)) - fn = pytensor.function([a, d], M, mode=rewrite_mode) + values = [rng.standard_normal(s) for s in ((3, 3), (4, 4))] np.testing.assert_allclose( - fn(a_v, d_v), - np.block([[a_v, np.zeros((3, 4))], [np.zeros((4, 3)), d_v]]), - atol=1e-12, - rtol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) def test_nonzero_off_diagonal_skips(self): - # Off-diagonal isn't statically zero -> don't canonicalize. a = pt.tensor("a", shape=(3, 3)) b = pt.tensor("b", shape=(3, 4)) c = pt.tensor("c", shape=(4, 3)) d = pt.tensor("d", shape=(4, 4)) - M = pt.block([[a, b], [c, d]]) - - fn = pytensor.function([a, b, c, d], M, mode=rewrite_mode) + fn = pytensor.function( + [a, b, c, d], pt.block([[a, b], [c, d]]), mode=rewrite_mode + ) assert not self._has_block_diagonal(fn) def test_non_square_skips(self): - # 2x3 grid (non-square) -> not a candidate. a = pt.tensor("a", shape=(3, 4)) b = pt.tensor("b", shape=(3, 4)) c = pt.tensor("c", shape=(3, 4)) z = pt.zeros((3, 4)) - M = pt.block([[a, z, z], [b, z, c]]) - - fn = pytensor.function([a, b, c], M, mode=rewrite_mode) + fn = pytensor.function( + [a, b, c], pt.block([[a, z, z], [b, z, c]]), mode=rewrite_mode + ) assert not self._has_block_diagonal(fn) From 5c9be42384d1a3207a78c3ec68d8e0bc1a90afaa Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 20:39:22 -0400 Subject: [PATCH 09/10] Make test_x_at_s_at_xT_no_intermediate_materialization non-vacuous --- tests/tensor/rewriting/test_math.py | 67 +++++++++++------------------ 1 file changed, 26 insertions(+), 41 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index fb00e8ce10..a4cecb58bb 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,7 +31,7 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint, pprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor.basic import Alloc, constant, join, second, switch +from pytensor.tensor.basic import Alloc, Join, Split, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise @@ -5448,68 +5448,53 @@ def test_matching_axis_matching_sizes_returns_inputs(self): np.testing.assert_array_equal(rb, b_v) def test_matching_axis_mismatched_sizes_skips(self): - # Split sizes don't match Join input sizes: don't fire. + # Split sizes [2, 5] don't align with the Join's [3, 4] boundaries. a = pt.tensor("a", shape=(3, 5)) b = pt.tensor("b", shape=(4, 5)) - joined = pt.concatenate([a, b], axis=-2) - # 3+4=7 total, but split as [2, 5] -- boundaries don't align with [3, 4]. - out_a, out_b = pt.split(joined, splits_size=[2, 5], n_splits=2, axis=-2) + out_a, out_b = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[2, 5], n_splits=2, axis=-2 + ) fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) assert self._has_split(fn) def test_different_axis_distributes(self): - # Split(Join(axis=-2, A, B), [2, 3], axis=-1) distributes as - # Join(-2, Split(A, [2,3], -1)[k], Split(B, [2,3], -1)[k]) per k. a = pt.tensor("a", shape=(3, 5)) b = pt.tensor("b", shape=(4, 5)) - joined = pt.concatenate([a, b], axis=-2) # shape (7, 5) - c0, c1 = pt.split(joined, splits_size=[2, 3], n_splits=2, axis=-1) + c0, c1 = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[2, 3], n_splits=2, axis=-1 + ) + + ref_fn = pytensor.function([a, b], [c0, c1], mode=no_opt_mode) + rewr_fn = pytensor.function([a, b], [c0, c1], mode=rewrite_mode) rng = np.random.default_rng(0) a_v = rng.standard_normal((3, 5)) b_v = rng.standard_normal((4, 5)) - joined_v = np.concatenate([a_v, b_v], axis=-2) - ref0 = joined_v[:, :2] - ref1 = joined_v[:, 2:] - - fn = pytensor.function([a, b], [c0, c1], mode=rewrite_mode) - r0, r1 = fn(a_v, b_v) - np.testing.assert_allclose(r0, ref0) - np.testing.assert_allclose(r1, ref1) + for r, e in zip(rewr_fn(a_v, b_v), ref_fn(a_v, b_v), strict=True): + np.testing.assert_allclose(r, e) def test_x_at_s_at_xT_no_intermediate_materialization(self): - # X @ S @ X.T with X = Block. The (M, N) intermediate of X @ S no - # longer materializes -- no Split feeds from a Join (which would mean - # we just rebuilt and re-split a Block-shaped tensor). + # Block-triangular X gives ``local_dot_of_join`` something to fire on, producing + # Splits of S. Invariant: none of them feed from a Join (otherwise + # ``local_split_of_join`` failed to collapse a rebuilt intermediate). a = pt.tensor("a", shape=(3, 3)) - b = pt.tensor("b", shape=(3, 4)) c = pt.tensor("c", shape=(4, 3)) d = pt.tensor("d", shape=(4, 4)) S = pt.tensor("S", shape=(7, 7)) - X = pt.block([[a, b], [c, d]]) - fn = pytensor.function([a, b, c, d, S], X @ S @ X.mT, mode=rewrite_mode) + X = pt.block([[a, pt.zeros((3, 4))], [c, d]]) - from pytensor.tensor.basic import Join, Split + ref_fn = pytensor.function([a, c, d, S], X @ S @ X.mT, mode=no_opt_mode) + rewr_fn = pytensor.function([a, c, d, S], X @ S @ X.mT, mode=rewrite_mode) - for n in fn.maker.fgraph.toposort(): - if isinstance(n.op, Split): - src = n.inputs[0] - assert src.owner is None or not isinstance(src.owner.op, Join), ( - f"Split feeds from a Join -- intermediate concat not " - f"decomposed: {n}" - ) + splits = [n for n in rewr_fn.maker.fgraph.toposort() if isinstance(n.op, Split)] + assert splits, "expected the rewrite chain to introduce at least one Split" + for n in splits: + src = n.inputs[0] + assert src.owner is None or not isinstance(src.owner.op, Join) rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 3)) - b_v = rng.standard_normal((3, 4)) - c_v = rng.standard_normal((4, 3)) - d_v = rng.standard_normal((4, 4)) - S_v = rng.standard_normal((7, 7)) - X_v = np.block([[a_v, b_v], [c_v, d_v]]) + values = [rng.standard_normal(s) for s in ((3, 3), (4, 3), (4, 4), (7, 7))] np.testing.assert_allclose( - fn(a_v, b_v, c_v, d_v, S_v), - X_v @ S_v @ X_v.T, - atol=1e-12, - rtol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) From 0ca81918ff57afdd1330bb61185cc250756938b6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 26 May 2026 20:41:27 -0400 Subject: [PATCH 10/10] Gate local_dot_of_join on structured leaves --- pytensor/tensor/rewriting/math.py | 248 ++++++++++++++---- tests/tensor/rewriting/test_math.py | 392 ++++++++++++++-------------- 2 files changed, 396 insertions(+), 244 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e43be2e50f..18ef586fbc 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -9,7 +9,7 @@ import pytensor.scalar.basic as ps import pytensor.scalar.math as ps_math -from pytensor.assumptions import DIAGONAL, check_assumption +from pytensor.assumptions import DIAGONAL, SELECTION, check_assumption from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( NodeRewriter, @@ -22,6 +22,7 @@ from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.tensor.basic import ( Alloc, + Eye, Join, MakeVector, Split, @@ -234,47 +235,200 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): return {client.outputs[0]: new_output} -@register_stabilize -@node_rewriter([Join]) -def local_dot_of_join(fgraph, node): - r"""Push ``dot`` inside a :class:`Join`, decomposing the matmul into per-leaf products. +def _join_matmul_axis(var): + """Return ``-1`` or ``-2`` if ``var`` is a :class:`Join` along a matmul axis, else ``None``.""" + owner = var.owner + if owner is None or not isinstance(owner.op, Join): + return None + try: + axis = int( + get_underlying_scalar_constant_value( + owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + ndim = var.type.ndim + if axis < 0: + axis += ndim + if axis == ndim - 1: + return -1 + if axis == ndim - 2: + return -2 + return None - When ``Join`` runs along the matmul-contracted axis, ``Y`` is split by symbolic per-leaf sizes and - the per-leaf products are summed. Otherwise each leaf multiplies ``Y`` directly and the results are concatenated. - Walks through chains of left-``expand_dims`` ``DimShuffle`` nodes between the Join and the matmul - (Blockwise stacks pads this way). - """ +def _is_static_zero(var): + """``True`` if ``var`` folds to scalar zero (sees through ``Alloc`` and ``DimShuffle``).""" try: - join_axis = int( + return ( get_underlying_scalar_constant_value( - node.inputs[0], raise_not_constant=True + var, only_process_constants=False, raise_not_constant=True ) + == 0 ) except NotScalarConstantError: - return None + return False - out_ndim = node.outputs[0].type.ndim - if join_axis < 0: - join_axis += out_ndim - # Only the last two axes participate in matmul; other axes are batch. - if join_axis not in (out_ndim - 1, out_ndim - 2): + +def _is_static_identity(var): + """``True`` if ``var`` is statically a square identity ``I_n``.""" + if var.type.ndim != 2: + return False + owner = var.owner + if owner is not None and isinstance(owner.op, Eye): + n, m, k = owner.inputs + try: + k_val = int( + get_underlying_scalar_constant_value(k, raise_not_constant=True) + ) + except NotScalarConstantError: + return False + if k_val != 0: + return False + try: + n_val = int( + get_underlying_scalar_constant_value(n, raise_not_constant=True) + ) + m_val = int( + get_underlying_scalar_constant_value(m, raise_not_constant=True) + ) + except NotScalarConstantError: + return n is m + return n_val == m_val + if isinstance(var, Constant): + arr = np.asarray(var.data) + return ( + arr.ndim == 2 + and arr.shape[0] == arr.shape[1] + and np.array_equal(arr, np.eye(arr.shape[0], dtype=arr.dtype)) + ) + return False + + +def _block_kind(fgraph, var): + """Classify a matmul block as ``"zero"``, ``"identity"``, ``"selection"``, or ``"dense"``.""" + # Identity short-circuits SELECTION so ``I @ x`` folds to ``x`` instead of going + # through ``selection_dot_to_indexing``'s scatter. + if _is_static_zero(var): + return "zero" + if _is_static_identity(var): + return "identity" + if var.type.ndim >= 2 and check_assumption(fgraph, var, SELECTION): + return "selection" + return "dense" + + +def _decomposition_saves(fgraph, var): + """``True`` if some leaf in the nested-join tree at ``var`` is non-dense.""" + kind = _block_kind(fgraph, var) + if kind in ("zero", "identity", "selection"): + return True + if _join_matmul_axis(var) is not None: + return any(_decomposition_saves(fgraph, leaf) for leaf in var.owner.inputs[1:]) + return False + + +def _decompose_contracted(fgraph, leaves, other, dot_op, join_is_left): + """Decompose ``dot`` over a contracted-axis ``Join`` by splitting ``other`` per leaf.""" + size_axis = -1 if join_is_left else -2 + sizes = stack([leaf.shape[size_axis] for leaf in leaves]) + chunks = split( + other, + splits_size=sizes, + n_splits=len(leaves), + axis=-2 if join_is_left else -1, + ) + + def _sided(left, right): + return dot_op(left, right) if join_is_left else dot_op(right, left) + + terms = [] + dense_leaves: list = [] + dense_chunks: list = [] + for leaf, chunk in zip(leaves, chunks, strict=True): + match _block_kind(fgraph, leaf): + case "zero": + continue + case "identity": + terms.append(chunk) + case "selection": + terms.append(_sided(leaf, chunk)) + case _: + dense_leaves.append(leaf) + dense_chunks.append(chunk) + + if dense_leaves: + if len(dense_leaves) == 1: + terms.append(_sided(dense_leaves[0], dense_chunks[0])) + else: + # Re-join survivors into one smaller GEMM; gate guarantees at least one + # leaf was dropped, so this is strictly smaller than the parent. + d_leaf = join(-1 if join_is_left else -2, *dense_leaves) + d_chunk = join(-2 if join_is_left else -1, *dense_chunks) + terms.append(_sided(d_leaf, d_chunk)) + + if not terms: + return zeros_like(_sided(leaves[0], chunks[0])) + if len(terms) == 1: + return terms[0] + return add(*terms) + + +def _decompose_output( + fgraph, leaves, other, dot_op, join_is_left, concat_axis, out_dtype +): + """Decompose ``dot`` over an output-axis ``Join`` by emitting one block per leaf.""" + + def _sided(left, right): + return dot_op(left, right) if join_is_left else dot_op(right, left) + + blocks = [] + for leaf in leaves: + match _block_kind(fgraph, leaf): + case "zero": + blocks.append(zeros_like(_sided(leaf, other))) + case "identity": + blocks.append( + other if other.type.dtype == out_dtype else other.astype(out_dtype) + ) + case _: + blocks.append(_sided(leaf, other)) + # Dense leaves stay separate; re-joining them would reconstruct the parent ``Join``. + return concat_with_broadcast(blocks, axis=concat_axis) + + +@register_stabilize +@node_rewriter([Join]) +def local_dot_of_join(fgraph, node): + r"""Push ``dot`` inside a :class:`Join`, gated on the presence of structured leaves. + + Fires per matmul client only when some leaf of the (possibly nested) ``Join`` is a + static zero, the identity, or a :data:`SELECTION` matrix -- the categories where the + leaf's product reduces below a GEMM (drop, fold to chunk, or downstream + ``selection_dot_to_indexing`` collapses to indexing). All-dense block matmuls keep + their single GEMM. + + Along the contracted axis surviving dense leaves are re-joined into one smaller GEMM; + along an output axis blocks stay separate. Walks through left-``expand_dims`` + ``DimShuffle`` chains between the Join and the matmul (``Blockwise`` pads this way). + """ + matmul_axis = _join_matmul_axis(node.outputs[0]) + if matmul_axis is None: return None - leaves = node.inputs[1:] + leaves = list(node.inputs[1:]) if len(leaves) < 2: return None join_out = node.outputs[0] - # Translate Join's axis (in its own ndim) to a "matmul axis" tag: - # join_matmul_axis = -1 -> Join concatenates along the inner mat axis - # join_matmul_axis = -2 -> Join concatenates along the outer mat axis - join_matmul_axis = join_axis - out_ndim # -1 or -2 + if not _decomposition_saves(fgraph, join_out): + return None def _walk_to_matmul(var): - """Yield ``(matmul_node, input_idx)`` for every Dot/_matmul reachable - from ``var`` through a chain of left-expand-dims DimShuffles.""" for client, input_idx in fgraph.clients[var]: + if isinstance(client, str): + continue if client.op in (_dot, _matmul): yield client, input_idx elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: @@ -282,41 +436,27 @@ def _walk_to_matmul(var): replacements: dict = {} for client, client_idx in _walk_to_matmul(join_out): - if client.outputs[0] in replacements: - # ``dot(J, J)`` reaches the same matmul via both inputs, and - # ``ds(ds(J))`` (chained DimShuffles) reaches it via multiple - # paths. Either way: decompose once, let the next pass handle - # any side still wrapping the (now-fewer-clients) Join. + old_out = client.outputs[0] + if old_out in replacements: + # ``dot(J, J)`` and chained DimShuffles can reach the same matmul twice. continue + join_is_left = client_idx == 0 other = client.inputs[1 - client_idx] dot_op = client.op - old_out = client.outputs[0] + out_dtype = old_out.type.dtype - if client_idx == 0: - # Join @ other - if join_matmul_axis == -1: - widths = stack([leaf.shape[-1] for leaf in leaves]) - other_chunks = split(other, splits_size=widths, axis=-2) - terms = [ - dot_op(leaf, chunk) for leaf, chunk in zip(leaves, other_chunks) - ] - new_output = add(*terms) - else: - terms = [dot_op(leaf, other) for leaf in leaves] - new_output = concat_with_broadcast(terms, axis=-2) + contracted = (join_is_left and matmul_axis == -1) or ( + not join_is_left and matmul_axis == -2 + ) + if contracted: + new_output = _decompose_contracted( + fgraph, leaves, other, dot_op, join_is_left + ) else: - # other @ Join - if join_matmul_axis == -1: - terms = [dot_op(other, leaf) for leaf in leaves] - new_output = concat_with_broadcast(terms, axis=-1) - else: - heights = stack([leaf.shape[-2] for leaf in leaves]) - other_chunks = split(other, splits_size=heights, axis=-1) - terms = [ - dot_op(chunk, leaf) for chunk, leaf in zip(other_chunks, leaves) - ] - new_output = add(*terms) + new_output = _decompose_output( + fgraph, leaves, other, dot_op, join_is_left, matmul_axis, out_dtype + ) copy_stack_trace(old_out, new_output) replacements[old_out] = new_output diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a4cecb58bb..2739c6213f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -32,7 +32,7 @@ from pytensor.printing import debugprint, pprint from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.tensor.basic import Alloc, Join, Split, constant, join, second, switch -from pytensor.tensor.blas import Dot22, Gemv +from pytensor.tensor.blas import BatchedDot, Dot22, Dot22Scalar, Gemm, Gemv, Ger from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -117,6 +117,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 from pytensor.tensor.type import ( TensorType, cmatrix, @@ -5166,276 +5167,287 @@ def test_non_square_skips(self): assert not self._has_block_diagonal(fn) +_DOT_OPS = (Dot, Dot22, Dot22Scalar, Gemm, Gemv, Ger, CGemv, BatchedDot) + + class TestDotOfJoin: @staticmethod def _n_dots(fn): + # CVM mode fuses ``dot + add`` into Gemm and lowers Dot22 to Dot22Scalar / + # Gemv / Ger; count every matmul-like Op (and its Blockwise wrappings) so the + # assertion holds across linkers. return sum( - isinstance(n.op, Dot | Dot22) - or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, Dot | Dot22)) + isinstance(n.op, _DOT_OPS) + or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, _DOT_OPS)) for n in fn.maker.fgraph.toposort() ) @staticmethod - def _join_consumes_originals(fn, originals): - """Whether any Join in the graph still has any of ``originals`` as inputs. - - Used to confirm the original Join was decomposed (not just shuffled).""" - original_set = set(originals) - for n in fn.maker.fgraph.toposort(): - from pytensor.tensor.basic import Join - - if isinstance(n.op, Join) and any( - inp in original_set for inp in n.inputs[1:] - ): - return True - return False - - def test_join_lhs_axis_neg1(self): - # [A | B] @ y -> A @ y[:n_a] + B @ y[n_a:] + def _has_op(fn, op_type): + return any(isinstance(n.op, op_type) for n in fn.maker.fgraph.toposort()) + + @pytest.mark.parametrize( + "side, axis, a_shape, b_shape, y_shape", + [ + ("lhs", -1, (3, 4), (3, 5), (9, 6)), + ("lhs", -2, (3, 4), (2, 4), (4, 6)), + ("rhs", -1, (4, 3), (4, 5), (7, 4)), + ("rhs", -2, (3, 5), (2, 5), (7, 5)), + ], + ids=["lhs_axis-1", "lhs_axis-2", "rhs_axis-1", "rhs_axis-2"], + ) + def test_all_dense_preserved(self, side, axis, a_shape, b_shape, y_shape): + a = pt.tensor("a", shape=a_shape) + b = pt.tensor("b", shape=b_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, b], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, y], out, mode=rewrite_mode) + + # Join survival is the proof the rewrite chose not to fire. + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in (a_shape, b_shape, y_shape)] + np.testing.assert_allclose(rewr_fn(*values), ref_fn(*values), atol=1e-12) + + def test_unknown_shapes_still_preserved(self): + a = pt.matrix("a") + b = pt.matrix("b") + y = pt.matrix("y") + out = pt.concatenate([a, b], axis=-1) @ y + + ref_fn = pytensor.function([a, b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 4), (3, 5), (9, 6))] + np.testing.assert_allclose(rewr_fn(*values), ref_fn(*values), atol=1e-12) + + def test_single_input_join_skipped(self): a = pt.tensor("a", shape=(3, 4)) - b = pt.tensor("b", shape=(3, 5)) - y = pt.tensor("y", shape=(9, 6)) - M = pt.concatenate([a, b], axis=-1) - fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + y = pt.tensor("y", shape=(4, 5)) + out = pt.join(-1, a) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 rng = np.random.default_rng(0) a_v = rng.standard_normal((3, 4)) - b_v = rng.standard_normal((3, 5)) - y_v = rng.standard_normal((9, 6)) + y_v = rng.standard_normal((4, 5)) + np.testing.assert_allclose(rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12) + + @pytest.mark.parametrize( + "side, axis, a_shape, zero_shape, y_shape", + [ + ("lhs", -1, (3, 4), (3, 5), (9, 6)), + ("rhs", -2, (3, 5), (2, 5), (7, 5)), + ], + ids=["lhs_axis-1", "rhs_axis-2"], + ) + def test_zero_leaf_contracted_drops_dot( + self, side, axis, a_shape, zero_shape, y_shape + ): + a = pt.tensor("a", shape=a_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, pt.zeros(zero_shape)], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + + # Split + no Join is the signature of contracted-axis decomposition. + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Split) + assert not self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal(a_shape) + y_v = rng.standard_normal(y_shape) np.testing.assert_allclose( - fn(a_v, b_v, y_v), - np.concatenate([a_v, b_v], axis=-1) @ y_v, - atol=1e-12, + rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12, rtol=1e-12 ) - def test_join_lhs_axis_neg2(self): - # [[A], [B]] @ y -> concat([A @ y, B @ y], -2) - a = pt.tensor("a", shape=(3, 4)) - b = pt.tensor("b", shape=(2, 4)) - y = pt.tensor("y", shape=(4, 6)) - M = pt.concatenate([a, b], axis=-2) - fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + @pytest.mark.parametrize( + "side, axis, a_shape, zero_shape, y_shape", + [ + ("lhs", -2, (3, 4), (2, 4), (4, 6)), + ("rhs", -1, (4, 3), (4, 5), (7, 4)), + ], + ids=["lhs_axis-2", "rhs_axis-1"], + ) + def test_zero_leaf_output_emits_zero_block( + self, side, axis, a_shape, zero_shape, y_shape + ): + # Output-axis case: rewrite emits ``Join(Dot(a, other), zero_block)``. A Join + # still appears, so only numerical equivalence pins the rewrite's effect. + a = pt.tensor("a", shape=a_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, pt.zeros(zero_shape)], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 4)) - b_v = rng.standard_normal((2, 4)) - y_v = rng.standard_normal((4, 6)) + a_v = rng.standard_normal(a_shape) + y_v = rng.standard_normal(y_shape) np.testing.assert_allclose( - fn(a_v, b_v, y_v), - np.concatenate([a_v, b_v], axis=-2) @ y_v, - atol=1e-12, + rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12, rtol=1e-12 ) - def test_join_rhs_axis_neg1(self): - # y @ [A | B] -> concat([y @ A, y @ B], -1) + def test_identity_leaf_contracted_folds_to_chunk(self): a = pt.tensor("a", shape=(4, 3)) - b = pt.tensor("b", shape=(4, 5)) - y = pt.tensor("y", shape=(7, 4)) - M = pt.concatenate([a, b], axis=-1) - fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + y = pt.tensor("y", shape=(3 + 4, 6)) + out = pt.join(-1, a, pt.eye(4)) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + assert not self._has_op(rewr_fn, Join) rng = np.random.default_rng(0) - a_v = rng.standard_normal((4, 3)) - b_v = rng.standard_normal((4, 5)) - y_v = rng.standard_normal((7, 4)) + values = [rng.standard_normal(s) for s in ((4, 3), (7, 6))] np.testing.assert_allclose( - fn(a_v, b_v, y_v), - y_v @ np.concatenate([a_v, b_v], axis=-1), - atol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) - def test_join_rhs_axis_neg2(self): - # y @ [[A], [B]] -> y[..., :n_a] @ A + y[..., n_a:] @ B + def test_identity_leaf_output_replaces_with_operand(self): a = pt.tensor("a", shape=(3, 5)) - b = pt.tensor("b", shape=(2, 5)) - y = pt.tensor("y", shape=(7, 5)) - M = pt.concatenate([a, b], axis=-2) - fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + y = pt.tensor("y", shape=(5, 6)) + out = pt.concatenate([a, pt.eye(5)], axis=-2) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 5)) - b_v = rng.standard_normal((2, 5)) - y_v = rng.standard_normal((7, 5)) + values = [rng.standard_normal(s) for s in ((3, 5), (5, 6))] np.testing.assert_allclose( - fn(a_v, b_v, y_v), - y_v @ np.concatenate([a_v, b_v], axis=-2), - atol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) - def test_unknown_widths_lhs_decomposes(self): - # LHS axis=-1 splits y by symbolic leaf widths; fires on dynamic shapes. - a = pt.matrix("a") - b = pt.matrix("b") - y = pt.matrix("y") - M = pt.concatenate([a, b], axis=-1) - fn = pytensor.function([a, b, y], M @ y, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + def test_selection_leaf_collapses_to_indexing(self): + # SELECTION leaf becomes its own Dot, then ``selection_dot_to_indexing`` turns + # it into an advanced index. ``y @ S`` hits the gather path. + b = pt.tensor("b", shape=(4, 5)) + perm = pt.constant(np.array([2, 0, 3, 1], dtype=np.int64)) + S = pt.eye(4)[:, perm] + y = pt.tensor("y", shape=(7, 4)) + out = y @ pt.concatenate([b, S], axis=-1) - rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 4)) - b_v = rng.standard_normal((3, 5)) - y_v = rng.standard_normal((9, 6)) - np.testing.assert_allclose( - fn(a_v, b_v, y_v), - np.concatenate([a_v, b_v], axis=-1) @ y_v, - atol=1e-12, - ) + ref_fn = pytensor.function([b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([b, y], out, mode=rewrite_mode) - def test_unknown_heights_rhs_decomposes(self): - # RHS axis=-2 splits y by symbolic leaf heights; fires on dynamic shapes. - a = pt.matrix("a") - b = pt.matrix("b") - y = pt.matrix("y") - M = pt.concatenate([a, b], axis=-2) - fn = pytensor.function([a, b, y], y @ M, mode=rewrite_mode) - assert self._n_dots(fn) == 2 - assert not self._join_consumes_originals(fn, [a, b]) + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, AdvancedSubtensor | AdvancedSubtensor1) rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 5)) - b_v = rng.standard_normal((2, 5)) - y_v = rng.standard_normal((7, 5)) + values = [rng.standard_normal(s) for s in ((4, 5), (7, 4))] np.testing.assert_allclose( - fn(a_v, b_v, y_v), - y_v @ np.concatenate([a_v, b_v], axis=-2), - atol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) - def test_single_input_join_skipped(self): - # A "Join" with one input is a no-op; skip. - a = pt.tensor("a", shape=(3, 4)) - y = pt.tensor("y", shape=(4, 5)) - M = pt.join(-1, a) # single-input join - fn = pytensor.function([a, y], M @ y, mode=rewrite_mode) - assert self._n_dots(fn) == 1 - - @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) - def test_depth_2_block_at_x(self, left_multiply): - # ``pt.block([[a, b], [c, d]])`` produces nested Joins. Iterated - # ``local_dot_of_join`` decomposes both levels into leaf-level dots. + def test_all_dense_block_preserved(self): a = pt.tensor("a", shape=(2, 3)) b = pt.tensor("b", shape=(2, 4)) c = pt.tensor("c", shape=(5, 3)) d = pt.tensor("d", shape=(5, 4)) - M = pt.block([[a, b], [c, d]]) + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[a, b], [c, d]]) @ other - if left_multiply: - other = pt.tensor("other", shape=(7, 6)) - out = M @ other - else: - other = pt.tensor("other", shape=(8, 7)) - out = other @ M + ref_fn = pytensor.function([a, b, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, c, d, other], out, mode=rewrite_mode) - fn = pytensor.function([a, b, c, d, other], out, mode=rewrite_mode) - # 4 leaf-level dots, regardless of side. - assert self._n_dots(fn) == 4 + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) rng = np.random.default_rng(0) - a_v = rng.standard_normal(a.type.shape) - b_v = rng.standard_normal(b.type.shape) - c_v = rng.standard_normal(c.type.shape) - d_v = rng.standard_normal(d.type.shape) - other_v = rng.standard_normal(other.type.shape) - ref_M = np.block([[a_v, b_v], [c_v, d_v]]) - expected = ref_M @ other_v if left_multiply else other_v @ ref_M + values = [rng.standard_normal(t.type.shape) for t in (a, b, c, d, other)] np.testing.assert_allclose( - fn(a_v, b_v, c_v, d_v, other_v), - expected, - atol=1e-12, - rtol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) - def test_depth_2_unknown_shapes_decomposes(self): - # Symbolic split sizes: even with fully dynamic shapes, both Join - # levels decompose down to leaf-level dots. - a = pt.matrix("a") - b = pt.matrix("b") - c = pt.matrix("c") - d = pt.matrix("d") - x = pt.matrix("x") - out = pt.block([[a, b], [c, d]]) @ x - fn = pytensor.function([a, b, c, d, x], out, mode=rewrite_mode) - assert self._n_dots(fn) == 4 - - def test_depth_2_block_at_block(self): - # ``pt.block(X) @ pt.block(Y)`` cascades through the rewrite. With - # split-of-join lifting, every leaf product is exposed: 4 result - # entries x 2 inner products = 8 leaf-level dots. + def test_block_triangular_decomposes(self): + # Top row's zero drops its product; bottom row stays as one GEMM. 2 Dots. a = pt.tensor("a", shape=(3, 3)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[a, pt.zeros((3, 4))], [c, d]]) @ other + + ref_fn = pytensor.function([a, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, c, d, other], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 2 + + rng = np.random.default_rng(0) + values = [rng.standard_normal(t.type.shape) for t in (a, c, d, other)] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_block_with_identity_decomposes(self): + # Top row's identity folds to operand chunk; bottom row stays as one GEMM. b = pt.tensor("b", shape=(3, 4)) c = pt.tensor("c", shape=(4, 3)) d = pt.tensor("d", shape=(4, 4)) - X = pt.block([[a, b], [c, d]]) - fn = pytensor.function([a, b, c, d], X @ X, mode=rewrite_mode) - assert self._n_dots(fn) == 8 + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[pt.eye(3), b], [c, d]]) @ other + + ref_fn = pytensor.function([b, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([b, c, d, other], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 2 rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 3)) - b_v = rng.standard_normal((3, 4)) - c_v = rng.standard_normal((4, 3)) - d_v = rng.standard_normal((4, 4)) - X_v = np.block([[a_v, b_v], [c_v, d_v]]) + values = [rng.standard_normal(t.type.shape) for t in (b, c, d, other)] np.testing.assert_allclose( - fn(a_v, b_v, c_v, d_v), X_v @ X_v, atol=1e-12, rtol=1e-12 + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) @pytest.mark.parametrize( "lhs_t, rhs_t", - [(False, True), (True, False), (True, True)], - ids=["X@X.T", "X.T@X", "X.T@X.T"], + [(False, False), (False, True), (True, False), (True, True)], + ids=["X@X", "X@X.T", "X.T@X", "X.T@X.T"], ) - def test_depth_2_block_at_block_with_transpose(self, lhs_t, rhs_t): - # Matrix-transposed nested-Joins get canonicalized via - # ``local_transpose_of_join`` (transpose pushes to leaves). Then the - # ``Block @ Block`` cascade applies. + def test_dense_block_at_dense_block_preserved(self, lhs_t, rhs_t): a = pt.tensor("a", shape=(3, 3)) b = pt.tensor("b", shape=(3, 4)) c = pt.tensor("c", shape=(4, 3)) d = pt.tensor("d", shape=(4, 4)) X = pt.block([[a, b], [c, d]]) - lhs = X.mT if lhs_t else X - rhs = X.mT if rhs_t else X - fn = pytensor.function([a, b, c, d], lhs @ rhs, mode=rewrite_mode) - assert self._n_dots(fn) == 8 + out = (X.mT if lhs_t else X) @ (X.mT if rhs_t else X) + + ref_fn = pytensor.function([a, b, c, d], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) rng = np.random.default_rng(0) - a_v = rng.standard_normal((3, 3)) - b_v = rng.standard_normal((3, 4)) - c_v = rng.standard_normal((4, 3)) - d_v = rng.standard_normal((4, 4)) - X_v = np.block([[a_v, b_v], [c_v, d_v]]) - ref_lhs = X_v.T if lhs_t else X_v - ref_rhs = X_v.T if rhs_t else X_v + values = [rng.standard_normal(s) for s in ((3, 3), (3, 4), (4, 3), (4, 4))] np.testing.assert_allclose( - fn(a_v, b_v, c_v, d_v), - ref_lhs @ ref_rhs, - atol=1e-12, - rtol=1e-12, + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 ) class TestSplitOfJoin: @staticmethod def _has_split(fn): - from pytensor.tensor.basic import Split - return any(isinstance(n.op, Split) for n in fn.maker.fgraph.toposort()) def test_matching_axis_matching_sizes_returns_inputs(self): - # Split(Join(axis=-2, A, B), [3, 4], axis=-2) -> A, B directly. a = pt.tensor("a", shape=(3, 5)) b = pt.tensor("b", shape=(4, 5)) - joined = pt.concatenate([a, b], axis=-2) - out_a, out_b = pt.split(joined, splits_size=[3, 4], n_splits=2, axis=-2) + out_a, out_b = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[3, 4], n_splits=2, axis=-2 + ) fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) assert not self._has_split(fn)