Skip to content
Open
149 changes: 134 additions & 15 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,28 +3020,146 @@ def vertical_stack(*args):
return concatenate(_args, axis=0)


def is_flat(var, ndim=1):
"""
Verifies the dimensionality of the var is equal to
ndim. This method is usually called after flatten method on a
variable, where the first ndim-1 dimension size(s) of the variable
is kept intact, and the last dimension size of the variable is made
equal to the multiplication of its remaining dimension size(s), such that
the variable would end up with as many dimension as ndim.
def _block_check_depths_match(arrays, parent_index=()):
"""Walk a nested block-list and check every leaf sits at the same depth.

Parameters
----------
var : pytensor.tensor.var.TensorVariable
the pytensor var on which the dimensionality is checked.
arrays : list or array_like
Nested block-list to validate.
parent_index : tuple of int, optional
Indices accumulated from the root, used in error messages. Default ``()``.

ndim : int
the expected dimensionality of var.
Returns
-------
structure : nested tuple of None
Tree shape with ``None`` at each leaf position.
leaf_depth : int
Depth at which every leaf sits.
max_leaf_ndim : int
Largest ``ndim`` across all leaves.
"""
if isinstance(arrays, list):
if not arrays:
raise ValueError("Block: empty list is not allowed")
children = []
first_leaf_depth = None
max_ndim = 0
for i, child in enumerate(arrays):
child_struct, child_leaf_depth, child_ndim = _block_check_depths_match(
child, (*parent_index, i)
)
if first_leaf_depth is None:
first_leaf_depth = child_leaf_depth
elif first_leaf_depth != child_leaf_depth:
raise ValueError(
"Block: all leaves must be at the same nesting depth "
f"(got depth {child_leaf_depth} at index {(*parent_index, i)}, "
f"expected {first_leaf_depth})"
)
if child_ndim > max_ndim:
max_ndim = child_ndim
children.append(child_struct)
return tuple(children), first_leaf_depth, max_ndim
elif isinstance(arrays, tuple):
raise TypeError("Block: tuples are not allowed as nested containers; use lists")
else:
leaf = as_tensor_variable(arrays)
return None, len(parent_index), leaf.type.ndim


def block(arrays):
"""Assemble a tensor from nested lists of blocks, like ``numpy.block``.

Parameters
----------
arrays : nested list of array_like
Tensors at the leaves, lists at the interior. Every leaf must sit at
the same nesting depth ``d``; the concatenation spans the last ``d``
axes.

Returns
-------
bool
the comparison result of var's dim
and the expected outdim.
result : TensorVariable
Assembled block tensor. A bare tensor (no list wrapping) returns as
``atleast_1d(arrays)``.

Examples
--------
.. testcode::

import numpy as np
import pytensor.tensor as pt

A = pt.as_tensor_variable(np.array([[1, 2], [3, 4]]))
B = pt.as_tensor_variable(np.array([[5], [6]]))
C = pt.as_tensor_variable(np.array([[7, 8]]))
D = pt.as_tensor_variable(np.array([[9]]))
M = pt.block([[A, B], [C, D]])
print(M.eval())

.. testoutput::

[[1 2 5]
[3 4 6]
[7 8 9]]
"""
structure, _, _ = _block_check_depths_match(arrays)

if structure is None:
return atleast_Nd(arrays, n=1)

flat = []

def _gather(node):
if isinstance(node, list):
for child in node:
_gather(child)
else:
flat.append(as_tensor_variable(node))

_gather(arrays)

def _structure_depth(structure):
if structure is None:
return 0
return 1 + _structure_depth(structure[0])

list_ndim = _structure_depth(structure)
result_ndim = builtins.max(list_ndim, builtins.max(inp.type.ndim for inp in flat))
promoted = [atleast_Nd(inp, n=result_ndim) for inp in flat]

def _unflatten_structure(flat, structure):
"""Rebuild a nested list from ``flat`` consumed in pre-order against ``structure``."""
it = iter(flat)

def _build(s):
if s is None:
return next(it)
return [_build(child) for child in s]

return _build(structure)

nested = _unflatten_structure(promoted, structure)

def _recurse(node, depth):
if depth == list_ndim:
return node
children = [_recurse(child, depth + 1) for child in node]
return concatenate(children, axis=-(list_ndim - depth))

return _recurse(nested, 0)


def is_flat(var, ndim=1):
"""Return ``True`` when ``var`` has exactly ``ndim`` dimensions.

Parameters
----------
var : TensorVariable
Variable to inspect.
ndim : int
Expected number of dimensions. Default 1.
"""
return var.ndim == ndim

Expand Down Expand Up @@ -4558,6 +4676,7 @@ def ix_(*args):
"atleast_2d",
"atleast_3d",
"atleast_Nd",
"block",
"cast",
"choose",
"concatenate",
Expand Down
73 changes: 73 additions & 0 deletions pytensor/tensor/rewriting/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import (
Eye,
Join,
TensorVariable,
atleast_Nd,
diagonal,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.linalg.inverse import MatrixInverse, MatrixPinv
from pytensor.tensor.math import variadic_mul
from pytensor.tensor.rewriting.basic import (
Expand All @@ -43,6 +46,76 @@
}


def match_2x2_nested_join(var):
"""Return ``[[A_11, A_12], [A_21, A_22]]`` if ``var`` is a 2x2 nested ``Join``, else ``None``.

Requires the outer ``Join`` along ``ndim - 2``, both inner ``Join`` ops along
``ndim - 1``, statically-known row heights and column widths that line up, and
square diagonal blocks.
"""
if var.owner is None or not isinstance(var.owner.op, Join):
return None

out_ndim = var.type.ndim
if out_ndim < 2:
return None

try:
outer_axis = int(
get_underlying_scalar_constant_value(
var.owner.inputs[0], raise_not_constant=True
)
)
except NotScalarConstantError:
return None
if outer_axis < 0:
outer_axis += out_ndim
if outer_axis != out_ndim - 2:
return None

rows = var.owner.inputs[1:]
if len(rows) != 2:
return None

leaves = []
for row in rows:
if row.owner is None or not isinstance(row.owner.op, Join):
return None
try:
inner_axis = int(
get_underlying_scalar_constant_value(
row.owner.inputs[0], raise_not_constant=True
)
)
except NotScalarConstantError:
return None
if inner_axis < 0:
inner_axis += row.type.ndim
if inner_axis != row.type.ndim - 1:
return None
row_leaves = list(row.owner.inputs[1:])
if len(row_leaves) != 2:
return None
leaves.append(row_leaves)

[[A_11, A_12], [A_21, A_22]] = leaves

m1 = A_11.type.shape[-2]
m2 = A_22.type.shape[-2]
n1 = A_11.type.shape[-1]
n2 = A_22.type.shape[-1]
if any(s is None for s in (m1, m2, n1, n2)):
return None
if m1 != n1 or m2 != n2:
return None # diagonal blocks not square
if A_12.type.shape[-2] != m1 or A_12.type.shape[-1] != n2:
return None
if A_21.type.shape[-2] != m2 or A_21.type.shape[-1] != n1:
return None

return leaves


def matrix_diagonal_product(x):
return pt.prod(diagonal(x, axis1=-2, axis2=-1), axis=-1)

Expand Down
Loading
Loading