Skip to content

Rewrite ShapeFeature to not hold live variables#2056

Draft
ricardoV94 wants to merge 6 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature
Draft

Rewrite ShapeFeature to not hold live variables#2056
ricardoV94 wants to merge 6 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 17, 2026

Closes pymc-devs/pymc-extras#673

ShapeFeature reintroducing variables we have lowered/rewritten away is no-bueno

Details Replace the eager per-variable dict (shape_of, shape_of_reverse_index, scheduled) with a lazy FrozenFunctionGraph-based shape kernel cache. For each Apply, a kernel built from dummy clones of node.inputs is stored in self._cache[node] and materialized against today's live inputs on demand via a custom frozen-graph walker (graph_replace would mutate globally-interned FrozenApply inputs).

The kernel holds only NominalVariables and Constants, so no live
variable can leak between tests or across rewrites, eliminating by
construction the stale-XRV class of bugs.

Back-compat surface (_LazyShapeTuple, _ShapeOfProxy, update_shape,
shape_ir, init_r) is retained and marked as temporary. A regression
test for the stale-XRV scenario replaces the prior xfail.

shape_of_variables switches to builders.infer_shape so it returns to
scalar-dim inputs instead of allocating per-input arrays.

local_track_shape_i no longer depends on the deleted scheduled dict;
it rewrites Shape_i(v, i) to get_shape(v, i) whenever the kernel
produces something other than the trivial fallback.

on_change_input carries r's inferred shape onto new_r as an override
when new_r's Op has no infer_shape, preserving the legacy behavior
where a well-inferred shape survives through a replacement with an
opaque op.

Benchmarks (cxx enabled):

  • radon_repeat 0.78s -> 0.55s (-30%)
  • radon_variants (8) 7.9s -> 7.2s ( -9%)
  • fusion_large 0.22s -> 0.22s (noise)
  • fusion_deep 13ms -> 13ms (noise)

@ricardoV94
Copy link
Copy Markdown
Member Author

@ricardoV94
Copy link
Copy Markdown
Member Author

I added a small patch for OFG to cache the shape graph, with this, this PR lazy ShapeFeature and #2147 the issue that #2164 is trying to circumvent is already handled.

Before PR
--------------------------------------------------------------------------------------------- benchmark: 2 tests ---------------------------------------------------------------------------------------------
Name (time in s)                                    Min                Max               Mean            StdDev             Median               IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]      4.6451 (1.0)       5.0273 (1.0)       4.7674 (1.0)      0.1500 (1.0)       4.7223 (1.0)      0.1310 (1.0)           1;1  0.2098 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     70.5293 (15.18)    75.3482 (14.99)    73.3600 (15.39)    2.3004 (15.34)    73.7812 (15.62)    3.7426 (28.56)         1;0  0.0136 (0.06)          4           1
test_xtensor_attention_rewrite_benchmark[4]     HEAT DEATH OF THE UNIVERSE
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After lazy Shape Feature
----------------------------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------------------------
Name (time in ms)                                      Min                   Max                  Mean             StdDev                Median                 IQR            Outliers     OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]       604.5540 (1.0)        734.8644 (1.0)        678.0195 (1.0)      58.0627 (2.44)       702.4159 (1.0)      101.5709 (2.87)          1;0  1.4749 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     1,238.5549 (2.05)     1,337.9649 (1.82)     1,266.2889 (1.87)     41.0227 (1.72)     1,248.9357 (1.78)      38.6898 (1.09)          1;0  0.7897 (0.54)          5           1
test_xtensor_attention_rewrite_benchmark[4]     2,128.2901 (3.52)     2,186.8365 (2.98)     2,164.0254 (3.19)     23.8237 (1.0)      2,173.0298 (3.09)      35.4256 (1.0)           1;0  0.4621 (0.31)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After shape graph cache:
------------------------------------------------------------------------------------------------- benchmark: 3 tests ------------------------------------------------------------------------------------------------
Name (time in ms)                                    Min                 Max                Mean             StdDev              Median                 IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_xtensor_attention_rewrite_benchmark[2]     288.4377 (1.0)      422.6271 (1.0)      327.9444 (1.0)      53.7626 (1.0)      310.3342 (1.0)       38.3331 (1.0)           1;1  3.0493 (1.0)           5           1
test_xtensor_attention_rewrite_benchmark[3]     431.9948 (1.50)     601.8571 (1.42)     505.5498 (1.54)     78.9754 (1.47)     457.3631 (1.47)     135.0387 (3.52)          1;0  1.9780 (0.65)          5           1
test_xtensor_attention_rewrite_benchmark[4]     614.3723 (2.13)     783.8840 (1.85)     714.3603 (2.18)     82.6776 (1.54)     756.9681 (2.44)     151.6929 (3.96)          1;0  1.3999 (0.46)          5           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

There's a new benchmark test tracking this. It can still make sense to eagerly rewrite einsum -> core graph for 2 inputs, but in general we can't afford to have inner graph ops be a performance bottleneck, as that's the direction we are moving (see #2110 and #1221).

There were several sources of exponential blow-up that we are addressing here. None of this is a hack, the old code was just dumb.

CC @cetagostini

`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join`
and batched-`Blockwise` as protected client ops, but not `Subtensor`.
`local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into
`alloc(val[...], *new_shape)` — preserving the Alloc structure that
downstream rewrites like `local_blockwise_alloc_inputs` depend on.
Folding the Alloc here short-circuited that lift and produced
broadcast-equivalent `Constant` matrices whose batch dim was no longer
type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the
surrounding `Blockwise(Reshape)`.

Surfaced by the lazy-kernel `ShapeFeature` (which resolves
`Subtensor(Shape(out), const)` to a scalar `Constant` earlier and
makes more upstream Allocs constant-foldable), but the fix belongs
here — the protection was too narrow.
Breaking API change: the `fgraph` argument was unused by every
in-tree `infer_shape` implementation. Removing it makes
`infer_shape` a pure function of `(node, input_shapes)`, simpler
to call from outside an fgraph context (e.g. ShapeFeature's lazy
kernel build) and tighter as a contract.

External Ops with custom `infer_shape(self, fgraph, node, input_shapes)`
must drop the `fgraph` parameter.
Add `break_aliasing_cycles` to `pytensor.graph.replace`. When an inplace
Op overwrites input `x` and a single Apply ends up reading both `x` and
a transitive dependent of the destroyer's output, no valid schedule
exists. The helper re-routes such inputs through `deep_copy_op` to lift
the conflict.

Expose it via a `ShapeFeature.get_shape_no_cycle` convenience method,
and use it from `introduce_explicit_core_shape_rv` and
`introduce_explicit_core_shape_blockwise`, where lazy shape
materialization can otherwise produce that pattern.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Prior.create_variable(xdist=True) fails compile_logp for centered priors with nested Prior parameters that have dims

1 participant