Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/assumptions/specify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_node(self, x):
out = x.type()
return Apply(self, [x], [out])

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes

def pullback(
Expand Down
2 changes: 1 addition & 1 deletion pytensor/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def perform(self, node, inputs, output_storage):
def pullback(self, inputs, outputs, output_gradients):
return [disconnected_type(), *output_gradients]

def infer_shape(self, fgraph, inputs, input_shapes):
def infer_shape(self, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
return input_shapes[1:]

Expand Down
145 changes: 73 additions & 72 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Callable, Sequence
from copy import copy
from functools import partial
from itertools import chain
from typing import cast

from pytensor.compile.maker import function
Expand All @@ -24,70 +23,51 @@
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph
from pytensor.graph.null_type import NullType
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
from pytensor.graph.replace import clone_replace
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.shape import Shape_i


def infer_shape(outs, inputs, input_shapes):
"""
Compute the shape of the outputs given the shape of the inputs of an PyTensor
graph.

We do it this way to avoid compiling the inner function just to get
the shape. Changes to ShapeFeature could require changes in this function.

"""
# We use a ShapeFeature because it has all the necessary logic
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually

# TODO: ShapeFeature should live elsewhere
"""Compute the shape of the outputs given the shape of the inputs of a PyTensor graph."""
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes, strict=True):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim

shape_feature = ShapeFeature()
fgraph = FunctionGraph([], [], features=[shape_feature])
for v in chain.from_iterable(s for s in input_shapes if s is not None):
# Import input_shape nodes, as for some graphs ShapeFeature assumes these were seen before
if (node := v.owner) is not None:
fgraph.import_node(node, import_missing=True)

# Initialize shape_of with the input shapes
for inp, inp_shp in zip(inputs, input_shapes, strict=True):
shape_feature.set_shape(inp, inp_shp, override=True)

def local_traverse(out):
"""
Go back in the graph, from out, adding computable shapes to shape_of.

"""
if out in shape_feature.shape_of:
# Its shape is already known
return
elif out.owner is None:
# This is an input of the graph
shape_feature.init_r(out)
else:
# Recurse over inputs
for inp in out.owner.inputs:
if inp not in shape_feature.shape_of:
local_traverse(inp)
output_shapes = [shape_feature.shape_tuple(o) for o in outs]

# shape_feature.on_import does not actually use an fgraph
# It will call infer_shape and set_shape appropriately
dummy_fgraph = None
shape_feature.on_import(dummy_fgraph, out.owner, reason="dummy")
# Shape expressions for root inputs are Shape_i(inp, i).
# Replace those with the caller-provided input_shapes.
replacements = {}
for inp, shp in zip(inputs, input_shapes, strict=True):
if shp is None:
continue
per_dim = shape_feature._shape_i_cache.get(inp)
if per_dim is None:
continue
for i, s in enumerate(shp):
cached = per_dim.get(i)
if cached is not None:
replacements[cached] = s

if replacements:
flat = [s for tup in output_shapes if tup is not None for s in tup]
flat_replaced = graph_replace(flat, replacements, strict=False)
result = []
idx = 0
for tup in output_shapes:
if tup is None:
result.append(None)
else:
result.append(tuple(flat_replaced[idx : idx + len(tup)]))
idx += len(tup)
return result

ret = []
for o in outs:
local_traverse(o)
ret.append(shape_feature.shape_of[o])
return ret
return output_shapes


def construct_nominal_fgraph(
Expand Down Expand Up @@ -885,30 +865,51 @@ def connection_pattern(self, node):
self._connection_pattern = ret
return ret

def infer_shape(self, fgraph, node, shapes):
# TODO: Use `fgraph.shape_feature` to do this instead.
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)

# Clone the output shape so that shape are computed from outer inputs.
# Note:
# Here we could do it more simply like:
# `ret = [pytensor.clone_replace(shp, replace=repl) for shp in out_shp]`
# But doing it multiple time could duplicate common subgraph between
# each shape call. PyTensor optimizer will clean this up later, but this
# will make extra work for the optimizer.

repl = dict(zip(self.inner_inputs, node.inputs, strict=True))
clone_out_shapes = [s for s in out_shapes if isinstance(s, tuple)]
cloned = clone_replace(sum(clone_out_shapes, ()), replace=repl)
def infer_shape(self, node, shapes):
try:
template = self._inner_shape_template
frozen = self._inner_shape_frozen
except AttributeError:
from pytensor.tensor.rewriting.shape import ShapeFeature

sf = ShapeFeature()
inner_inputs = self.inner_inputs
template = [sf.shape_tuple(o) for o in self.inner_outputs]
flat_shapes = [s for tup in template if tup is not None for s in tup]

# Express the inner-output shapes as a frozen function of the inner
# inputs plus each input's per-dim size. from_structural_inputs rewires
# every Shape_i(inner_input, dim) occurrence to the matching input, so
# bind can later swap in the caller's shapes. One slot per input dim:
# static or unused dims become dead inputs, keeping the layout positional.
shape_inputs = [
Shape_i(dim)(inp)
for inp in inner_inputs
for dim in range(getattr(inp.type, "ndim", 0))
]
frozen = FrozenFunctionGraph.from_structural_inputs(
[*inner_inputs, *shape_inputs], flat_shapes
)
self._inner_shape_template = template
self._inner_shape_frozen = frozen

# frozen.inputs is [*inner_inputs, *per-dim sizes]; mirror that layout.
repl = list(node.inputs)
for shp in shapes:
if shp is not None:
repl.extend(shp)

bound_shapes = frozen.bind(dict(zip(frozen.inputs, repl, strict=True)))

ret = []
used = 0
for i, out_shape in enumerate(out_shapes):
if out_shape is None:
idx = 0
for tup in template:
if tup is None:
ret.append(None)
else:
nb = len(out_shape)
ret.append(cloned[used : used + nb])
used += nb
nb = len(tup)
ret.append(bound_shapes[idx : idx + nb])
idx += nb

return ret

Expand Down
10 changes: 5 additions & 5 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ViewOp(TypeCastingOp):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes

def pullback(self, args, outputs, g_outs):
Expand Down Expand Up @@ -179,7 +179,7 @@ def c_code(self, node, name, inames, onames, sub):
# Else, no C code
raise NotImplementedError()

def infer_shape(self, fgraph, node, input_shapes):
def infer_shape(self, node, input_shapes):
return input_shapes


Expand Down Expand Up @@ -251,8 +251,8 @@ def __reduce__(self):
)
return load_back, (mod, name)

