diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..0c11c36fe5 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3020,28 +3020,146 @@ def vertical_stack(*args): return concatenate(_args, axis=0) -def is_flat(var, ndim=1): - """ - Verifies the dimensionality of the var is equal to - ndim. This method is usually called after flatten method on a - variable, where the first ndim-1 dimension size(s) of the variable - is kept intact, and the last dimension size of the variable is made - equal to the multiplication of its remaining dimension size(s), such that - the variable would end up with as many dimension as ndim. +def _block_check_depths_match(arrays, parent_index=()): + """Walk a nested block-list and check every leaf sits at the same depth. Parameters ---------- - var : pytensor.tensor.var.TensorVariable - the pytensor var on which the dimensionality is checked. + arrays : list or array_like + Nested block-list to validate. + parent_index : tuple of int, optional + Indices accumulated from the root, used in error messages. Default ``()``. - ndim : int - the expected dimensionality of var. + Returns + ------- + structure : nested tuple of None + Tree shape with ``None`` at each leaf position. + leaf_depth : int + Depth at which every leaf sits. + max_leaf_ndim : int + Largest ``ndim`` across all leaves. + """ + if isinstance(arrays, list): + if not arrays: + raise ValueError("Block: empty list is not allowed") + children = [] + first_leaf_depth = None + max_ndim = 0 + for i, child in enumerate(arrays): + child_struct, child_leaf_depth, child_ndim = _block_check_depths_match( + child, (*parent_index, i) + ) + if first_leaf_depth is None: + first_leaf_depth = child_leaf_depth + elif first_leaf_depth != child_leaf_depth: + raise ValueError( + "Block: all leaves must be at the same nesting depth " + f"(got depth {child_leaf_depth} at index {(*parent_index, i)}, " + f"expected {first_leaf_depth})" + ) + if child_ndim > max_ndim: + max_ndim = child_ndim + children.append(child_struct) + return tuple(children), first_leaf_depth, max_ndim + elif isinstance(arrays, tuple): + raise TypeError("Block: tuples are not allowed as nested containers; use lists") + else: + leaf = as_tensor_variable(arrays) + return None, len(parent_index), leaf.type.ndim + + +def block(arrays): + """Assemble a tensor from nested lists of blocks, like ``numpy.block``. + + Parameters + ---------- + arrays : nested list of array_like + Tensors at the leaves, lists at the interior. Every leaf must sit at + the same nesting depth ``d``; the concatenation spans the last ``d`` + axes. Returns ------- - bool - the comparison result of var's dim - and the expected outdim. + result : TensorVariable + Assembled block tensor. A bare tensor (no list wrapping) returns as + ``atleast_1d(arrays)``. + + Examples + -------- + .. testcode:: + + import numpy as np + import pytensor.tensor as pt + + A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]])) + B = pt.as_tensor_variable(np.array([[5], [6]])) + C = pt.as_tensor_variable(np.array([[7, 8]])) + D = pt.as_tensor_variable(np.array([[9]])) + M = pt.block([[A, B], [C, D]]) + print(M.eval()) + + .. testoutput:: + + [[1 2 5] + [3 4 6] + [7 8 9]] + """ + structure, _, _ = _block_check_depths_match(arrays) + + if structure is None: + return atleast_Nd(arrays, n=1) + + flat = [] + + def _gather(node): + if isinstance(node, list): + for child in node: + _gather(child) + else: + flat.append(as_tensor_variable(node)) + + _gather(arrays) + + def _structure_depth(structure): + if structure is None: + return 0 + return 1 + _structure_depth(structure[0]) + + list_ndim = _structure_depth(structure) + result_ndim = builtins.max(list_ndim, builtins.max(inp.type.ndim for inp in flat)) + promoted = [atleast_Nd(inp, n=result_ndim) for inp in flat] + + def _unflatten_structure(flat, structure): + """Rebuild a nested list from ``flat`` consumed in pre-order against ``structure``.""" + it = iter(flat) + + def _build(s): + if s is None: + return next(it) + return [_build(child) for child in s] + + return _build(structure) + + nested = _unflatten_structure(promoted, structure) + + def _recurse(node, depth): + if depth == list_ndim: + return node + children = [_recurse(child, depth + 1) for child in node] + return concatenate(children, axis=-(list_ndim - depth)) + + return _recurse(nested, 0) + + +def is_flat(var, ndim=1): + """Return ``True`` when ``var`` has exactly ``ndim`` dimensions. + + Parameters + ---------- + var : TensorVariable + Variable to inspect. + ndim : int + Expected number of dimensions. Default 1. """ return var.ndim == ndim @@ -4558,6 +4676,7 @@ def ix_(*args): "atleast_2d", "atleast_3d", "atleast_Nd", + "block", "cast", "choose", "concatenate", diff --git a/pytensor/tensor/rewriting/linalg/utils.py b/pytensor/tensor/rewriting/linalg/utils.py index 7b5645250a..09e07182a3 100644 --- a/pytensor/tensor/rewriting/linalg/utils.py +++ b/pytensor/tensor/rewriting/linalg/utils.py @@ -14,11 +14,14 @@ from pytensor.scalar.basic import Mul from pytensor.tensor.basic import ( Eye, + Join, TensorVariable, atleast_Nd, diagonal, + get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg.inverse import MatrixInverse, MatrixPinv from pytensor.tensor.math import variadic_mul from pytensor.tensor.rewriting.basic import ( @@ -43,6 +46,76 @@ } +def match_2x2_nested_join(var): + """Return ``[[A_11, A_12], [A_21, A_22]]`` if ``var`` is a 2x2 nested ``Join``, else ``None``. + + Requires the outer ``Join`` along ``ndim - 2``, both inner ``Join`` ops along + ``ndim - 1``, statically-known row heights and column widths that line up, and + square diagonal blocks. + """ + if var.owner is None or not isinstance(var.owner.op, Join): + return None + + out_ndim = var.type.ndim + if out_ndim < 2: + return None + + try: + outer_axis = int( + get_underlying_scalar_constant_value( + var.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if outer_axis < 0: + outer_axis += out_ndim + if outer_axis != out_ndim - 2: + return None + + rows = var.owner.inputs[1:] + if len(rows) != 2: + return None + + leaves = [] + for row in rows: + if row.owner is None or not isinstance(row.owner.op, Join): + return None + try: + inner_axis = int( + get_underlying_scalar_constant_value( + row.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if inner_axis < 0: + inner_axis += row.type.ndim + if inner_axis != row.type.ndim - 1: + return None + row_leaves = list(row.owner.inputs[1:]) + if len(row_leaves) != 2: + return None + leaves.append(row_leaves) + + [[A_11, A_12], [A_21, A_22]] = leaves + + m1 = A_11.type.shape[-2] + m2 = A_22.type.shape[-2] + n1 = A_11.type.shape[-1] + n2 = A_22.type.shape[-1] + if any(s is None for s in (m1, m2, n1, n2)): + return None + if m1 != n1 or m2 != n2: + return None # diagonal blocks not square + if A_12.type.shape[-2] != m1 or A_12.type.shape[-1] != n2: + return None + if A_21.type.shape[-2] != m2 or A_21.type.shape[-1] != n1: + return None + + return leaves + + def matrix_diagonal_product(x): return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b04892d6ce..18ef586fbc 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -9,7 +9,7 @@ import pytensor.scalar.basic as ps import pytensor.scalar.math as ps_math -from pytensor.assumptions import DIAGONAL, check_assumption +from pytensor.assumptions import DIAGONAL, SELECTION, check_assumption from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( NodeRewriter, @@ -22,8 +22,10 @@ from pytensor.graph.rewriting.utils import get_clients_at_depth from pytensor.tensor.basic import ( Alloc, + Eye, Join, MakeVector, + Split, alloc, alloc_diag, as_tensor_variable, @@ -33,10 +35,12 @@ diagonal, expand_dims, get_underlying_scalar_constant_value, + join, moveaxis, ones_like, register_infer_shape, split, + stack, switch, zeros, zeros_like, @@ -44,7 +48,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast -from pytensor.tensor.linalg.constructors import BlockDiagonal +from pytensor.tensor.linalg.constructors import BlockDiagonal, block_diag from pytensor.tensor.math import ( Dot, Prod, @@ -231,6 +235,452 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): return {client.outputs[0]: new_output} +def _join_matmul_axis(var): + """Return ``-1`` or ``-2`` if ``var`` is a :class:`Join` along a matmul axis, else ``None``.""" + owner = var.owner + if owner is None or not isinstance(owner.op, Join): + return None + try: + axis = int( + get_underlying_scalar_constant_value( + owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + ndim = var.type.ndim + if axis < 0: + axis += ndim + if axis == ndim - 1: + return -1 + if axis == ndim - 2: + return -2 + return None + + +def _is_static_zero(var): + """``True`` if ``var`` folds to scalar zero (sees through ``Alloc`` and ``DimShuffle``).""" + try: + return ( + get_underlying_scalar_constant_value( + var, only_process_constants=False, raise_not_constant=True + ) + == 0 + ) + except NotScalarConstantError: + return False + + +def _is_static_identity(var): + """``True`` if ``var`` is statically a square identity ``I_n``.""" + if var.type.ndim != 2: + return False + owner = var.owner + if owner is not None and isinstance(owner.op, Eye): + n, m, k = owner.inputs + try: + k_val = int( + get_underlying_scalar_constant_value(k, raise_not_constant=True) + ) + except NotScalarConstantError: + return False + if k_val != 0: + return False + try: + n_val = int( + get_underlying_scalar_constant_value(n, raise_not_constant=True) + ) + m_val = int( + get_underlying_scalar_constant_value(m, raise_not_constant=True) + ) + except NotScalarConstantError: + return n is m + return n_val == m_val + if isinstance(var, Constant): + arr = np.asarray(var.data) + return ( + arr.ndim == 2 + and arr.shape[0] == arr.shape[1] + and np.array_equal(arr, np.eye(arr.shape[0], dtype=arr.dtype)) + ) + return False + + +def _block_kind(fgraph, var): + """Classify a matmul block as ``"zero"``, ``"identity"``, ``"selection"``, or ``"dense"``.""" + # Identity short-circuits SELECTION so ``I @ x`` folds to ``x`` instead of going + # through ``selection_dot_to_indexing``'s scatter. + if _is_static_zero(var): + return "zero" + if _is_static_identity(var): + return "identity" + if var.type.ndim >= 2 and check_assumption(fgraph, var, SELECTION): + return "selection" + return "dense" + + +def _decomposition_saves(fgraph, var): + """``True`` if some leaf in the nested-join tree at ``var`` is non-dense.""" + kind = _block_kind(fgraph, var) + if kind in ("zero", "identity", "selection"): + return True + if _join_matmul_axis(var) is not None: + return any(_decomposition_saves(fgraph, leaf) for leaf in var.owner.inputs[1:]) + return False + + +def _decompose_contracted(fgraph, leaves, other, dot_op, join_is_left): + """Decompose ``dot`` over a contracted-axis ``Join`` by splitting ``other`` per leaf.""" + size_axis = -1 if join_is_left else -2 + sizes = stack([leaf.shape[size_axis] for leaf in leaves]) + chunks = split( + other, + splits_size=sizes, + n_splits=len(leaves), + axis=-2 if join_is_left else -1, + ) + + def _sided(left, right): + return dot_op(left, right) if join_is_left else dot_op(right, left) + + terms = [] + dense_leaves: list = [] + dense_chunks: list = [] + for leaf, chunk in zip(leaves, chunks, strict=True): + match _block_kind(fgraph, leaf): + case "zero": + continue + case "identity": + terms.append(chunk) + case "selection": + terms.append(_sided(leaf, chunk)) + case _: + dense_leaves.append(leaf) + dense_chunks.append(chunk) + + if dense_leaves: + if len(dense_leaves) == 1: + terms.append(_sided(dense_leaves[0], dense_chunks[0])) + else: + # Re-join survivors into one smaller GEMM; gate guarantees at least one + # leaf was dropped, so this is strictly smaller than the parent. + d_leaf = join(-1 if join_is_left else -2, *dense_leaves) + d_chunk = join(-2 if join_is_left else -1, *dense_chunks) + terms.append(_sided(d_leaf, d_chunk)) + + if not terms: + return zeros_like(_sided(leaves[0], chunks[0])) + if len(terms) == 1: + return terms[0] + return add(*terms) + + +def _decompose_output( + fgraph, leaves, other, dot_op, join_is_left, concat_axis, out_dtype +): + """Decompose ``dot`` over an output-axis ``Join`` by emitting one block per leaf.""" + + def _sided(left, right): + return dot_op(left, right) if join_is_left else dot_op(right, left) + + blocks = [] + for leaf in leaves: + match _block_kind(fgraph, leaf): + case "zero": + blocks.append(zeros_like(_sided(leaf, other))) + case "identity": + blocks.append( + other if other.type.dtype == out_dtype else other.astype(out_dtype) + ) + case _: + blocks.append(_sided(leaf, other)) + # Dense leaves stay separate; re-joining them would reconstruct the parent ``Join``. + return concat_with_broadcast(blocks, axis=concat_axis) + + +@register_stabilize +@node_rewriter([Join]) +def local_dot_of_join(fgraph, node): + r"""Push ``dot`` inside a :class:`Join`, gated on the presence of structured leaves. + + Fires per matmul client only when some leaf of the (possibly nested) ``Join`` is a + static zero, the identity, or a :data:`SELECTION` matrix -- the categories where the + leaf's product reduces below a GEMM (drop, fold to chunk, or downstream + ``selection_dot_to_indexing`` collapses to indexing). All-dense block matmuls keep + their single GEMM. + + Along the contracted axis surviving dense leaves are re-joined into one smaller GEMM; + along an output axis blocks stay separate. Walks through left-``expand_dims`` + ``DimShuffle`` chains between the Join and the matmul (``Blockwise`` pads this way). + """ + matmul_axis = _join_matmul_axis(node.outputs[0]) + if matmul_axis is None: + return None + + leaves = list(node.inputs[1:]) + if len(leaves) < 2: + return None + + join_out = node.outputs[0] + if not _decomposition_saves(fgraph, join_out): + return None + + def _walk_to_matmul(var): + for client, input_idx in fgraph.clients[var]: + if isinstance(client, str): + continue + if client.op in (_dot, _matmul): + yield client, input_idx + elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: + yield from _walk_to_matmul(client.outputs[0]) + + replacements: dict = {} + for client, client_idx in _walk_to_matmul(join_out): + old_out = client.outputs[0] + if old_out in replacements: + # ``dot(J, J)`` and chained DimShuffles can reach the same matmul twice. + continue + + join_is_left = client_idx == 0 + other = client.inputs[1 - client_idx] + dot_op = client.op + out_dtype = old_out.type.dtype + + contracted = (join_is_left and matmul_axis == -1) or ( + not join_is_left and matmul_axis == -2 + ) + if contracted: + new_output = _decompose_contracted( + fgraph, leaves, other, dot_op, join_is_left + ) + else: + new_output = _decompose_output( + fgraph, leaves, other, dot_op, join_is_left, matmul_axis, out_dtype + ) + + copy_stack_trace(old_out, new_output) + replacements[old_out] = new_output + + return replacements or None + + +@register_canonicalize +@register_stabilize +@node_rewriter([DimShuffle]) +def local_transpose_of_join(fgraph, node): + r"""Rewrite ``Join(axis, *xs).mT`` to ``Join(swapped_axis, *[x.mT for x in xs])``. + + Swap ``axis=-1`` and ``axis=-2`` when axis is one of the matrix axes; leave batch + axes unchanged. + """ + if not node.op.is_matrix_transpose: + return None + + [src] = node.inputs + if src.owner is None or not isinstance(src.owner.op, Join): + return None + + try: + join_axis = int( + get_underlying_scalar_constant_value( + src.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + + src_ndim = src.type.ndim + if join_axis < 0: + join_axis += src_ndim + + if join_axis == src_ndim - 1: + new_axis = src_ndim - 2 + elif join_axis == src_ndim - 2: + new_axis = src_ndim - 1 + else: + new_axis = join_axis # batch axis -- mT doesn't touch it + + new_out = join(new_axis, *[inp.mT for inp in src.owner.inputs[1:]]) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Join]) +def local_nested_join_to_block_diagonal(fgraph, node): + r"""Rewrite a square ``n x n`` nested ``Join`` with zero off-diagonals to :func:`block_diag`. + + Matches ``Join(-2, *Join(-1, ...))`` -- an outer row-concat whose every input is a + column-concat -- forming a square grid with statically-zero off-diagonal leaves. + """ + out_ndim = node.outputs[0].type.ndim + if out_ndim < 2: + return None + + try: + outer_axis = int( + get_underlying_scalar_constant_value( + node.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if outer_axis < 0: + outer_axis += out_ndim + if outer_axis != out_ndim - 2: + return None + + rows = node.inputs[1:] + n_rows = len(rows) + if n_rows < 2: + return None + + leaves = [] + n_cols = None + for row in rows: + if row.owner is None or not isinstance(row.owner.op, Join): + return None + try: + inner_axis = int( + get_underlying_scalar_constant_value( + row.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + if inner_axis < 0: + inner_axis += row.type.ndim + if inner_axis != row.type.ndim - 1: + return None + row_leaves = list(row.owner.inputs[1:]) + if n_cols is None: + n_cols = len(row_leaves) + elif len(row_leaves) != n_cols: + return None # ragged + leaves.append(row_leaves) + + # Square grid only. + if n_rows != n_cols: + return None + + diag_blocks = [] + for i in range(n_rows): + for j in range(n_cols): + leaf = leaves[i][j] + if i == j: + diag_blocks.append(leaf) + elif ( + get_underlying_scalar_constant_value( + leaf, only_process_constants=False, raise_not_constant=False + ) + != 0 + ): + return None + + if len(diag_blocks) < 2: + return None + + new_out = block_diag(*diag_blocks) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +def _const_int_vector(var): + """Return a ``list`` of ints from a 1-D ``var`` whose entries are statically known, else ``None``.""" + if isinstance(var, Constant): + try: + arr = np.asarray(var.data) + except Exception: + return None + if arr.ndim != 1: + return None + return [int(x) for x in arr] + if var.owner is not None and isinstance(var.owner.op, MakeVector): + try: + return [ + int(get_underlying_scalar_constant_value(inp, raise_not_constant=True)) + for inp in var.owner.inputs + ] + except NotScalarConstantError: + return None + return None + + +@register_canonicalize +@register_stabilize +@node_rewriter([Split]) +def local_split_of_join(fgraph, node): + r"""Push :class:`Split` through :class:`Join`. + + Same axis with matching sizes: ``Split(Join(a, *X), [|X_i|_a], axis=a)`` returns the + join's inputs directly. Different axis: ``Split(Join(a, *X), s, axis=b)`` distributes + to ``Join(a, *[Split(X_i, s, b)[k]])`` per cut -- slicing an orthogonal axis commutes + with concatenation. + """ + x, axis_var, splits_size_var = node.inputs + if x.owner is None or not isinstance(x.owner.op, Join): + return None + + try: + split_axis = int( + get_underlying_scalar_constant_value(axis_var, raise_not_constant=True) + ) + join_axis = int( + get_underlying_scalar_constant_value( + x.owner.inputs[0], raise_not_constant=True + ) + ) + except NotScalarConstantError: + return None + + out_ndim = x.type.ndim + if split_axis < 0: + split_axis += out_ndim + if join_axis < 0: + join_axis += out_ndim + + join_inputs = list(x.owner.inputs[1:]) + n_splits = len(node.outputs) + + if split_axis == join_axis: + # Matching axis: return Join's inputs directly when sizes line up. + if len(join_inputs) != n_splits: + return None + join_sizes = [inp.type.shape[join_axis] for inp in join_inputs] + if any(s is None for s in join_sizes): + return None + split_sizes = _const_int_vector(splits_size_var) + if split_sizes is None: + return None + if join_sizes != split_sizes: + return None + for inp in join_inputs: + copy_stack_trace(node.outputs[0], inp) + return list(join_inputs) + + # Different axis: distribute Split through Join. + per_input_splits = [ + split( + inp, + splits_size=splits_size_var, + n_splits=n_splits, + axis=split_axis, + ) + for inp in join_inputs + ] + new_outputs = [ + join( + join_axis, + *[per_input_splits[i][k] for i in range(len(join_inputs))], + ) + for k in range(n_splits) + ] + for new_out in new_outputs: + copy_stack_trace(node.outputs[0], new_out) + return new_outputs + + @register_canonicalize @node_rewriter([Dot, _matmul]) def local_lift_transpose_through_dot(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9106163c2b..2739c6213f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,8 +31,8 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint, pprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor.basic import Alloc, constant, join, second, switch -from pytensor.tensor.blas import Dot22, Gemv +from pytensor.tensor.basic import Alloc, Join, Split, constant, join, second, switch +from pytensor.tensor.blas import BatchedDot, Dot22, Dot22Scalar, Gemm, Gemv, Ger from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -117,6 +117,7 @@ simplify_mul, ) from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape +from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 from pytensor.tensor.type import ( TensorType, cmatrix, @@ -156,6 +157,9 @@ rewrite_mode = "FAST_RUN" rewrite_mode = get_mode(rewrite_mode) +# Paired with ``rewrite_mode`` in rewrite tests as the un-rewritten reference. +no_opt_mode = Mode(linker="py", optimizer=None) + dimshuffle_lift = out2in(local_dimshuffle_lift) _stabilize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) @@ -5116,3 +5120,393 @@ def test_rewrite_does_not_apply(self): original, include=("canonicalize", "stabilize", "specialize") ) assert_equal_computations([rewritten], [original]) + + +class TestNestedJoinToBlockDiagonal: + @staticmethod + def _has_block_diagonal(fn): + return any( + isinstance(n.op, BlockDiagonal) + or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, BlockDiagonal)) + for n in fn.maker.fgraph.toposort() + ) + + def test_zeros_off_diagonal_canonicalizes(self): + a = pt.tensor("a", shape=(3, 3)) + d = pt.tensor("d", shape=(4, 4)) + M = pt.block([[a, pt.zeros((3, 4))], [pt.zeros((4, 3)), d]]) + + ref_fn = pytensor.function([a, d], M, mode=no_opt_mode) + rewr_fn = pytensor.function([a, d], M, mode=rewrite_mode) + assert self._has_block_diagonal(rewr_fn) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 3), (4, 4))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_nonzero_off_diagonal_skips(self): + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + fn = pytensor.function( + [a, b, c, d], pt.block([[a, b], [c, d]]), mode=rewrite_mode + ) + assert not self._has_block_diagonal(fn) + + def test_non_square_skips(self): + a = pt.tensor("a", shape=(3, 4)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(3, 4)) + z = pt.zeros((3, 4)) + fn = pytensor.function( + [a, b, c], pt.block([[a, z, z], [b, z, c]]), mode=rewrite_mode + ) + assert not self._has_block_diagonal(fn) + + +_DOT_OPS = (Dot, Dot22, Dot22Scalar, Gemm, Gemv, Ger, CGemv, BatchedDot) + + +class TestDotOfJoin: + @staticmethod + def _n_dots(fn): + # CVM mode fuses ``dot + add`` into Gemm and lowers Dot22 to Dot22Scalar / + # Gemv / Ger; count every matmul-like Op (and its Blockwise wrappings) so the + # assertion holds across linkers. + return sum( + isinstance(n.op, _DOT_OPS) + or (isinstance(n.op, Blockwise) and isinstance(n.op.core_op, _DOT_OPS)) + for n in fn.maker.fgraph.toposort() + ) + + @staticmethod + def _has_op(fn, op_type): + return any(isinstance(n.op, op_type) for n in fn.maker.fgraph.toposort()) + + @pytest.mark.parametrize( + "side, axis, a_shape, b_shape, y_shape", + [ + ("lhs", -1, (3, 4), (3, 5), (9, 6)), + ("lhs", -2, (3, 4), (2, 4), (4, 6)), + ("rhs", -1, (4, 3), (4, 5), (7, 4)), + ("rhs", -2, (3, 5), (2, 5), (7, 5)), + ], + ids=["lhs_axis-1", "lhs_axis-2", "rhs_axis-1", "rhs_axis-2"], + ) + def test_all_dense_preserved(self, side, axis, a_shape, b_shape, y_shape): + a = pt.tensor("a", shape=a_shape) + b = pt.tensor("b", shape=b_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, b], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, y], out, mode=rewrite_mode) + + # Join survival is the proof the rewrite chose not to fire. + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in (a_shape, b_shape, y_shape)] + np.testing.assert_allclose(rewr_fn(*values), ref_fn(*values), atol=1e-12) + + def test_unknown_shapes_still_preserved(self): + a = pt.matrix("a") + b = pt.matrix("b") + y = pt.matrix("y") + out = pt.concatenate([a, b], axis=-1) @ y + + ref_fn = pytensor.function([a, b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 4), (3, 5), (9, 6))] + np.testing.assert_allclose(rewr_fn(*values), ref_fn(*values), atol=1e-12) + + def test_single_input_join_skipped(self): + a = pt.tensor("a", shape=(3, 4)) + y = pt.tensor("y", shape=(4, 5)) + out = pt.join(-1, a) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 4)) + y_v = rng.standard_normal((4, 5)) + np.testing.assert_allclose(rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12) + + @pytest.mark.parametrize( + "side, axis, a_shape, zero_shape, y_shape", + [ + ("lhs", -1, (3, 4), (3, 5), (9, 6)), + ("rhs", -2, (3, 5), (2, 5), (7, 5)), + ], + ids=["lhs_axis-1", "rhs_axis-2"], + ) + def test_zero_leaf_contracted_drops_dot( + self, side, axis, a_shape, zero_shape, y_shape + ): + a = pt.tensor("a", shape=a_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, pt.zeros(zero_shape)], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + + # Split + no Join is the signature of contracted-axis decomposition. + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Split) + assert not self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal(a_shape) + y_v = rng.standard_normal(y_shape) + np.testing.assert_allclose( + rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12, rtol=1e-12 + ) + + @pytest.mark.parametrize( + "side, axis, a_shape, zero_shape, y_shape", + [ + ("lhs", -2, (3, 4), (2, 4), (4, 6)), + ("rhs", -1, (4, 3), (4, 5), (7, 4)), + ], + ids=["lhs_axis-2", "rhs_axis-1"], + ) + def test_zero_leaf_output_emits_zero_block( + self, side, axis, a_shape, zero_shape, y_shape + ): + # Output-axis case: rewrite emits ``Join(Dot(a, other), zero_block)``. A Join + # still appears, so only numerical equivalence pins the rewrite's effect. + a = pt.tensor("a", shape=a_shape) + y = pt.tensor("y", shape=y_shape) + M = pt.concatenate([a, pt.zeros(zero_shape)], axis=axis) + out = M @ y if side == "lhs" else y @ M + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + + rng = np.random.default_rng(0) + a_v = rng.standard_normal(a_shape) + y_v = rng.standard_normal(y_shape) + np.testing.assert_allclose( + rewr_fn(a_v, y_v), ref_fn(a_v, y_v), atol=1e-12, rtol=1e-12 + ) + + def test_identity_leaf_contracted_folds_to_chunk(self): + a = pt.tensor("a", shape=(4, 3)) + y = pt.tensor("y", shape=(3 + 4, 6)) + out = pt.join(-1, a, pt.eye(4)) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + assert not self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((4, 3), (7, 6))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_identity_leaf_output_replaces_with_operand(self): + a = pt.tensor("a", shape=(3, 5)) + y = pt.tensor("y", shape=(5, 6)) + out = pt.concatenate([a, pt.eye(5)], axis=-2) @ y + + ref_fn = pytensor.function([a, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, y], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 1 + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 5), (5, 6))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_selection_leaf_collapses_to_indexing(self): + # SELECTION leaf becomes its own Dot, then ``selection_dot_to_indexing`` turns + # it into an advanced index. ``y @ S`` hits the gather path. + b = pt.tensor("b", shape=(4, 5)) + perm = pt.constant(np.array([2, 0, 3, 1], dtype=np.int64)) + S = pt.eye(4)[:, perm] + y = pt.tensor("y", shape=(7, 4)) + out = y @ pt.concatenate([b, S], axis=-1) + + ref_fn = pytensor.function([b, y], out, mode=no_opt_mode) + rewr_fn = pytensor.function([b, y], out, mode=rewrite_mode) + + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, AdvancedSubtensor | AdvancedSubtensor1) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((4, 5), (7, 4))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_all_dense_block_preserved(self): + a = pt.tensor("a", shape=(2, 3)) + b = pt.tensor("b", shape=(2, 4)) + c = pt.tensor("c", shape=(5, 3)) + d = pt.tensor("d", shape=(5, 4)) + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[a, b], [c, d]]) @ other + + ref_fn = pytensor.function([a, b, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, c, d, other], out, mode=rewrite_mode) + + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(t.type.shape) for t in (a, b, c, d, other)] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_block_triangular_decomposes(self): + # Top row's zero drops its product; bottom row stays as one GEMM. 2 Dots. + a = pt.tensor("a", shape=(3, 3)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[a, pt.zeros((3, 4))], [c, d]]) @ other + + ref_fn = pytensor.function([a, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, c, d, other], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 2 + + rng = np.random.default_rng(0) + values = [rng.standard_normal(t.type.shape) for t in (a, c, d, other)] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + def test_block_with_identity_decomposes(self): + # Top row's identity folds to operand chunk; bottom row stays as one GEMM. + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + other = pt.tensor("other", shape=(7, 6)) + out = pt.block([[pt.eye(3), b], [c, d]]) @ other + + ref_fn = pytensor.function([b, c, d, other], out, mode=no_opt_mode) + rewr_fn = pytensor.function([b, c, d, other], out, mode=rewrite_mode) + assert self._n_dots(rewr_fn) == 2 + + rng = np.random.default_rng(0) + values = [rng.standard_normal(t.type.shape) for t in (b, c, d, other)] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + @pytest.mark.parametrize( + "lhs_t, rhs_t", + [(False, False), (False, True), (True, False), (True, True)], + ids=["X@X", "X@X.T", "X.T@X", "X.T@X.T"], + ) + def test_dense_block_at_dense_block_preserved(self, lhs_t, rhs_t): + a = pt.tensor("a", shape=(3, 3)) + b = pt.tensor("b", shape=(3, 4)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + X = pt.block([[a, b], [c, d]]) + out = (X.mT if lhs_t else X) @ (X.mT if rhs_t else X) + + ref_fn = pytensor.function([a, b, c, d], out, mode=no_opt_mode) + rewr_fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + + assert self._n_dots(rewr_fn) == 1 + assert self._has_op(rewr_fn, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 3), (3, 4), (4, 3), (4, 4))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) + + +class TestSplitOfJoin: + @staticmethod + def _has_split(fn): + return any(isinstance(n.op, Split) for n in fn.maker.fgraph.toposort()) + + def test_matching_axis_matching_sizes_returns_inputs(self): + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + out_a, out_b = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[3, 4], n_splits=2, axis=-2 + ) + + fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) + assert not self._has_split(fn) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((4, 5)) + ra, rb = fn(a_v, b_v) + np.testing.assert_array_equal(ra, a_v) + np.testing.assert_array_equal(rb, b_v) + + def test_matching_axis_mismatched_sizes_skips(self): + # Split sizes [2, 5] don't align with the Join's [3, 4] boundaries. + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + out_a, out_b = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[2, 5], n_splits=2, axis=-2 + ) + + fn = pytensor.function([a, b], [out_a, out_b], mode=rewrite_mode) + assert self._has_split(fn) + + def test_different_axis_distributes(self): + a = pt.tensor("a", shape=(3, 5)) + b = pt.tensor("b", shape=(4, 5)) + c0, c1 = pt.split( + pt.concatenate([a, b], axis=-2), splits_size=[2, 3], n_splits=2, axis=-1 + ) + + ref_fn = pytensor.function([a, b], [c0, c1], mode=no_opt_mode) + rewr_fn = pytensor.function([a, b], [c0, c1], mode=rewrite_mode) + + rng = np.random.default_rng(0) + a_v = rng.standard_normal((3, 5)) + b_v = rng.standard_normal((4, 5)) + for r, e in zip(rewr_fn(a_v, b_v), ref_fn(a_v, b_v), strict=True): + np.testing.assert_allclose(r, e) + + def test_x_at_s_at_xT_no_intermediate_materialization(self): + # Block-triangular X gives ``local_dot_of_join`` something to fire on, producing + # Splits of S. Invariant: none of them feed from a Join (otherwise + # ``local_split_of_join`` failed to collapse a rebuilt intermediate). + a = pt.tensor("a", shape=(3, 3)) + c = pt.tensor("c", shape=(4, 3)) + d = pt.tensor("d", shape=(4, 4)) + S = pt.tensor("S", shape=(7, 7)) + X = pt.block([[a, pt.zeros((3, 4))], [c, d]]) + + ref_fn = pytensor.function([a, c, d, S], X @ S @ X.mT, mode=no_opt_mode) + rewr_fn = pytensor.function([a, c, d, S], X @ S @ X.mT, mode=rewrite_mode) + + splits = [n for n in rewr_fn.maker.fgraph.toposort() if isinstance(n.op, Split)] + assert splits, "expected the rewrite chain to introduce at least one Split" + for n in splits: + src = n.inputs[0] + assert src.owner is None or not isinstance(src.owner.op, Join) + + rng = np.random.default_rng(0) + values = [rng.standard_normal(s) for s in ((3, 3), (4, 3), (4, 4), (7, 7))] + np.testing.assert_allclose( + rewr_fn(*values), ref_fn(*values), atol=1e-12, rtol=1e-12 + ) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 2ab01c047d..e96353dbd9 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -41,6 +41,7 @@ arange, as_tensor_variable, atleast_Nd, + block, cast, choose, constant, @@ -3632,6 +3633,90 @@ def test_grad_3d(self, offset, axis1, axis2): utt.verify_grad(diag_fn, [x], rng=rng) +class TestBlock: + """Tests that ``pt.block`` matches ``np.block``.""" + + def _check(self, arrays_pt, arrays_np, *, expected_ndim=None): + out = block(arrays_pt) + if expected_ndim is not None: + assert out.ndim == expected_ndim + np.testing.assert_allclose(out.eval(), np.block(arrays_np)) + + def test_depth_1(self): + a = np.arange(3.0) + b = np.arange(4.0) + self._check([a, b], [a, b], expected_ndim=1) + + def test_depth_2_block_matrix(self): + rng = np.random.default_rng(0) + A = rng.standard_normal((2, 3)) + B = rng.standard_normal((2, 4)) + C = rng.standard_normal((5, 3)) + D = rng.standard_normal((5, 4)) + self._check([[A, B], [C, D]], [[A, B], [C, D]], expected_ndim=2) + + def test_depth_3(self): + rng = np.random.default_rng(0) + x = rng.standard_normal((2, 3, 4)) + self._check([[[x, x], [x, x]]], [[[x, x], [x, x]]], expected_ndim=3) + + def test_scalar_promotion(self): + # Depth-2 list of scalars → 2-D output, like np.block. + self._check([[1.0, 2.0], [3.0, 4.0]], [[1.0, 2.0], [3.0, 4.0]], expected_ndim=2) + + def test_mixed_rank_promotion(self): + # 1-D leaf in a depth-2 list gets atleast_2d-promoted on the left. + rng = np.random.default_rng(0) + v = rng.standard_normal(3) + A = rng.standard_normal((2, 3)) + self._check([[v], [A]], [[v], [A]], expected_ndim=2) + + def test_ragged_widths(self): + # Same depth, different number of children per row — np.block allows it. + rng = np.random.default_rng(0) + A = rng.standard_normal((2, 1)) + B = rng.standard_normal((2, 2)) + C = rng.standard_normal((2, 3)) + D = rng.standard_normal((3, 6)) + self._check([[A, B, C], [D]], [[A, B, C], [D]], expected_ndim=2) + + def test_single_leaf_passthrough(self): + # np.block(arr) returns atleast_1d(arr). + out = block(5.0) + assert out.ndim == 1 + np.testing.assert_array_equal(out.eval(), np.block(5.0)) + + def test_depth_mismatch_raises(self): + A = pytensor.tensor.matrix("A") + with pytest.raises(ValueError, match="same nesting depth"): + block([[A, A], A]) + + def test_tuple_rejected(self): + A = pytensor.tensor.matrix("A") + with pytest.raises(TypeError, match="tuples are not allowed"): + block((A, A)) + + def test_empty_list_raises(self): + with pytest.raises(ValueError, match="empty list"): + block([]) + + def test_returns_nested_join(self): + # `block` is a pure helper: the output is the result of a Join of Joins, + # not a wrapping Op. + A = pytensor.tensor.matrix("A") + B = pytensor.tensor.matrix("B") + out = block([[A, B], [B, A]]) + # Outer concat along axis -2 (rows). + assert isinstance(out.owner.op, ptb.Join) + outer_axis = int(out.owner.inputs[0].data) + assert outer_axis == out.ndim - 2 + # Each row is itself a Join along axis -1 (columns). + for inner in out.owner.inputs[1:]: + assert isinstance(inner.owner.op, ptb.Join) + inner_axis = int(inner.owner.inputs[0].data) + assert inner_axis == inner.ndim - 1 + + class TestAllocDiag: # TODO: Separate perform, grad and infer_shape tests