diff --git a/pytensor/assumptions/specify.py b/pytensor/assumptions/specify.py index ebf2282c89..201e7d6282 100644 --- a/pytensor/assumptions/specify.py +++ b/pytensor/assumptions/specify.py @@ -39,7 +39,7 @@ def make_node(self, x): out = x.type() return Apply(self, [x], [out]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback( diff --git a/pytensor/breakpoint.py b/pytensor/breakpoint.py index f9c74950dc..c2913cfea1 100644 --- a/pytensor/breakpoint.py +++ b/pytensor/breakpoint.py @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage): def pullback(self, inputs, outputs, output_gradients): return [disconnected_type(), *output_gradients] - def infer_shape(self, fgraph, inputs, input_shapes): + def infer_shape(self, inputs, input_shapes): # Return the shape of every input but the condition (first input) return input_shapes[1:] diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 4060c38365..c001d8c7e7 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,7 +7,6 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from itertools import chain from typing import cast from pytensor.compile.maker import function @@ -24,26 +23,14 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, graph_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError +from pytensor.tensor.shape import Shape_i def infer_shape(outs, inputs, input_shapes): - """ - Compute the shape of the outputs given the shape of the inputs of an PyTensor - graph. - - We do it this way to avoid compiling the inner function just to get - the shape. Changes to ShapeFeature could require changes in this function. - - """ - # We use a ShapeFeature because it has all the necessary logic - # inside. We don't use the full ShapeFeature interface, but we - # let it initialize itself with an empty fgraph, otherwise we will - # need to do it manually - - # TODO: ShapeFeature should live elsewhere + """Compute the shape of the outputs given the shape of the inputs of a PyTensor graph.""" from pytensor.tensor.rewriting.shape import ShapeFeature for inp, inp_shp in zip(inputs, input_shapes, strict=True): @@ -51,43 +38,36 @@ def infer_shape(outs, inputs, input_shapes): assert len(inp_shp) == inp.type.ndim shape_feature = ShapeFeature() - fgraph = FunctionGraph([], [], features=[shape_feature]) - for v in chain.from_iterable(s for s in input_shapes if s is not None): - # Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before - if (node := v.owner) is not None: - fgraph.import_node(node, import_missing=True) - - # Initialize shape_of with the input shapes - for inp, inp_shp in zip(inputs, input_shapes, strict=True): - shape_feature.set_shape(inp, inp_shp, override=True) - - def local_traverse(out): - """ - Go back in the graph, from out, adding computable shapes to shape_of. - - """ - if out in shape_feature.shape_of: - # Its shape is already known - return - elif out.owner is None: - # This is an input of the graph - shape_feature.init_r(out) - else: - # Recurse over inputs - for inp in out.owner.inputs: - if inp not in shape_feature.shape_of: - local_traverse(inp) + output_shapes = [shape_feature.shape_tuple(o) for o in outs] - # shape_feature.on_import does not actually use an fgraph - # It will call infer_shape and set_shape appropriately - dummy_fgraph = None - shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy") + # Shape expressions for root inputs are Shape_i(inp, i). + # Replace those with the caller-provided input_shapes. + replacements = {} + for inp, shp in zip(inputs, input_shapes, strict=True): + if shp is None: + continue + per_dim = shape_feature._shape_i_cache.get(inp) + if per_dim is None: + continue + for i, s in enumerate(shp): + cached = per_dim.get(i) + if cached is not None: + replacements[cached] = s + + if replacements: + flat = [s for tup in output_shapes if tup is not None for s in tup] + flat_replaced = graph_replace(flat, replacements, strict=False) + result = [] + idx = 0 + for tup in output_shapes: + if tup is None: + result.append(None) + else: + result.append(tuple(flat_replaced[idx : idx + len(tup)])) + idx += len(tup) + return result - ret = [] - for o in outs: - local_traverse(o) - ret.append(shape_feature.shape_of[o]) - return ret + return output_shapes def construct_nominal_fgraph( @@ -885,30 +865,51 @@ def connection_pattern(self, node): self._connection_pattern = ret return ret - def infer_shape(self, fgraph, node, shapes): - # TODO: Use `fgraph.shape_feature` to do this instead. - out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes) - - # Clone the output shape so that shape are computed from outer inputs. - # Note: - # Here we could do it more simply like: - # `ret = [pytensor.clone_replace(shp, replace=repl) for shp in out_shp]` - # But doing it multiple time could duplicate common subgraph between - # each shape call. PyTensor optimizer will clean this up later, but this - # will make extra work for the optimizer. - - repl = dict(zip(self.inner_inputs, node.inputs, strict=True)) - clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)] - cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl) + def infer_shape(self, node, shapes): + try: + template = self._inner_shape_template + frozen = self._inner_shape_frozen + except AttributeError: + from pytensor.tensor.rewriting.shape import ShapeFeature + + sf = ShapeFeature() + inner_inputs = self.inner_inputs + template = [sf.shape_tuple(o) for o in self.inner_outputs] + flat_shapes = [s for tup in template if tup is not None for s in tup] + + # Express the inner-output shapes as a frozen function of the inner + # inputs plus each input's per-dim size. from_structural_inputs rewires + # every Shape_i(inner_input, dim) occurrence to the matching input, so + # bind can later swap in the caller's shapes. One slot per input dim: + # static or unused dims become dead inputs, keeping the layout positional. + shape_inputs = [ + Shape_i(dim)(inp) + for inp in inner_inputs + for dim in range(getattr(inp.type, "ndim", 0)) + ] + frozen = FrozenFunctionGraph.from_structural_inputs( + [*inner_inputs, *shape_inputs], flat_shapes + ) + self._inner_shape_template = template + self._inner_shape_frozen = frozen + + # frozen.inputs is [*inner_inputs, *per-dim sizes]; mirror that layout. + repl = list(node.inputs) + for shp in shapes: + if shp is not None: + repl.extend(shp) + + bound_shapes = frozen.bind(dict(zip(frozen.inputs, repl, strict=True))) + ret = [] - used = 0 - for i, out_shape in enumerate(out_shapes): - if out_shape is None: + idx = 0 + for tup in template: + if tup is None: ret.append(None) else: - nb = len(out_shape) - ret.append(cloned[used : used + nb]) - used += nb + nb = len(tup) + ret.append(bound_shapes[idx : idx + nb]) + idx += nb return ret diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 57d34bbf25..028c854854 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp): def make_node(self, x): return Apply(self, [x], [x.type()]) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def pullback(self, args, outputs, g_outs): @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes @@ -251,8 +251,8 @@ def __reduce__(self): ) return load_back, (mod, name) - def _infer_shape(self, fgraph, node, input_shapes): - return self.__infer_shape(fgraph, node, input_shapes) + def _infer_shape(self, node, input_shapes): + return self.__infer_shape(node, input_shapes) def as_op(itypes, otypes, infer_shape=None): @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None): It takes an optional infer_shape parameter that should be a callable with this signature: - def infer_shape(fgraph, node, input_shapes): + def infer_shape(node, input_shapes): ... return output_shapes diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 93fc6aff0c..3f82ec0e15 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -1031,6 +1031,37 @@ def _resolve_input(inp, memo=memo): self._variables: frozenset[Variable] | None = None self._clients: dict[Variable, list[ClientType]] | None = None + @classmethod + def from_structural_inputs( + cls, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + ) -> "FrozenFunctionGraph": + """Freeze ``outputs``, allowing ``inputs`` to be *interior* expressions. + + Structural-matching dual of `bind`: where `bind` maps inputs to values, + this lifts chosen sub-expressions up to inputs. An ``input`` produced by + an `Apply` is matched against ``outputs`` by structure (not identity) and + every occurrence is rewired to it; root inputs behave as in the + constructor. Intermediate inputs absent from ``outputs`` become dead + inputs. The signature preserves the order of ``inputs``. ``outputs`` must + be computable from ``inputs`` alone — any root they still depend on + directly must itself appear in ``inputs``, else the rewired graph is + orphaned. + """ + # Discover the true graph roots (as FunctionGraph(inputs=None) does) to + # seed the staged freeze; the caller's `inputs` may be intermediate. + # Freezing inputs and outputs together interns each intermediate input + # onto the same object as its occurrences in the outputs, so the + # re-freeze can rewire them — which requires intermediate inputs to be + # *built*, not blocked, hence only roots seed the freeze. + roots = [ + v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant) + ] + interned = cls(roots, [*inputs, *outputs]) + n_inputs = len(inputs) + return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:]) + def __reduce__(self): return FrozenFunctionGraph, (self.inputs, self.outputs) @@ -1094,7 +1125,7 @@ def bind(self, replace: dict[Variable, Variable]) -> list[Variable]: [o.type() for o in node.outputs], ) memo.update(zip(node.outputs, new_node.outputs)) - return [memo[out] for out in self.outputs] + return [out if isinstance(out, Constant) else memo[out] for out in self.outputs] def unfreeze(self) -> "FunctionGraph": """Return a mutable FunctionGraph with fresh mutable Apply nodes.""" diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index ebfa067b71..e048b4188f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -1,3 +1,4 @@ +import inspect import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Sequence @@ -120,6 +121,24 @@ class Op(MetaObject): as nodes with these Ops must be rebuilt even if the input types haven't changed. """ + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + method = cls.__dict__.get("infer_shape") + if method is None: + return + params = inspect.signature(method).parameters + if len(params) == 4: + warnings.warn( + f"{cls.__module__}.{cls.__qualname__}.infer_shape takes a " + "deprecated `fgraph` parameter; drop it from the signature. " + "The parameter will be passed as None.", + DeprecationWarning, + stacklevel=2, + ) + cls.infer_shape = lambda self, node, input_shapes, _old=method: _old( + self, None, node, input_shapes + ) + def make_node(self, *inputs: Variable) -> Apply: """Construct an `Apply` node that represent the application of this operation to the given inputs. diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index ad309161aa..86ece4aa01 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -212,6 +212,66 @@ def toposort_key( return fg.outputs[0] +def break_aliasing_cycles( + outputs: Sequence[Variable], + destroyers_of, +) -> list[Variable]: + """Break aliasing-induced ordering cycles in ``outputs``. + + When an inplace Op overwrites input ``x`` and a single Apply reads + both ``x`` and a transitive dependent of the destroyer's output, + no valid schedule exists. This re-routes ``x`` on that Apply through + ``deep_copy_op`` to lift the ordering conflict. + """ + from pytensor.compile.ops import deep_copy_op + + EMPTY: frozenset[Variable] = frozenset() + deps: dict[Variable, frozenset[Variable]] = {} + substitutes: dict[Variable, Variable] = {} + replacements: dict[Variable, Variable] = {} + for node in toposort(outputs): + d: set[Variable] = set() + for inp in node.inputs: + d |= deps.get(inp, EMPTY) + if inp.owner is not None and inp.owner.op.destroy_map: + d.add(inp) + out_deps = frozenset(d) + for out in node.outputs: + deps[out] = out_deps + + if node.op.destroy_map: + continue + + new_inputs = list(node.inputs) + changed = False + for i, inp in enumerate(node.inputs): + inp_destroyers = destroyers_of(inp) + if not inp_destroyers: + continue + other_deps: set[Variable] = set() + for j, other_inp in enumerate(node.inputs): + if j == i: + continue + other_deps |= deps.get(other_inp, EMPTY) + if other_inp.owner is not None and other_inp.owner.op.destroy_map: + other_deps.add(other_inp) + if any( + out in other_deps for c_app in inp_destroyers for out in c_app.outputs + ): + if inp not in substitutes: + substitutes[inp] = cast(Variable, deep_copy_op(inp)) + new_inputs[i] = substitutes[inp] + changed = True + if changed: + new_node = node.op.make_node(*new_inputs) + replacements.update(zip(node.outputs, new_node.outputs, strict=True)) + + if not replacements: + return list(outputs) + + return graph_replace(list(outputs), replace=replacements) + + @singledispatch def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]: # Default implementation is provided in pytensor.tensor.blockwise diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 3e7264fcf9..31e0feba97 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -103,7 +103,7 @@ def __str__(self): args.append("inplace") return f"if{{{','.join(args)}}}" - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): # By construction, corresponding then/else pairs have the same number # of dimensions diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index c2962e4d7d..87f52c3c4b 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -137,7 +137,7 @@ def c_code(self, node, name, inames, onames, props): def c_code_cache_version(self): return (2,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def do_constant_folding(self, fgraph, node): diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 3628567b95..f963b8f512 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -2251,7 +2251,7 @@ def perform(self, node, inputs, output_storage): self.t_call = t_call self.t_fn = t_fn - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # input_shapes correspond to the shapes of node.inputs for inp, inp_shp in zip(node.inputs, input_shapes, strict=True): assert inp_shp is None or len(inp_shp) == inp.type.ndim diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 771135b1c6..a6bf82edf7 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -494,7 +494,7 @@ def pullback(self, inputs, outputs, gout): disconnected_type(), ] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): # node.inputs[3] is of length as we only support sparse matrix. return [(node.inputs[3][0], node.inputs[3][1])] @@ -584,7 +584,7 @@ def perform(self, node, inputs, outputs): g_out[0] = gout_data - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -629,7 +629,7 @@ def pullback(self, inputs, outputs, outputs_gradients): else: return [Cast(inputs[0].dtype)(gz)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return ins_shapes def __str__(self): @@ -742,7 +742,7 @@ def pullback(self, inputs, outputs, gout): else: return [SparseFromDense(x.type.format)(gz)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -806,7 +806,7 @@ def pullback(self, inputs, outputs, gout): ) return (gx,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -820,7 +820,7 @@ class GetItemList(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[1][0], shapes[0][1])] def make_node(self, x, index): @@ -865,7 +865,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItemListGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, index, gz): @@ -958,7 +958,7 @@ def pullback(self, inputs, outputs, g_outputs): class GetItem2ListsGrad(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0])] def make_node(self, x, ind1, ind2, gz): @@ -1139,7 +1139,7 @@ class GetItemScalar(Op): __props__ = () - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def make_node(self, x, index): @@ -1239,7 +1239,7 @@ def pullback(self, inputs, outputs, gout): assert _is_sparse_variable(x) and _is_sparse_variable(gz) return (transpose(gz),) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0][::-1]] @@ -1288,7 +1288,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [col_scale(gz, s), sp_sum(x * gz, axis=0)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1339,7 +1339,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [row_scale(gz, s), sp_sum(x * gz, axis=1)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1438,7 +1438,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [square_diagonal(gz)] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): return [(minimum(*shapes[0]),)] @@ -1498,7 +1498,7 @@ def perform(self, node, inputs, outputs): def pullback(self, inputs, outputs, output_grad): return [output_grad[0]] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes def __str__(self): @@ -1614,7 +1614,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[1] for shape in ins_shapes) return [(ins_shapes[0][0], d)] @@ -1711,7 +1711,7 @@ def choose(continuous, derivative): return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): d = sum(shape[0] for shape in ins_shapes) return [(d, ins_shapes[0][1])] @@ -1800,7 +1800,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return [gz] - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): return i0_shapes @@ -1880,7 +1880,7 @@ def perform(self, node, inp, out_): (data, indices, indptr), shape=out_shape, dtype=values.dtype ) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x = node.inputs[0] return [[x[0], x[1]]] diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index 57fe280f88..d59915ed45 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -325,7 +325,7 @@ def pullback(self, inputs, outputs, gout): r = psb.SparseFromDense(o_format)(r) return [r] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): r = None if self.axis is None: r = [()] @@ -404,7 +404,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -465,7 +465,7 @@ def pullback(self, inputs, outputs, gout): derivative = {True: gz, False: None} return [derivative[b] for b in is_continuous] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -507,7 +507,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_dense_variable(gz) return psb.sp_ones_like(x) * gz, gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] @@ -567,7 +567,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return gz, sp_sum(gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -697,7 +697,7 @@ def pullback(self, inputs, outputs, gout): (gz,) = gout return y * gz, x * gz - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -786,7 +786,7 @@ def pullback(self, inputs, outputs, gout): assert psb._is_sparse_variable(gz) return y * gz, psb.dense_from_sparse(x * gz) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -869,7 +869,7 @@ def pullback(self, inputs, outputs, gout): return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -987,7 +987,7 @@ def perform(self, node, inputs, outputs): self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format) ) - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1032,7 +1032,7 @@ def perform(self, node, inputs, outputs): o = np.asarray(o) out[0] = o - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[0]] @@ -1282,7 +1282,7 @@ def pullback(self, inputs, outputs, gout): rval[1] = psb.dense_from_sparse(rval[1]) return rval - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1410,7 +1410,7 @@ def pullback(self, inputs, outputs, gout): (g_out,) = gout return [structured_dot_grad(a, b, g_out), structured_dot(a.T, g_out)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(shapes[0][0], shapes[1][1])] @@ -1594,7 +1594,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1729,7 +1729,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] @@ -1827,7 +1827,7 @@ def pullback(self, inputs, outputs, gout): return rval - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): return [ins_shapes[2]] @@ -1840,7 +1840,7 @@ class Dot(Op): def __str__(self): return "Sparse" + self.__class__.__name__ - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes x, y = node.inputs if x.ndim == 2 and y.ndim == 2: diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index 2866f1aec0..1a2c75a6a8 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -176,7 +176,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ return code - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[3]] def c_code_cache_version(self): diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 7d690a9e9c..48528e7696 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -637,7 +637,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(s) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -698,7 +698,7 @@ def perform(self, node, inputs, output_storage): # not using .item() because that returns a Python scalar, not a numpy scalar output_storage[0][0] = inputs[0][()] - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [()] def pullback(self, inp, outputs, grads): @@ -1379,7 +1379,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.eye(n, m, k, dtype=self.dtype) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): out_shape = [node.inputs[0], node.inputs[1]] return [out_shape] @@ -1708,7 +1708,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (5,) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs[1:]] def connection_pattern(self, node): @@ -1762,14 +1762,32 @@ def do_constant_folding(self, fgraph, node): if not clients: return False + from pytensor.tensor.blas import Gemv, Ger + from pytensor.tensor.blas_c import CGemv, CGer + from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + Subtensor, + ) + for client, idx in clients: client_op = client.op if isinstance(client_op, Output): # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False - # Op's through which Alloc can be lifted - elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join): + # Op's through which Alloc can be lifted. ``Subtensor`` is + # included because ``local_subtensor_of_alloc`` rewrites + # ``alloc(val, *shape)[idx]`` into ``alloc(val[...], *new_shape)``, + # preserving the Alloc structure that downstream rewrites + # (e.g. ``local_blockwise_alloc_inputs``) rely on. Folding the + # Alloc here would short-circuit that lift and produce a + # broadcast-equivalent constant whose batch dim is no longer + # type-broadcastable. + elif isinstance( + client_op, Elemwise | DimShuffle | Alloc | Join | Subtensor + ): return False # Same for Blockwise, unless it has no batch_dims elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client): @@ -1779,13 +1797,13 @@ def do_constant_folding(self, fgraph, node): idx == 0 and isinstance( client_op, - pytensor.tensor.subtensor.IncSubtensor - | pytensor.tensor.subtensor.AdvancedIncSubtensor1 - | pytensor.tensor.subtensor.AdvancedIncSubtensor - | pytensor.tensor.blas.Gemv - | pytensor.tensor.blas_c.CGemv - | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer, + IncSubtensor + | AdvancedIncSubtensor1 + | AdvancedIncSubtensor + | Gemv + | CGemv + | Ger + | CGer, ) ): # Ops that will work inplace on the Alloc. So if they @@ -1966,7 +1984,7 @@ def c_code(self, node, name, inp, out_, props): """ return ret - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [(len(ishapes),)] def pullback(self, inputs, outputs, output_gradients): @@ -2254,7 +2272,7 @@ def perform(self, node, inputs, outputs_storage): for out_storage, out in zip(outputs_storage, split_outs, strict=False): out_storage[0] = out - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): axis = node.inputs[1] splits = node.inputs[2] shp_x, _shp_axis, _shp_splits = in_shapes @@ -2710,7 +2728,7 @@ def pullback(self, inputs, outputs, grads): return rval - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, ge # ishapes[0] contains the size of the axis on which we join @@ -3264,7 +3282,7 @@ def make_node(self, start, stop, step): return Apply(self, inputs, outputs) @config.change_flags(warn_float64="ignore") - def infer_shape(self, fgraph, node, i_shapes): + def infer_shape(self, node, i_shapes): from pytensor.tensor.math import ceil, maximum # Note start, stop and step can be float numbers. @@ -3641,7 +3659,7 @@ def perform(self, node, inp, out): self._rec_perform(node, x, y, self.inverse, outs[0], curdim=0) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): from pytensor.tensor.math import maximum shp_x = in_shapes[0] @@ -3893,7 +3911,7 @@ def pullback(self, inputs, outputs, gout): x_grad = moveaxis(x_grad, (0, 1), (axis1, axis2)) return [x_grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): from pytensor.tensor.math import clip, minimum (in_shape,) = shapes @@ -4225,7 +4243,7 @@ def __init__(self, mode): assert mode in ("raise", "wrap", "clip") self.mode = mode - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): a_shape, choices_shape = shapes if choices_shape is None: # choices is a TypedList, not a tensor; no shape to broadcast @@ -4256,9 +4274,7 @@ def make_node(self, a, choices): choice = as_tensor_variable(choices) choice_dtype = choice.dtype - (out_shape,) = self.infer_shape( - None, None, [shape_tuple(a), shape_tuple(choice)] - ) + (out_shape,) = self.infer_shape(None, [shape_tuple(a), shape_tuple(choice)]) static_out_shape = () for s in out_shape: @@ -4361,7 +4377,7 @@ def c_code(self, node, name, inputs, out_, sub): """ return str - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [node.inputs] def c_code_cache_version(self): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 0b0dcdfc2d..edb70ac0e9 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -243,7 +243,7 @@ def perform(self, node, inputs, out_storage): out += y out_storage[0][0] = np.asarray(out, dtype=y.dtype) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -316,7 +316,7 @@ def perform(self, node, inputs, output_storage): A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) output_storage[0][0] = A - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] @@ -941,7 +941,7 @@ def perform(self, node, inp, out): z += a * np.dot(x, y) zout[0] = z - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): z_shape, _, x_shape, y_shape, _ = input_shapes return [ ( @@ -1146,7 +1146,7 @@ def make_node(self, x, y): def perform(self, node, inputs, output_storage): output_storage[0][0] = np.dot(*inputs) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = """ @@ -1249,7 +1249,7 @@ def perform(self, node, inp, out): e.args = (*e.args, x.shape, y.shape) raise - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [[input_shapes[0][0], input_shapes[1][1]]] setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz @@ -1638,7 +1638,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index ecc4ad92d1..2c8a2c99bc 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -6,7 +6,6 @@ from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType -from pytensor.graph import FunctionGraph from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.null_type import NullType from pytensor.graph.op import Op @@ -321,9 +320,7 @@ def make_node(self, *inputs): def batch_ndim(self, node: Apply) -> int: return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) - def infer_shape( - self, fgraph, node, input_shapes - ) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, input_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i @@ -354,13 +351,10 @@ def extract_core_shape_from_infer_shape(): return_dummy_inputs=True, propagate_unbatched_core_inputs=True, ) - dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False) core_input_shapes = [ input_shape[batch_ndims:] for input_shape in input_shapes ] - core_output_shapes = core_op_infer_shape( - dummy_fgraph, dummy_core_node, core_input_shapes - ) + core_output_shapes = core_op_infer_shape(dummy_core_node, core_input_shapes) if not dummy_core_inputs: # All inputs are unbatched, so the core_shape can be used as is diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 1196a5ca77..f9c6c2b2db 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -257,7 +257,7 @@ def perform(self, node, inp, out): new_shape.insert(augm, 1) out[0][0] = res.reshape(new_shape) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishp,) = shapes # transpose rval = [ishp[i] for i in self.shuffle] @@ -755,7 +755,7 @@ def _check_runtime_broadcast(node, inputs): "If broadcasting was intended, use `specify_broadcastable` on the relevant input." ) - def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: + def infer_shape(self, node, i_shapes) -> list[tuple[TensorVariable, ...]]: from pytensor.tensor.extra_ops import broadcast_shape out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) @@ -1426,7 +1426,7 @@ def perform(self, node, inp, out): output[0] = np.asarray(out, dtype=out_dtype) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes axis = self.axis if axis is None: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 61a8003cfd..34542e9bdb 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -150,7 +150,7 @@ def make_node(self, x, v, sorter=None): raise TypeError("sorter must be an integer vector", sorter.type) return Apply(self, [x, v, sorter], [out_type()]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[1]] def perform(self, node, inputs, output_storage): @@ -340,7 +340,7 @@ def pullback(self, inputs, outputs, output_gradients): f'{type(self).__name__}: unknown gradient for mode "{self.mode}"' ) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes def c_code(self, node, name, inames, onames, sub): @@ -717,7 +717,7 @@ def pullback(self, inputs, outputs, gout): return [gx, disconnected_type()] - def infer_shape(self, fgraph, node, ins_shapes): + def infer_shape(self, node, ins_shapes): i0_shapes = ins_shapes[0] repeats = node.inputs[1] out_shape = list(i0_shapes) @@ -849,7 +849,7 @@ def perform(self, node, inputs, out_): (out,) = out_ out[0] = np.bartlett(M) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): temp = node.inputs[0] M = ptb.switch(lt(temp, 0), ptb.cast(0, temp.dtype), temp) return [[M]] @@ -892,7 +892,7 @@ class FillDiagonal(Op): # See function fill_diagonal for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val): @@ -993,7 +993,7 @@ class FillDiagonalOffset(Op): # See function fill_diagonal_offset for docstring __props__ = () - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [in_shapes[0]] def make_node(self, a, val, offset): @@ -1240,7 +1240,7 @@ def perform(self, node, inputs, output_storage): else: output_storage[0][0] = outs - def infer_shape(self, fgraph, node, i0_shapes): + def infer_shape(self, node, i0_shapes): [x_shape] = i0_shapes shape0_op = Shape_i(0) out_shapes = [(shape0_op(out),) for out in node.outputs] @@ -1310,7 +1310,7 @@ def make_node(self, indices, dims): [out_type() for _i in range(ptb.get_vector_length(dims))], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] * len(node.outputs) def perform(self, node, inp, out): @@ -1387,7 +1387,7 @@ def make_node(self, *inp): [out_type()], ) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [input_shapes[0]] def perform(self, node, inp, out): diff --git a/pytensor/tensor/fourier.py b/pytensor/tensor/fourier.py index 9d35955c6f..260fc2dc09 100644 --- a/pytensor/tensor/fourier.py +++ b/pytensor/tensor/fourier.py @@ -100,7 +100,7 @@ def make_node(self, a, n, axis): ], ) - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): shape_a = in_shapes[0] n = node.inputs[1] axis = node.inputs[2] diff --git a/pytensor/tensor/linalg/constructors.py b/pytensor/tensor/linalg/constructors.py index 96ced60f9e..fbcc6d486a 100644 --- a/pytensor/tensor/linalg/constructors.py +++ b/pytensor/tensor/linalg/constructors.py @@ -34,7 +34,7 @@ def pullback(self, inputs, outputs, gout): ] return [gout[0][slc] for slc in slices] - def infer_shape(self, fgraph, nodes, shapes): + def infer_shape(self, nodes, shapes): first, second = unzip(shapes, n=2, strict=True) return [(pt.add(*first), pt.add(*second))] diff --git a/pytensor/tensor/linalg/decomposition/cholesky.py b/pytensor/tensor/linalg/decomposition/cholesky.py index 0fa7a34b3b..556ead9bf2 100644 --- a/pytensor/tensor/linalg/decomposition/cholesky.py +++ b/pytensor/tensor/linalg/decomposition/cholesky.py @@ -33,7 +33,7 @@ def __init__( if self.overwrite_a: self.destroy_map = {0: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def make_node(self, x): diff --git a/pytensor/tensor/linalg/decomposition/eigen.py b/pytensor/tensor/linalg/decomposition/eigen.py index f2afe7cf39..ba617d11f2 100644 --- a/pytensor/tensor/linalg/decomposition/eigen.py +++ b/pytensor/tensor/linalg/decomposition/eigen.py @@ -65,7 +65,7 @@ def perform(self, node, inputs, outputs): outputs[0][0] = w.astype(dtype, copy=False) outputs[1][0] = v.astype(dtype, copy=False) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shapes,) = shapes n, _ = x_shapes @@ -206,7 +206,7 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": return self return type(self)(**new_props) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n,), (n, n)] @@ -416,7 +416,7 @@ def make_node(self, a, b=None): w = vector(dtype=out_dtype, shape=(N,)) return Apply(self, inputs, [w]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [ (n,), diff --git a/pytensor/tensor/linalg/decomposition/lu.py b/pytensor/tensor/linalg/decomposition/lu.py index 2b4edb621b..db8ac60fe5 100644 --- a/pytensor/tensor/linalg/decomposition/lu.py +++ b/pytensor/tensor/linalg/decomposition/lu.py @@ -42,7 +42,7 @@ def __init__(self, *, permute_l=False, overwrite_a=False, p_indices=False): if self.overwrite_a: self.destroy_map = {0: [0]} if self.permute_l else {1: [0]} - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] if self.permute_l: return [(n, n), (n, n)] @@ -258,7 +258,7 @@ def make_node(self, A): return Apply(self, [A], [LU, pivots]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): n = shapes[0][0] return [(n, n), (n,)] diff --git a/pytensor/tensor/linalg/decomposition/qr.py b/pytensor/tensor/linalg/decomposition/qr.py index 9e9270259d..9d81e8eb7b 100644 --- a/pytensor/tensor/linalg/decomposition/qr.py +++ b/pytensor/tensor/linalg/decomposition/qr.py @@ -103,7 +103,7 @@ def make_node(self, x): return Apply(self, [x], outputs) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape diff --git a/pytensor/tensor/linalg/decomposition/schur.py b/pytensor/tensor/linalg/decomposition/schur.py index 737ca9dc1d..af3e2035ea 100644 --- a/pytensor/tensor/linalg/decomposition/schur.py +++ b/pytensor/tensor/linalg/decomposition/schur.py @@ -146,7 +146,7 @@ def perform(self, node, inputs, outputs): T_out[0] = T Z_out[0] = Z - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0], shapes[0]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": @@ -489,7 +489,7 @@ def perform(self, node, inputs, outputs): alpha_out[0] = alpha beta_out[0] = beta - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): A_shape, B_shape = shapes if self.return_eigenvalues: return [A_shape, B_shape, (A_shape[0],), (A_shape[0],), A_shape, B_shape] diff --git a/pytensor/tensor/linalg/decomposition/svd.py b/pytensor/tensor/linalg/decomposition/svd.py index 584eae1419..361b5cdd34 100644 --- a/pytensor/tensor/linalg/decomposition/svd.py +++ b/pytensor/tensor/linalg/decomposition/svd.py @@ -93,7 +93,7 @@ def perform(self, node, inputs, outputs): (s,) = outputs s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (x_shape,) = shapes M, N = x_shape K = ptm.minimum(M, N) diff --git a/pytensor/tensor/linalg/inverse.py b/pytensor/tensor/linalg/inverse.py index 82878c8777..6c205d6ae7 100644 --- a/pytensor/tensor/linalg/inverse.py +++ b/pytensor/tensor/linalg/inverse.py @@ -61,7 +61,7 @@ def pullback(self, inputs, outputs, g_outputs): ).T return [grad] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [list(reversed(shapes[0]))] @@ -159,7 +159,7 @@ def pushforward(self, inputs, outputs, eval_points): return [-matrix_dot(xi, ev, xi)] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return shapes @@ -187,7 +187,7 @@ def perform(self, node, inputs, outputs): (x,) = outputs x[0] = np.linalg.tensorinv(a, self.ind) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): sp = shapes[0][self.ind :] + shapes[0][: self.ind] return [sp] diff --git a/pytensor/tensor/linalg/products.py b/pytensor/tensor/linalg/products.py index 42d49b10f4..542b174c4c 100644 --- a/pytensor/tensor/linalg/products.py +++ b/pytensor/tensor/linalg/products.py @@ -74,7 +74,7 @@ def pullback(self, inputs, outputs, output_grads): return [expm(aug)[..., :n, n:]] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] diff --git a/pytensor/tensor/linalg/solvers/core.py b/pytensor/tensor/linalg/solvers/core.py index 9805d86c8f..39a46e84a8 100644 --- a/pytensor/tensor/linalg/solvers/core.py +++ b/pytensor/tensor/linalg/solvers/core.py @@ -71,7 +71,7 @@ def make_node(self, A, b): x = tensor(dtype=o_dtype, shape=b.type.shape) return Apply(self, [A, b], [x]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): Ashape, Bshape = shapes rows = Ashape[1] if len(Bshape) == 1: diff --git a/pytensor/tensor/linalg/solvers/linear_control.py b/pytensor/tensor/linalg/solvers/linear_control.py index 14478f8e03..2fb47fac34 100644 --- a/pytensor/tensor/linalg/solvers/linear_control.py +++ b/pytensor/tensor/linalg/solvers/linear_control.py @@ -82,7 +82,7 @@ def perform(self, node, inputs, outputs_storage): Y *= scale X[0] = Y - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[2]] def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": diff --git a/pytensor/tensor/linalg/summary.py b/pytensor/tensor/linalg/summary.py index bba599e17f..c76753885f 100644 --- a/pytensor/tensor/linalg/summary.py +++ b/pytensor/tensor/linalg/summary.py @@ -71,7 +71,7 @@ def pullback(self, inputs, outputs, g_outputs): (x,) = inputs return [gz * self(x) * matrix_inverse(x).T] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [()] def __str__(self): @@ -106,7 +106,7 @@ def perform(self, node, inputs, outputs): except Exception as e: raise ValueError("Failed to compute determinant", x) from e - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [(), ()] def __str__(self): diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 139d22529d..7e93db8c01 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -254,7 +254,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): return (3,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): (ishape,) = shapes if self.axis is None: return [()] @@ -3106,7 +3106,7 @@ def pushforward(self, inputs, outputs, eval_points): else: return [t2] - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshp, yshp = shapes return [[xshp[0], yshp[1]]] diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 34479e142c..7201d631e8 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -306,7 +306,7 @@ def extract_batch_shape(p, ps, n): return shape - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): _, size, *dist_params = node.inputs _, _, *param_shapes = input_shapes diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index b658cc98cb..ac01e83ef6 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -274,7 +274,7 @@ def is_nd_advanced_idx(idx, dtype) -> bool: # Use shape_feature to facilitate inferring final shape. # Check that neither the RV nor the old Subtensor are in the shape graph. - output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None) + output_shape = shape_feature.shape_tuple(indexed_rv) if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)): return None diff --git a/pytensor/tensor/random/rewriting/numba.py b/pytensor/tensor/random/rewriting/numba.py index 8d128ec698..eb4607d7c3 100644 --- a/pytensor/tensor/random/rewriting/numba.py +++ b/pytensor/tensor/random/rewriting/numba.py @@ -60,7 +60,8 @@ def introduce_explicit_core_shape_rv(fgraph, node): shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) if shape_feature: core_shape = [ - shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp)) + shape_feature.get_shape_no_cycle(rv, -i - 1) + for i in reversed(range(op.ndim_supp)) ] else: core_shape = op._supp_shape_from_params(op.dist_params(node)) diff --git a/pytensor/tensor/reshape.py b/pytensor/tensor/reshape.py index b25308fccc..b43245d31e 100644 --- a/pytensor/tensor/reshape.py +++ b/pytensor/tensor/reshape.py @@ -67,7 +67,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override] return Apply(self, [x], [output_type]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape] = shapes joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int) return [self.output_shapes(input_shape, joined_shape)] @@ -188,7 +188,7 @@ def make_node(self, x, shape): ) return Apply(self, [x, shape], [output]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): [input_shape, _] = shapes _, shape = node.inputs output_shapes = list(input_shape) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 27136c8ef4..e1451d044a 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -124,8 +124,8 @@ def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: def get_simplified_shape(x: TensorVariable, *, fgraph) -> tuple: """Return a simplified shape tuple for ``x``: shape_feature → static → ``x.shape``.""" try: - return fgraph.shape_feature.shape_of[x] - except (AttributeError, KeyError): + return fgraph.shape_feature.shape_tuple(x) + except AttributeError: pass static_shape = x.type.shape diff --git a/pytensor/tensor/rewriting/numba.py b/pytensor/tensor/rewriting/numba.py index 801878bd11..1df89fef15 100644 --- a/pytensor/tensor/rewriting/numba.py +++ b/pytensor/tensor/rewriting/numba.py @@ -99,14 +99,16 @@ def introduce_explicit_core_shape_blockwise(fgraph, node): shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) if shape_feature: core_shapes = [ - [shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)] + [ + shape_feature.get_shape_no_cycle(out, i) + for i in range(batch_ndim, out.type.ndim) + ] for out in node.outputs ] else: input_shapes = [tuple(inp.shape) for inp in node.inputs] core_shapes = [ - out_shape[batch_ndim:] - for out_shape in op.infer_shape(None, node, input_shapes) + out_shape[batch_ndim:] for out_shape in op.infer_shape(node, input_shapes) ] core_shapes = [ diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 2b2060c3e7..7d54d9af91 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -1,6 +1,3 @@ -import traceback -from io import StringIO -from typing import cast as type_cast from warnings import warn import numpy as np @@ -15,8 +12,6 @@ copy_stack_trace, node_rewriter, ) -from pytensor.graph.traversal import ancestors -from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.tensor.basic import ( Alloc, MakeVector, @@ -30,13 +25,12 @@ stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.exceptions import NotScalarConstantError, ShapeError +from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, register_stabilize, register_useless, - topo_constant_folding, ) from pytensor.tensor.shape import ( Reshape, @@ -50,616 +44,226 @@ AdvancedIncSubtensor1, IncSubtensor, Subtensor, - get_idx_list, ) -from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.variable import TensorVariable -class ShapeFeature(Feature): - r"""A `Feature` that tracks shape information in a graph. - - This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with - `Shape_i` and `MakeVector` `Op`\s. - - This `Feature` and its associated rewrites have several goals: - - 1. to "lift" `Shape`\s to as close to the inputs as possible, - 2. to infer the shape of every node in the graph in terms of the - input shapes, and - 3. remove fill `Op`\s (e.g. `Second`) from the graph. - - Lifting shapes as close to the inputs as possible is important for - canonicalization because it is very bad form to have to compute - something just to know how big it will be. Firstly, it is a waste - of time to compute such outputs. But it is important to get rid - of these outputs as early as possible in the compilation process - because the extra computations make it appear as if many internal - graph nodes have multiple clients. Many rewrites refuse to - work on nodes with multiple clients. - - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven rewrites. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. - - Notes - ----- - To use this shape information in rewrites, use the - ``shape_of`` dictionary. +class _ShapeOfProxy: + """Dict-like proxy so ``shape_feature.shape_of[var]`` keeps working.""" - For example: - - .. code-block:: python + def __init__(self, feature): + self._feature = feature - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return + def __getitem__(self, var): + result = self._feature.shape_tuple(var) + if result is None: + raise KeyError(var) + return result - shape_of_output_zero = shape_of[node.output[0]] + def __contains__(self, var): + return hasattr(var.type, "ndim") - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes PyTensor - constants?? That would be confusing. +class ShapeFeature(Feature): + r"""Lazy `Feature` that provides shape information via ``get_shape``. + Shapes are derived on demand by calling ``Op.infer_shape`` on the + current (live) node inputs — never cached across graph mutations. + This prevents stale variable references from being reintroduced + when intermediate nodes are replaced (e.g. xtensor lowering). """ - def get_node_infer_shape(self, node): - try: - shape_infer = node.op.infer_shape - except AttributeError: - shape_infer = self.default_infer_shape - - try: - o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except ShapeError: - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except NotImplementedError as e: - raise NotImplementedError( - "Code called by infer_shape failed raising a " - "NotImplementedError. Raising NotImplementedError to " - "indicate that a shape cannot be computed is no longer " - "supported, and one should now use ShapeError " - f"instead. The original exception message is: {e}" - ).with_traceback(e.__traceback__) - except Exception as e: - msg = ( - f"Failed to infer_shape from Op {node.op}.\nInput shapes: " - f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " - f"{type(e)}\nException message: {e!s}\nTraceback: {traceback.format_exc()}" - ) - if config.on_shape_error == "raise": - raise Exception(msg).with_traceback(e.__traceback__) - else: - warn(msg) - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - - return o_shapes - - def get_shape(self, var, idx): - """Rewrites can call this to get a `Shape_i`. - - It is better to call this then use directly ``shape_of[var][idx]`` - as this method should update `shape_of` if needed. - - TODO: Up to now, we don't update it in all cases. Update in all cases. - """ - r = self.shape_of[var][idx] - if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables - ): - assert var.owner - node = var.owner - # recur on inputs - for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) - assert len(o_shapes) == len(node.outputs) - - # Only change the variables and dimensions that would introduce - # extra computation - for new_shps, out in zip(o_shapes, node.outputs, strict=True): - if not hasattr(out.type, "ndim"): - continue - - merged_shps = list(self.shape_of[out]) - changed = False - for i in range(out.type.ndim): - n_r = merged_shps[i] - if ( - n_r.owner - and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables - ): - changed = True - merged_shps[i] = new_shps[i] - if changed: - self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") + def __init__(self): + self.fgraph: FunctionGraph | None = None + # node -> tuple of (tuple of shape vars) per output, lazily populated + self._cache: dict = {} + # var -> {i: Shape_i(i)(v)}, ensures Apply identity for leaves + self._shape_i_cache: dict = {} + self.lscalar_one = constant(1, dtype="int64") + # Compat: scheduled replacements for local_track_shape_i + self.scheduled: dict = {} + + def _shape_i_var(self, v, i): + per_dim = self._shape_i_cache.get(v) + if per_dim is not None: + cached = per_dim.get(i) + if cached is not None: + return cached else: - s = Shape_i(i)(r) - try: - s = get_scalar_constant_value(s) - except NotScalarConstantError: - pass + per_dim = {} + self._shape_i_cache[v] = per_dim + if hasattr(v.type, "shape") and v.type.shape[i] is not None: + res = constant(v.type.shape[i], dtype="int64") + else: + res = Shape_i(i)(v) + per_dim[i] = res + return res + + def _coerce_shape_el(self, s, node): + """Validate and normalize a single shape element from infer_shape.""" + if isinstance(s, np.ndarray): + if s.ndim != 0: + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"ndarray for shape element: {s!r}" + ) + s = s.item() + if isinstance(s, Variable): + if s.type.dtype not in integer_dtypes: + raise TypeError( + f"infer_shape for {node.op} returned a non-integer " + f"Variable for shape element: {s!r}" + ) + if getattr(s.type, "ndim", 0): + raise TypeError( + f"infer_shape for {node.op} returned a non-scalar " + f"Variable for shape element: {s!r}" + ) + if s.type.dtype != "int64": + if isinstance(s, Constant): + return constant(int(s.data), dtype="int64") + return cast(s, "int64") return s + if isinstance(s, int | np.integer): + if int(s) < 0: + raise ValueError( + f"infer_shape for {node.op} returned a negative shape: {int(s)}" + ) + return constant(int(s), dtype="int64") + raise TypeError( + f"infer_shape for {node.op} returned an unsupported " + f"shape element of type {type(s).__name__}: {s!r}" + ) - def shape_tuple(self, r): - """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): - # This happen for NoneConst. - return None - return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - - def default_infer_shape(self, fgraph, node, i_shapes): - """Return a list of shape tuple or None for the outputs of node. - - This function is used for Ops that don't implement infer_shape. - Ops that do implement infer_shape should use the i_shapes parameter, - but this default implementation ignores it. - - """ - rval = [] - for r in node.outputs: - try: - rval.append(self.shape_tuple(r)) - except AttributeError: - rval.append(None) - return rval - - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. - - The s_i argument was produced by the infer_shape() of an Op subclass. - - var: the variable that correspond to s_i. This is just for - error reporting. - - """ - assert s_i is not None + def _get_node_shapes(self, node): + """Call infer_shape and return validated per-output shape tuples.""" + cached = self._cache.get(node) + if cached is not None: + return cached + + input_shapes = [] + for inp in node.inputs: + if hasattr(inp.type, "ndim"): + input_shapes.append( + tuple(self.get_shape(inp, j) for j in range(inp.type.ndim)) + ) + else: + input_shapes.append(None) - if s_i == 1: - return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, np.integer | int) or ( - isinstance(s_i, np.ndarray) and s_i.ndim == 0 - ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, tuple | list): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] - if ( - s_i.owner - and isinstance(s_i.owner.op, Subtensor) - and s_i.owner.inputs[0].owner - and isinstance(s_i.owner.inputs[0].owner.op, Shape) - ): - assert s_i.type.ndim == 0 - assert len(s_i.owner.op.idx_list) == 1 - - # The current Subtensor always put constant index in the graph. - # This was not True in the past. So call the Subtensor function - # that will return the right index. - idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) - assert len(idx) == 1 - idx = idx[0] + output_shapes = None + shape_infer = getattr(node.op, "infer_shape", None) + if shape_infer is not None: try: - i = get_scalar_constant_value(idx) - except NotScalarConstantError: + output_shapes = shape_infer(node, input_shapes) + except ShapeError: pass - else: - # Executed only if no exception was raised - x = s_i.owner.inputs[0].owner.inputs[0] - # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] - - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) - - def set_shape(self, r, s, override=False): - """Assign the shape `s` to previously un-shaped variable `r`. - - Parameters - ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. - - """ - if not override: - assert r not in self.shape_of, "r already in shape_of" - if s is None: - self.shape_of[r] = s - else: - if not isinstance(s, tuple | list): - raise TypeError("shapes must be tuple/list", (r, s)) - - if r.type.ndim != len(s): - sio = StringIO() - pytensor.printing.debugprint(r, file=sio, print_type=True) - raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" + except NotImplementedError: + pass + except Exception as exc: + if config.on_shape_error == "raise": + raise + warn( + f"Failed to infer_shape from Op {node.op}: " + f"{type(exc).__name__}: {exc}" ) - shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[i] != 1 - or self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals( - get_scalar_constant_value(shape_vars[i], raise_not_constant=False) + result = [] + for k, out in enumerate(node.outputs): + if not hasattr(out.type, "ndim"): + result.append(None) + continue + sh = None + if output_shapes is not None and k < len(output_shapes): + sh = output_shapes[k] + if sh is None or not isinstance(sh, list | tuple): + result.append( + tuple(self._shape_i_var(out, j) for j in range(out.type.ndim)) ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. - - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. - - """ - # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] + continue + coerced = [] + for j, s in enumerate(sh): + coerced.append(self._coerce_shape_el(s, node)) + result.append(tuple(coerced)) - # If other_shape has no information, call is pointless. - if other_shape is None: - return + result = tuple(result) + self._cache[node] = result + return result - if r in self.shape_of: - r_shape = self.shape_of[r] - else: - # If no info is known on r's shape, use other_shape - self.set_shape(r, other_shape) - return - if ( - other_r.owner - and r.owner - and other_r.owner.inputs == r.owner.inputs - and other_r.owner.op == r.owner.op - ): - # We are doing a merge, so the two shape graphs will be the - # same. This is only done so that we call `ancestors` less - # frequently. - return - - # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] - for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner - and isinstance(ps.owner.op, Shape_i) - and ps.owner.op.i == i - and ps.owner.inputs[0] in (r, other_r) - ): - # If other_shape[i] is uninformative, use r_shape[i]. - # For now, we consider 2 cases of uninformative other_shape[i]: - # - Shape_i(i)(other_r); - # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], Constant | int): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif any( - ( - r_shape[i] == anc - or ( - anc.owner - and isinstance(anc.owner.op, Shape) - and anc.owner.inputs[0] == r - ) - ) - for anc in ancestors([other_shape[i]]) - ): - # Another case where we want to use r_shape[i] is when - # other_shape[i] actually depends on r_shape[i]. In that case, - # we do not want to substitute an expression with another that - # is strictly more complex. Such a substitution could also lead - # to cycles: if (in the future) r_shape[i] gets replaced by an - # expression of other_shape[i], other_shape[i] may end up - # depending on itself. - merged_shape.append(r_shape[i]) - else: - merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "shape") - or (r.type.shape[i] != 1 and other_r.type.shape[i] != 1) - ) - or self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - get_scalar_constant_value( - merged_shape[i], - only_process_constants=True, - raise_not_constant=False, - ) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def set_shape_i(self, r, i, s_i): - """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of - prev_shape = self.shape_of[r] - # prev_shape is a tuple, so we cannot change it inplace, - # so we build another one. - new_shape = [] - for j, s_j in enumerate(prev_shape): - if j == i: - new_shape.append(self.unpack(s_i, r)) - else: - new_shape.append(s_j) - assert all( - not hasattr(r.type, "shape") - or r.type.shape[idx] != 1 - or self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals( - get_scalar_constant_value(new_shape[idx], raise_not_constant=False) - ) - for idx in range(r.type.ndim) + def get_shape(self, var, idx): + """Return a symbolic expression for ``var.shape[idx]``.""" + if hasattr(var.type, "shape") and var.type.shape[idx] is not None: + return constant(var.type.shape[idx], dtype="int64") + + node = var.owner + if node is None: + return self._shape_i_var(var, idx) + + node_shapes = self._get_node_shapes(node) + out_idx = node.outputs.index(var) + sh = node_shapes[out_idx] + if sh is not None: + return sh[idx] + return self._shape_i_var(var, idx) + + def shape_tuple(self, var): + if not hasattr(var.type, "ndim"): + return None + return tuple(self.get_shape(var, i) for i in range(var.type.ndim)) + + @property + def shape_of(self): + """Deprecated back-compat shim. Use ``shape_tuple(var)`` instead.""" + warn( + "ShapeFeature.shape_of is deprecated; use shape_tuple(var) instead.", + DeprecationWarning, + stacklevel=2, ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) + return _ShapeOfProxy(self) - def init_r(self, r): - """Register r's shape in the shape_of dictionary.""" - if r not in self.shape_of: - self.set_shape(r, self.shape_tuple(r)) + def get_shape_no_cycle(self, var, idx): + """Like ``get_shape`` but breaks destroy-handler aliasing cycles.""" + s = self.get_shape(var, idx) + destroyers = getattr(self.fgraph, "destroyers", None) + if destroyers is not None: + from pytensor.graph.replace import break_aliasing_cycles - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + [s] = break_aliasing_cycles([s], destroyers) + return s def on_attach(self, fgraph): if hasattr(fgraph, "shape_feature"): raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - - if hasattr(self, "fgraph") and self.fgraph != fgraph: + if self.fgraph is not None and self.fgraph is not fgraph: raise Exception("This ShapeFeature is already attached to a graph") - self.fgraph = fgraph - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") - assert self.lscalar_one.type.dtype == "int64" - - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} - - for node in fgraph.toposort(): - self.on_import(fgraph, node, reason="on_attach") def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} + self._cache.clear() + self._shape_i_cache.clear() + self.scheduled.clear() self.fgraph = None - del fgraph.shape_feature - - def on_import(self, fgraph, node, reason): - if node.outputs[0] in self.shape_of: - # this is a revert, not really an import - for r in node.outputs + node.inputs: - assert r in self.shape_of - return - - for i, r in enumerate(node.inputs): - # make sure we have shapes for the inputs - self.init_r(r) - - o_shapes = self.get_node_infer_shape(node) - - # this is packed information - # an element of o_shapes is either None or a tuple - # elements of the tuple can be either strings, or ints - if len(o_shapes) != len(node.outputs): - raise Exception( - f'The infer_shape method for the Op "{node.op}" returned a list ' - f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " - f" != len(node.outputs) = {len(node.outputs)}" - ) + if hasattr(fgraph, "shape_feature"): + del fgraph.shape_feature - # Ensure shapes are in 'int64'. This is to make sure the assert - # found in the `local_useless_subtensor` rewrite does not fail. - for sh_idx, sh in enumerate(o_shapes): - if sh is None: - continue - if not isinstance(sh, list | tuple): - raise ValueError( - f"infer_shape of {node} didn't return a list of" - f" list. It returned '{o_shapes}'" - ) - new_shape = [] - for i, d in enumerate(sh): - # Note: we ignore any shape element that is not typed (i.e., - # does not have a 'dtype' attribute). This means there may - # still remain int elements that are int32 on 32-bit platforms, - # but this works with `local_useless_subtensor`, so for now we - # keep it this way. See #266 for a better long-term fix. - if getattr(d, "dtype", "int64") != "int64": - assert d.dtype in discrete_dtypes, (node, d.dtype) - assert str(d.dtype) != "uint64", node - new_shape += sh[len(new_shape) : i + 1] - if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") - else: - casted_d = cast(d, "int64") - new_shape[i] = casted_d - if new_shape: - # We replace the shape with wrong dtype by the one with - # 'int64'. - new_shape += sh[len(new_shape) :] - o_shapes[sh_idx] = tuple(new_shape) - - for r, s in zip(node.outputs, o_shapes, strict=True): - self.set_shape(r, s) + def on_prune(self, fgraph, node, reason): + self._cache.pop(node, None) + for out in node.outputs: + self._shape_i_cache.pop(out, None) def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) - - # This tells us that r and new_r must have the same shape if - # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) - - # change_input happens in two cases: - # 1) we are trying to get rid of r, or - # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - # This schedule is processed by `local_track_shape_i`. - for shpnode, idx in fgraph.clients[r] + [(node, i)]: - if isinstance(shpnode.op, Shape_i): - idx = shpnode.op.i - repl = self.shape_of[new_r][idx] - if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. - continue - if ( - repl.owner - and repl.owner.inputs[0] is shpnode.inputs[0] - and isinstance(repl.owner.op, Shape_i) - and repl.owner.op.i == shpnode.op.i - ): - # The replacement is a shape_i of the same - # input. So no need to do this equivalent - # replacement. - continue + if r is new_r: + return + # Invalidate cached shapes for the node whose input changed + self._cache.pop(node, None) - if shpnode.outputs[0] in ancestors([repl]): - raise InconsistencyError( - "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" - ) - - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, - # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] - for k in unscheduled: - del self.scheduled[k] - - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): - # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() + # Schedule Shape_i(r) replacements for local_track_shape_i + if hasattr(r.type, "ndim"): + for shpnode, _idx in fgraph.clients.get(r, []): + if isinstance(getattr(shpnode, "op", None), Shape_i): + self.scheduled[shpnode] = new_r def same_shape( self, @@ -668,63 +272,27 @@ def same_shape( dim_x: int | None = None, dim_y: int | None = None, ) -> bool: - """Return ``True`` if `x` and `y` have the same shape. - - Parameters - ========== - x - The `Variable` for which its shape is to be compared with `y`'s shape. - y - The `Variable` for which its shape is to be compared with `x`'s shape. - dim_x - If non ``None``, compare only the dimension of `x` equal to - `dim_x`. - dim_y - If non ``None``, compare only the dimension of `y` equal to - `dim_y`. - - """ - sx = self.shape_of[x] - sy = self.shape_of[y] - + """Return True if we can statically prove x and y have the same shape.""" + sx = self.shape_tuple(x) + sy = self.shape_tuple(y) if sx is None or sy is None: return False - if dim_x is not None: - sx = [sx[dim_x]] - + sx = (sx[dim_x],) if dim_y is not None: - sy = [sy[dim_y]] - + sy = (sy[dim_y],) if len(sx) != len(sy): return False - - # Canonicalize the graphs so that comparisons are reasonable - # TODO FIXME: This should *not* need to be performed manually here. - # Instead, the shape information in `self.shape_of` should be operated - # upon alongside all the other elements in a `FunctionGraph` (e.g. as - # if `self.shape_of.values()` were additional outputs). - shapes_fg = FunctionGraph( - outputs=sx + sy, - # features=[self], - clone=True, - # copy_inputs=False, - ) - from pytensor.graph.rewriting.utils import rewrite_graph - - canon_shapes_fg = type_cast( - FunctionGraph, - rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), - ) - canon_shapes = canon_shapes_fg.outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - for dx, dy in zip(sx, sy, strict=True): - if not equal_computations([dx], [dy]): + if dx is dy: + continue + if isinstance(dx, Constant) and isinstance(dy, Constant): + if dx.data == dy.data: + continue return False - + if equal_computations([dx], [dy]): + continue + return False return True def clone(self): @@ -1302,7 +870,7 @@ def local_shape_to_shape_i(fgraph, node): if not hasattr(fgraph, "shape_feature"): return shape_feature = fgraph.shape_feature - ret = shape_feature.make_vector_shape(node.inputs[0]) + ret = as_tensor_variable(shape_feature.shape_tuple(node.inputs[0]), dtype="int64") # We need to copy over stack trace from input to output copy_stack_trace(node.outputs[0], ret) @@ -1314,44 +882,40 @@ def local_shape_to_shape_i(fgraph, node): @register_canonicalize @node_rewriter([Shape_i]) def local_track_shape_i(fgraph, node): - """ - Update `Shape_i` nodes to match `ShapeFeature`'s internal state. - - This rewrite is essential for propagating shape information during graph - transformations (like lowering). When a node is replaced or updated, - `ShapeFeature` calculates the shape of the new node and "schedules" - dependent `Shape_i` nodes for update, so they use the latest inferred graph. - - If we start with an fgraph containing the two nodes below: - >> out = OpWithoutInferShape(a, b) - >> out_shape_i = Shape_i(out) + """Replace ``Shape_i(v, i)`` with the inferred shape expression. - And then rewrite - >> new_out = OpWithInferShape(a, b) - >> fgraph.replace(out, new_out) - - We end up with - >> out_shape_i == Shape_i(new_out) + When ``v.owner.op`` has ``infer_shape``, ``get_shape(v, i)`` returns + a non-``Shape_i`` expression. Rewriting the literal ``Shape_i(v, i)`` + with that expression lets downstream rewrites see the inferred form + and typically lets the original producer of ``v`` be pruned when only + its shape is consumed. + """ + shape_feature = getattr(fgraph, "shape_feature", None) + if shape_feature is None: + return False - If installed, ShapeFeature will do this work in the background - >> new_out_shape = infer_shape(new_out) # Usually some f(a, b) - >> fgraph.shape_feature.scheduled[out_shape_i.owner] = new_out_shape + # Handle scheduled replacements from on_change_input + replacement = shape_feature.scheduled.pop(node, None) + if replacement is not None: + return [shape_feature.get_shape_no_cycle(replacement, node.op.i)] - And this rewrite will ultimately propagate the inference back to the fgraph - >> new_out_shape_i = fgraph.shape_feature.scheduled[out_shape_i.owner][i] - >> fgraph.replace(out_shape_i, new_out_shape_i) + [v] = node.inputs + if v.owner is None: + return False - """ - try: - shape_feature = fgraph.shape_feature - except AttributeError: + i = node.op.i + new_shape = shape_feature.get_shape_no_cycle(v, i) + if new_shape is None: return False - if node not in shape_feature.scheduled: + # Avoid replacing Shape_i(v, i) with itself + if new_shape.owner is node or ( + isinstance(new_shape, Variable) + and new_shape.owner is not None + and isinstance(new_shape.owner.op, Shape_i) + and new_shape.owner.op.i == i + and new_shape.owner.inputs[0] is v + ): return False - # Don't unschedule node as it could be reinserted in the - # fgraph as we don't change it in the shapefeature internal - # structure. - replacement = shape_feature.scheduled[node] - return [shape_feature.shape_of[replacement][node.op.i]] + return [new_shape] diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 6c07ddb807..d677f28386 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -883,12 +883,12 @@ def _local_subtensor_merge_rewrite(fgraph, node, *, merge_integer_index): indices_outer = unflatten_index_variables(outer_index_vars, node.op.idx_list) try: - xshape = fgraph.shape_feature.shape_of[x] + xshape = fgraph.shape_feature.shape_tuple(x) except AttributeError: xshape = tuple(x.shape) try: - ushape = fgraph.shape_feature.shape_of[u] + ushape = fgraph.shape_feature.shape_tuple(u) except AttributeError: ushape = tuple(u.shape) @@ -1201,7 +1201,7 @@ def local_useless_subtensor(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature cdata = get_constant_idx( node.op.idx_list, @@ -1223,7 +1223,7 @@ def local_useless_subtensor(fgraph, node): # is not a useless subtensor return False - length_pos = shape_of[node.inputs[0]][pos] + length_pos = shape_feature.get_shape(node.inputs[0], pos) if isinstance(idx.stop, int | np.integer): length_pos_data = sys.maxsize @@ -1327,12 +1327,12 @@ def local_useless_AdvancedSubtensor1(fgraph, node): if not hasattr(fgraph, "shape_feature"): return - shape_of = fgraph.shape_feature.shape_of + shape_feature = fgraph.shape_feature # get length of the indexed tensor along the first axis try: length = get_scalar_constant_value( - shape_of[node.inputs[0]][0], only_process_constants=True + shape_feature.get_shape(node.inputs[0], 0), only_process_constants=True ) except NotScalarConstantError: return False @@ -2417,7 +2417,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): # need it for this optimization, so don't continue. return False - shape_of = shape_feature.shape_of same_shape = shape_feature.same_shape # Get the subtensor of `x` indexed by `i` in order to compare @@ -2431,22 +2430,12 @@ def local_useless_inc_subtensor_alloc(fgraph, node): else: raise Exception("Should never happen!") - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - # `xi` may have more dimensions than `y` since the subtensor ops # do automatic broadcasting of the increment internally. Thus, we # need to make the leading implicitly broadcasted dimensions # explicit for shape comparison later. if xi.ndim > y.ndim: y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") # Build `z_broad` explicitly to include extra implicit dimensions. z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable @@ -2479,7 +2468,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): if ( z_broad[k] and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 + and shape_feature.get_shape(y, k) != 1 ) ] diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 3a7202acfc..1b43eaec6e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -81,7 +81,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = np.asarray(np.shape(x), dtype="int64") - def infer_shape(self, fgraph, node, in_shapes): + def infer_shape(self, node, in_shapes): return [[len(in_shapes[0])]] def connection_pattern(self, node): @@ -297,7 +297,7 @@ def c_code(self, node, name, inames, onames, sub): # Else, no C code raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return [()] def connection_pattern(self, node): @@ -339,21 +339,7 @@ def shape_i(var, i, fgraph=None): """ if fgraph and hasattr(fgraph, "shape_feature"): - shape_feature = fgraph.shape_feature - shape_of = shape_feature.shape_of - - def recur(node): - if node.outputs[0] not in shape_of: - for inp in node.inputs: - if inp.owner: - recur(inp.owner) - # If the output var isn't marked as being in the graph, - # we need to add it in the ShapeFeature. - shape_feature.on_import(fgraph, node, "graph.ops.shape_i") - - if var not in shape_of: - recur(var.owner) - return shape_of[var][i] + return fgraph.shape_feature.get_shape(var, i) # If we are not able to use the shape feature, we should not put # Shape_i in the graph. Otherwise, the shape feature optimization @@ -452,7 +438,7 @@ def perform(self, node, inp, out_): ) out[0] = x - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): xshape, *_ = shapes shape = node.inputs[1:] # Use x shape if specified dim is None, otherwise the specified shape @@ -727,7 +713,7 @@ def pushforward(self, inputs, outputs, eval_points): return [disconnected_type()] return self(eval_points[0], *inputs[1:], return_list=True) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): from pytensor.tensor.math import eq, maximum, mul # inputs[1] can contain at most one value of '-1', meaning the actual diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index c3133e3c15..51dd796a52 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -82,7 +82,7 @@ def make_node(self, in1, in2, full_mode): out = tensor(dtype=dtype, shape=out_shape) return Apply(self, [in1, in2, full_mode], [out]) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): _, _, full_mode = node.inputs in1_shape, in2_shape, _ = shapes out_shape = [ diff --git a/pytensor/tensor/sort.py b/pytensor/tensor/sort.py index af695d9e42..c911be988d 100644 --- a/pytensor/tensor/sort.py +++ b/pytensor/tensor/sort.py @@ -54,7 +54,7 @@ def perform(self, node, inputs, output_storage): z = output_storage[0] z[0] = np.sort(a, axis, self.kind) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] @@ -185,7 +185,7 @@ def perform(self, node, inputs, output_storage): dtype=node.outputs[0].dtype, ) - def infer_shape(self, fgraph, node, inputs_shapes): + def infer_shape(self, node, inputs_shapes): assert node.inputs[0].ndim == node.outputs[0].ndim assert inputs_shapes[1] == () return [inputs_shapes[0]] diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index f180d2e9c6..293ca8850b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -943,7 +943,7 @@ def perform(self, node, inputs, out_): cdata = unflatten_index_variables(index_variables, self.idx_list) out[0] = np.asarray(x.__getitem__(tuple(cdata))) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): def _is_constant(const, x): return isinstance(const, Constant) and const.data.item() == x @@ -1789,7 +1789,7 @@ def add_to_zview(self, name, x, fail): {fail}; }}""" - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def pushforward(self, inputs, outputs, eval_points): @@ -1967,7 +1967,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, ilist = ishapes return [ilist + x[1:]] @@ -2317,7 +2317,7 @@ def perform(self, node, inputs, output_storage): output_storage[0][0] = x - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): x, _y, _ilist = ishapes return [x] @@ -2492,7 +2492,7 @@ def pushforward(self, inputs, outputs, eval_points): _x, *index_variables = inputs return self.make_node(eval_points[0], *index_variables).outputs - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): def is_bool_index(idx): return ( isinstance(idx, np.bool_ | bool) @@ -2719,7 +2719,7 @@ def perform(self, node, inputs, out_): else: np.add.at(out[0], tuple(full_indices), y) - def infer_shape(self, fgraph, node, ishapes): + def infer_shape(self, node, ishapes): return [ishapes[0]] def connection_pattern(self, node): diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 1a7e681c22..8662faded3 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -83,11 +83,16 @@ def shape_of_variables( shape_feature = fgraph.shape_feature input_dims = [ - dimension for inp in fgraph.inputs for dimension in shape_feature.shape_of[inp] + dimension + for inp in fgraph.inputs + for dimension in shape_feature.shape_tuple(inp) ] output_dims = [ - dimension for shape in shape_feature.shape_of.values() for dimension in shape + dimension + for var in fgraph.variables + if hasattr(var.type, "ndim") + for dimension in shape_feature.shape_tuple(var) ] compute_shapes = pytensor.function(input_dims, output_dims) @@ -105,8 +110,10 @@ def shape_of_variables( sym_to_num_dict = dict(zip(output_dims, numeric_output_dims, strict=True)) l = {} - for var in shape_feature.shape_of: - l[var] = tuple(sym_to_num_dict[sym] for sym in shape_feature.shape_of[var]) + for var in fgraph.variables: + shape = shape_feature.shape_tuple(var) + if shape is not None: + l[var] = tuple(sym_to_num_dict[sym] for sym in shape) return l diff --git a/pytensor/xtensor/basic.py b/pytensor/xtensor/basic.py index 3e02c75ce9..09a8d8fe1f 100644 --- a/pytensor/xtensor/basic.py +++ b/pytensor/xtensor/basic.py @@ -30,7 +30,7 @@ class XTypeCastOp(TypeCastingOp): This is like a `ViewOp` but without the expectation the input and output have identical types. """ - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): return input_shapes def vectorize_node( diff --git a/tests/benchmarks/test_rewriting.py b/tests/benchmarks/test_rewriting.py index cea3c4e676..b7b9f3221f 100644 --- a/tests/benchmarks/test_rewriting.py +++ b/tests/benchmarks/test_rewriting.py @@ -1,9 +1,13 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt +import pytensor.xtensor as px from pytensor import config from pytensor.graph import FunctionGraph +from pytensor.graph.rewriting import rewrite_graph +from pytensor.xtensor.shape import stack as xstack def _large_fuseable_graph(n): @@ -66,3 +70,50 @@ def rewrite_func(): assert rewrite_func() == expected_n_repl benchmark.pedantic(rewrite_func, rounds=7, iterations=5) + + +def _xtensor_attention_graph(n_layers): + B, T, E, H, HD = 4, 32, 64, 4, 16 + rng = np.random.default_rng(0) + + def attn(x): + Wqkv = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, 3, H, HD))), + dims=("embd", "qkv", "head", "hd"), + ) + Wproj = px.as_xtensor( + pytensor.shared(rng.normal(size=(E, E))), + dims=("embd", "embd_out"), + ) + qkv = px.dot(x, Wqkv, dim="embd") + q = qkv.isel(qkv=0).rename(time="time_q") + k = qkv.isel(qkv=1).rename(time="time_k") + v = qkv.isel(qkv=2).rename(time="time_k") + s = px.dot(q, k, dim="hd") / np.sqrt(HD) + mask = px.as_xtensor( + pt.tril(pt.ones((T, T), dtype="bool")), + dims=("time_q", "time_k"), + ) + a = px.math.softmax(px.where(mask, s, np.float64(-1e9)), dim="time_k") + o = xstack(px.dot(a, v, dim="time_k"), embd=("head", "hd")) + return px.dot(o, Wproj, dim="embd").rename(time_q="time", embd_out="embd") + + x_t = pt.tensor("x", shape=(B, T, E)) + x = px.as_xtensor(x_t, dims=("batch", "time", "embd")) + for _ in range(n_layers): + x = attn(x) + return x_t, x.values.sum() + + +@pytest.mark.parametrize("n_layers", [2, 3, 4]) +def test_xtensor_attention_rewrite_benchmark(n_layers, benchmark): + x_t, loss = _xtensor_attention_graph(n_layers) + + def rewrite_once(): + lowered = rewrite_graph(loss, include=("lower_xtensor",), clone=True) + grad = pt.grad(lowered, x_t) + return rewrite_graph( + [lowered, grad], include=("fast_run",), exclude=("inplace",), clone=True + ) + + benchmark(rewrite_once) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index f8180773cc..c6f1f1e74c 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -460,8 +460,8 @@ def test_infer_shape(self): fg = FunctionGraph(outputs=[op_var[1]], clone=False) opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer()) - assert opt_res.shape_feature.shape_of[x] is None - assert opt_res.shape_feature.shape_of[z][0].data == 2 + assert opt_res.shape_feature.shape_tuple(x) is None + assert opt_res.shape_feature.shape_tuple(z)[0].data == 2 def test_make_node_shared(self): """Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`.""" diff --git a/tests/compile/test_ops.py b/tests/compile/test_ops.py index a30ed6475d..1954c6cbb6 100644 --- a/tests/compile/test_ops.py +++ b/tests/compile/test_ops.py @@ -65,7 +65,7 @@ def test_infer_shape(self): x = dmatrix("x") y = dvector("y") - def infer_shape(fgraph, node, shapes): + def infer_shape(node, shapes): _x, y = shapes return [y] diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index ecaa4fee6d..585eb85a8a 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -4,7 +4,7 @@ import pytest from pytensor.configdefaults import config -from pytensor.graph.basic import NominalVariable +from pytensor.graph.basic import NominalVariable, equal_computations from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint @@ -946,3 +946,77 @@ def test_freeze_unfreeze_round_trip(self): assert ffg == refrozen assert hash(ffg) == hash(refrozen) + + def test_bind_constant_output(self): + """bind must handle constants that appear directly as outputs.""" + x = float64("x") + c = ScalarConstant(float64, 42.0) + ffg = FunctionGraph([x], [add(x, c), c]).freeze() + + y = float64("y") + bound = ffg.bind({ffg.inputs[0]: y}) + assert len(bound) == 2 + assert bound[1] is c + + def test_from_structural_inputs_only_root_inputs(self): + """All inputs are roots: behaves like the plain constructor.""" + x, y = float64("x"), float64("y") + out = add(x, y) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y], [out]) + assert len(ffg.inputs) == 2 + + a, b = float64("a"), float64("b") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b], strict=True))) + assert equal_computations([res], [add(a, b)]) + + def test_from_structural_inputs_only_intermediate_inputs(self): + """Inputs may be only intermediate expressions; roots are found automatically.""" + x, y = float64("x"), float64("y") + # out depends on x, y only through the product. + out = add(mul(x, y), mul(x, y)) + + # The passed expression is matched by structure, not identity. + prod = mul(x, y) + assert prod is not out.owner.inputs[0] + + ffg = FrozenFunctionGraph.from_structural_inputs([prod], [out]) + assert len(ffg.inputs) == 1 + + p = float64("p") + [res] = ffg.bind({ffg.inputs[0]: p}) + # Both occurrences rewire to the single input. + assert equal_computations([res], [add(p, p)]) + + def test_from_structural_inputs_mixed_inputs(self): + """A root input and an intermediate input, both live.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) + + ffg = FrozenFunctionGraph.from_structural_inputs([x, mul(x, y)], [out]) + assert len(ffg.inputs) == 2 + + a, p = float64("a"), float64("p") + # x is used directly; root y is dropped (it feeds only the lifted product). + [res] = ffg.bind(dict(zip(ffg.inputs, [a, p], strict=True))) + assert equal_computations([res], [add(p, a)]) + + def test_from_structural_inputs_dead_inputs(self): + """A dead root input and a dead intermediate input are retained but ignored.""" + x, y = float64("x"), float64("y") + out = add(x, x) # uses neither y nor the product + + ffg = FrozenFunctionGraph.from_structural_inputs([x, y, mul(x, y)], [out]) + assert len(ffg.inputs) == 3 + + a, b, p = float64("a"), float64("b"), float64("p") + [res] = ffg.bind(dict(zip(ffg.inputs, [a, b, p], strict=True))) + assert equal_computations([res], [add(a, a)]) + + def test_from_structural_inputs_unreachable_output_raises(self): + """Outputs needing a root absent from the inputs cannot be expressed.""" + x, y = float64("x"), float64("y") + out = add(mul(x, y), x) # needs x directly, not only via the product + + with pytest.raises(ValueError): + FrozenFunctionGraph.from_structural_inputs([mul(x, y)], [out]) diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 2605c15b8f..5f3486e5bd 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -2,17 +2,26 @@ import pytest import scipy.special +import pytensor.scalar as ps import pytensor.tensor as pt from pytensor import config, function, shared +from pytensor.compile.ops import DeepCopyOp from pytensor.graph.basic import equal_computations +from pytensor.graph.destroyhandler import DestroyHandler, _contains_cycle +from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import ( _vectorize_node, + break_aliasing_cycles, clone_replace, graph_replace, vectorize_graph, ) -from pytensor.graph.traversal import graph_inputs +from pytensor.graph.traversal import applys_between, graph_inputs from pytensor.tensor import dvector, fvector, vector +from pytensor.tensor.basic import alloc +from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.shape import Shape_i +from pytensor.tensor.signal import convolve1d from tests import unittest_tools as utt from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs from tests.unittest_tools import assert_equal_computations @@ -373,3 +382,59 @@ def test_non_variable_raises(self): batch_out.eval({x: 3, y: 4}), np.zeros((2, 3, 4)), ) + + +def test_break_aliasing_cycles(): + """Reading a destroyed scalar and a transitive dependent of the destroyer + in the same Apply creates a destroy-handler cycle. ``break_aliasing_cycles`` + re-routes the destroyed scalar through ``deep_copy_op`` on the offending + Apply, lifting the conflict. When the destroyed scalar feeds the same + downstream Apply only transitively (via another op), no re-routing is needed. + """ + larger = pt.matrix("larger", shape=(8, None)) + smaller = pt.matrix("smaller", shape=(8, None)) + + larger_s1 = Shape_i(1)(larger) + smaller_s1 = Shape_i(1)(smaller) + + sx, sy = ps.int64(), ps.int64() + inplace_comp = Elemwise( + ps.Composite([sx, sy], [ps.sub(ps.add(sx, sy), ps.constant(1, dtype="int64"))]), + inplace_pattern={0: 0}, + ) + new_dim = inplace_comp(larger_s1, smaller_s1) + a = alloc(pt.zeros((1, 1)), 1, new_dim) + out = convolve1d(a, larger[:, ::-1], mode="full") + + fg = FunctionGraph([larger, smaller], [out], clone=False) + fg.attach_feature(DestroyHandler()) + + def imports_with_cycle(shape_vars): + check_fg = FunctionGraph([larger, smaller], [out, *shape_vars], clone=False) + check_fg.attach_feature(DestroyHandler()) + dh = check_fg.destroy_handler + return _contains_cycle(check_fg, dh.orderings(check_fg, ordered=False)) + + def deep_copy_inputs(shape_vars): + return [ + n.inputs[0] + for n in applys_between([larger, smaller], shape_vars) + if isinstance(n.op, DeepCopyOp) + ] + + # Direct conflict: the outer Add reads both larger_s1 (destroyed) and new_dim + # (the destroyer's output). break_aliasing_cycles must re-route larger_s1. + direct = [(new_dim + larger_s1) - 1] + assert imports_with_cycle(direct) + safe_direct = break_aliasing_cycles(direct, fg.destroyers) + assert not imports_with_cycle(safe_direct) + assert deep_copy_inputs(safe_direct) == [larger_s1] + + # Indirect consumer: the outer Add reads new_dim and (larger_s1 - 1), not + # larger_s1 directly. No per-Apply conflict — the destroyer just orders after + # the inner Sub, so break_aliasing_cycles must leave the graph alone. + indirect = [new_dim + (larger_s1 - 1)] + assert not imports_with_cycle(indirect) + safe_indirect = break_aliasing_cycles(indirect, fg.destroyers) + assert safe_indirect == indirect + assert deep_copy_inputs(safe_indirect) == [] diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 741d5f01a5..b3c663b5f0 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -310,7 +310,7 @@ def grad(self, inputs, gout): else: return (gz,) - def infer_shape(self, fgraph, node, shapes): + def infer_shape(self, node, shapes): return [shapes[0]] def test_grad_fail(self): diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 358c95fc66..96cecdc333 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -284,7 +284,7 @@ def test_normal_ShapeFeature(): clone=False, features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt, sd_pt], [s1, s2, d_rv], mode=py_mode, on_unused_input="ignore") s1_val, s2_val, d_rv_val = f(3, np.array(1.0, dtype=config.floatX)) @@ -657,7 +657,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) f = function([M_pt], [s1, s2], mode=py_mode) s1_val, s2_val = f(2) @@ -679,7 +679,7 @@ def test_mvnormal_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2, s3, s4 = fg.shape_feature.shape_of[d_rv] + s1, s2, s3, s4 = fg.shape_feature.shape_tuple(d_rv) mean_val = np.array([[0, 1, 2]], dtype=config.floatX) f = function([mean, cov], [s1, s2, s3, s4], mode=py_mode, on_unused_input="ignore") @@ -810,7 +810,7 @@ def test_dirichlet_ShapeFeature(): features=[ShapeFeature()], ) - s1, s2 = fg.shape_feature.shape_of[d_rv] + s1, s2 = fg.shape_feature.shape_tuple(d_rv) assert M_pt in graph_inputs([s1]) assert N_pt in graph_inputs([s2]) diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index 9ab4e591f1..b3c88879b2 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -162,7 +162,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - # def infer_shape(self, fgraph, node, (xshp,)): + # def infer_shape(self, node, (xshp,)): # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] identity_noshape = IdentityNoShape() @@ -179,7 +179,7 @@ def perform(self, node, inp, out_): (out,) = out_ out[0] = x.copy() - def infer_shape(self, fgraph, node, xshp_): + def infer_shape(self, node, xshp_): # Could also just return. (xshp,) = xshp_ return (xshp,) @@ -613,6 +613,70 @@ def test_vector_dim_err(self): shape_feature.same_shape(x, o, 0, 1) +def test_get_shape_resolves_through_chain(): + """get_shape should resolve to the deepest input, not intermediate ops.""" + from pytensor.tensor.math import cos + + x = matrix("x") + y = exp(cos(x).T) + + fg = FunctionGraph([x], [y], clone=False) + sf = ShapeFeature() + fg.attach_feature(sf) + + s = sf.get_shape(y, 0) + utt.assert_equal_computations([s], [Shape_i(1)(x)]) + + +def test_shape_materialization_does_not_create_destroy_cycle(): + """Lazy shape materialization can create a destroy-handler cycle. + + When an inplace op destroys a Shape_i scalar and the shape of a + downstream op (e.g. Convolve1d) depends on both the destroyed + scalar and the destroyer's output, a single Apply ends up reading + both — the dual-reference pattern that breaks scheduling. + + ``break_aliasing_cycles`` re-routes the destroyed scalar through + ``deep_copy_op`` on the offending Apply, lifting the conflict. + """ + import pytensor.scalar as ps + from pytensor.graph.destroyhandler import DestroyHandler, _contains_cycle + from pytensor.graph.replace import break_aliasing_cycles + from pytensor.tensor.signal import convolve1d + + larger = pt.matrix("larger", shape=(8, None)) + smaller = pt.matrix("smaller", shape=(8, None)) + + sf = ShapeFeature() + larger_s1 = sf._shape_i_var(larger, 1) + smaller_s1 = sf._shape_i_var(smaller, 1) + + sx, sy = ps.int64(), ps.int64() + inplace_comp = Elemwise( + ps.Composite([sx, sy], [ps.sub(ps.add(sx, sy), ps.constant(1, dtype="int64"))]), + inplace_pattern={0: 0}, + ) + new_dim = inplace_comp(larger_s1, smaller_s1) + a = alloc(pt.zeros((1, 1)), 1, new_dim) + out = convolve1d(a, larger[:, ::-1], mode="full") + + fg = FunctionGraph([larger, smaller], [out], clone=False) + fg.attach_feature(sf) + fg.attach_feature(DestroyHandler()) + + naive_shape = list(sf.shape_tuple(out)) + safe_shape = break_aliasing_cycles(naive_shape, fg.destroyers) + + def imports_with_cycle(shape_vars): + check_fg = FunctionGraph([larger, smaller], [out, *shape_vars], clone=False) + check_fg.attach_feature(DestroyHandler()) + dh = check_fg.destroy_handler + return _contains_cycle(check_fg, dh.orderings(check_fg, ordered=False)) + + assert imports_with_cycle(naive_shape) + assert not imports_with_cycle(safe_shape) + + def test_useless_specify_shape(): x = tensor("x", shape=(None, 5, 3)) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index dfded3fbc3..c4023b3ea5 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -310,7 +310,7 @@ def perform(self, node, inputs, outputs): c[0] = np.arange(a.size + b.size, dtype=config.floatX) d[0] = np.arange(a.sum() + b.sum(), dtype=config.floatX) - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): # First output shape depends only on input_shapes # Second output shape depends on input values a_identity, b_identity = node.inputs @@ -362,7 +362,7 @@ def make_node(self, x): def perform(self, node, inputs, outputs): raise NotImplementedError() - def infer_shape(self, fgraph, node, input_shapes): + def infer_shape(self, node, input_shapes): y = node.outputs[0] # Apparently it's valid to return integers in infer_shape. # DimShuffle does this. Modify test if that is no longer allowed. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 913a1036ff..6b72864763 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -873,7 +873,7 @@ def test_partial_static_shape_info(self): x_inferred_shape = (ps.constant(1), ps.constant(1)) res_shape = z.owner.op.infer_shape( - None, z.owner, [x_inferred_shape, x_inferred_shape] + z.owner, [x_inferred_shape, x_inferred_shape] ) assert len(res_shape) == 1 @@ -902,7 +902,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), ) in_1_shape = (ps.constant(1), ps.constant(1)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_1_shape]) for out in outs: assert out[0].eval() == 1 assert out[1].eval() == 1 @@ -911,7 +911,7 @@ def make_node(self, *args): as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) ) in_2_shape = (ps.constant(3), ps.constant(3)) - outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + outs = z_1.owner.op.infer_shape(z_1.owner, [in_1_shape, in_2_shape]) for out in outs: assert out[0].eval() == 3 assert out[1].eval() == 3 @@ -924,7 +924,7 @@ def test_shape_types(self): assert isinstance(z.owner.op, Elemwise) - (out_shape,) = z.owner.op.infer_shape(None, z.owner, [(lscalar(), 1), (50, 10)]) + (out_shape,) = z.owner.op.infer_shape(z.owner, [(lscalar(), 1), (50, 10)]) assert all(isinstance(v.type, TensorType) for v in out_shape) diff --git a/tests/xtensor/test_rewriting.py b/tests/xtensor/test_rewriting.py index da076b1824..2da37fa919 100644 --- a/tests/xtensor/test_rewriting.py +++ b/tests/xtensor/test_rewriting.py @@ -17,8 +17,21 @@ def test_infer_shape_db_handles_xtensor_lowering(): [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [(x.values.sum(0)).shape[0]]) - # With ShapeFeature - fgraph = FunctionGraph([x], [shape_y], features=[ShapeFeature()], copy_inputs=False) + # With ShapeFeature — force caching shape of XRV output before lowering + sf = ShapeFeature() + fgraph = FunctionGraph([x], [shape_y], features=[sf], copy_inputs=False) + # Force get_shape on the XRV sum output (y) before any rewriting lowers it. + # This caches a shape expression referencing the XRV variable. + y_in_graph = [ + v + for v in fgraph.variables + if hasattr(v.type, "ndim") and v.type.ndim == 1 and v is not x + ] + for v in y_in_graph: + try: + sf.get_shape(v, 0) + except Exception: + pass infer_shape_db.default_query.rewrite(fgraph) [rewritten_shape_y] = fgraph.outputs assert_equal_computations([rewritten_shape_y], [Shape_i(1)(x)])