def _infer_shape(self, fgraph, node, input_shapes):
return self.__infer_shape(fgraph, node, input_shapes)
def _infer_shape(self, node, input_shapes):
return self.__infer_shape(node, input_shapes)


def as_op(itypes, otypes, infer_shape=None):
Expand All @@ -275,7 +275,7 @@ def wrap_py(itypes, otypes, infer_shape=None):
It takes an optional infer_shape parameter that should be a callable with
this signature:

def infer_shape(fgraph, node, input_shapes):
def infer_shape(node, input_shapes):
...
return output_shapes

Expand Down
33 changes: 32 additions & 1 deletion pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,37 @@ def _resolve_input(inp, memo=memo):
self._variables: frozenset[Variable] | None = None
self._clients: dict[Variable, list[ClientType]] | None = None

@classmethod
def from_structural_inputs(
cls,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
) -> "FrozenFunctionGraph":
"""Freeze ``outputs``, allowing ``inputs`` to be *interior* expressions.

Structural-matching dual of `bind`: where `bind` maps inputs to values,
this lifts chosen sub-expressions up to inputs. An ``input`` produced by
an `Apply` is matched against ``outputs`` by structure (not identity) and
every occurrence is rewired to it; root inputs behave as in the
constructor. Intermediate inputs absent from ``outputs`` become dead
inputs. The signature preserves the order of ``inputs``. ``outputs`` must
be computable from ``inputs`` alone — any root they still depend on
directly must itself appear in ``inputs``, else the rewired graph is
orphaned.
"""
# Discover the true graph roots (as FunctionGraph(inputs=None) does) to
# seed the staged freeze; the caller's `inputs` may be intermediate.
# Freezing inputs and outputs together interns each intermediate input
# onto the same object as its occurrences in the outputs, so the
# re-freeze can rewire them — which requires intermediate inputs to be
# *built*, not blocked, hence only roots seed the freeze.
roots = [
v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant)
]
interned = cls(roots, [*inputs, *outputs])
n_inputs = len(inputs)
return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:])

def __reduce__(self):
return FrozenFunctionGraph, (self.inputs, self.outputs)

Expand Down Expand Up @@ -1094,7 +1125,7 @@ def bind(self, replace: dict[Variable, Variable]) -> list[Variable]:
[o.type() for o in node.outputs],
)
memo.update(zip(node.outputs, new_node.outputs))
return [memo[out] for out in self.outputs]
return [out if isinstance(out, Constant) else memo[out] for out in self.outputs]

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
Expand Down
19 changes: 19 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -120,6 +121,24 @@ class Op(MetaObject):
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
method = cls.__dict__.get("infer_shape")
if method is None:
return
params = inspect.signature(method).parameters
if len(params) == 4:
warnings.warn(
f"{cls.__module__}.{cls.__qualname__}.infer_shape takes a "
"deprecated `fgraph` parameter; drop it from the signature. "
"The parameter will be passed as None.",
DeprecationWarning,
stacklevel=2,
)
cls.infer_shape = lambda self, node, input_shapes, _old=method: _old(
self, None, node, input_shapes
)

def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.

Expand Down
60 changes: 60 additions & 0 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,66 @@ def toposort_key(
return fg.outputs[0]


def break_aliasing_cycles(
outputs: Sequence[Variable],
destroyers_of,
) -> list[Variable]:
"""Break aliasing-induced ordering cycles in ``outputs``.

When an inplace Op overwrites input ``x`` and a single Apply reads
both ``x`` and a transitive dependent of the destroyer's output,
no valid schedule exists. This re-routes ``x`` on that Apply through
``deep_copy_op`` to lift the ordering conflict.
"""
from pytensor.compile.ops import deep_copy_op

EMPTY: frozenset[Variable] = frozenset()
deps: dict[Variable, frozenset[Variable]] = {}
substitutes: dict[Variable, Variable] = {}
replacements: dict[Variable, Variable] = {}
for node in toposort(outputs):
d: set[Variable] = set()
for inp in node.inputs:
d |= deps.get(inp, EMPTY)
if inp.owner is not None and inp.owner.op.destroy_map:
d.add(inp)
out_deps = frozenset(d)
for out in node.outputs:
deps[out] = out_deps

if node.op.destroy_map:
continue

new_inputs = list(node.inputs)
changed = False
for i, inp in enumerate(node.inputs):
inp_destroyers = destroyers_of(inp)
if not inp_destroyers:
continue
other_deps: set[Variable] = set()
for j, other_inp in enumerate(node.inputs):
if j == i:
continue
other_deps |= deps.get(other_inp, EMPTY)
if other_inp.owner is not None and other_inp.owner.op.destroy_map:
other_deps.add(other_inp)
if any(
out in other_deps for c_app in inp_destroyers for out in c_app.outputs
):
if inp not in substitutes:
substitutes[inp] = cast(Variable, deep_copy_op(inp))
new_inputs[i] = substitutes[inp]
changed = True
if changed:
new_node = node.op.make_node(*new_inputs)
replacements.update(zip(node.outputs, new_node.outputs, strict=True))

if not replacements:
return list(outputs)

return graph_replace(list(outputs), replace=replacements)


@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise
Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __str__(self):
args.append("inplace")
return f"if{{{','.join(args)}}}"

def infer_shape(self, fgraph, node, inputs_shapes):
def infer_shape(self, node, inputs_shapes):
# By construction, corresponding then/else pairs have the same number
# of dimensions

Expand Down
Loading
Loading