Skip to content

Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167

Open
jaanerik wants to merge 3 commits into
pymc-devs:mainfrom
jaanerik:fix-local-subtensor-of-batch-dims-collapse
Open

Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167
jaanerik wants to merge 3 commits into
pymc-devs:mainfrom
jaanerik:fix-local-subtensor-of-batch-dims-collapse

Conversation

@jaanerik
Copy link
Copy Markdown
Contributor

@jaanerik jaanerik commented May 25, 2026

Description

The local_subtensor_of_batch_dims rewrite crashes with a TypeError when it attempts to lift a subtensor through an Elemwise node that is in a "stale" state.

A stale state occurs mid-optimization when upstream rewrites have made an Elemwise node's inputs broadcastable (length-1), but the node's output Type has not yet been updated to reflect this (remaining non-broadcastable).

Checklist

Type of change

  • Bug fix

Issues

@jaanerik jaanerik force-pushed the fix-local-subtensor-of-batch-dims-collapse branch from 3ad4689 to edcf930 Compare May 25, 2026 13:13
@jaanerik jaanerik force-pushed the fix-local-subtensor-of-batch-dims-collapse branch 2 times, most recently from 70780a2 to 7475ed4 Compare May 25, 2026 14:29

if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
if not dim_bcast_out and all(dim_bcast_inputs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to think a bit better about the patch.

The rewrite could do x[idx] -> alloc(x, idx.shape[dim]) in this case, the issue is it is missing that last step when after correctly figuring out the idx is never needed, we may still need to provide shape with an alloc.

Also there's a branch above that is not great:

if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
        # No need to worry about implicit broadcasting.
        indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]

It tries to check if there's any broadcasting going on, so in the cases where we know x is length 1 it just blindly lifts, but that's not good either, because we end up with wasteful x[idx] + x[idx]. expanding size 1 dims for no reason.

Would be better to fix these two scenarios, we never want to end up with a worse graph, but also we don't want the bail out that's currently proposed

Copy link
Copy Markdown
Contributor Author

@jaanerik jaanerik May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new commit hopefully fixes those. Also Claude thought it is worth noting that

  • alloc uses *old_out.shape rather than per-dim idx.shape[d]
  • a fast inp[idx_tuple] path is kept when no per-dim adjustment is needed, to sidestep _canonical_indexing's ScalarType AttributeError (separate latent bug — idx.type.broadcastable at line 107 doesn't exist on ScalarType).

Instead of bailing out when every input is length 1 on a dim but the
Elemwise output type stale-claims non-bcast, lift through and recover
the original shape with an alloc. Also avoid the prior fast path that
wastefully expanded size 1 dims when an advanced index landed on a dim
where all inputs were length 1 (would emit x[idx] + x[idx] expanding
size 1 to size K for no reason).

The per-dim loop now:
- skips slice(None) and dims with no broadcast input,
- applies basic indices (slice/scalar) as-is on all-bcast dims since
  they can't expand size 1,
- otherwise replaces the index on bcast inputs with a size 1 stand-in.

After building the lifted Elemwise, broadcast back to old_out.shape via
alloc whenever the bcast pattern differs from the original, which
covers both the stale-type case and the collapsed-all-bcast case.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: TypeError in local_subtensor_of_batch_dims when encountering stale broadcast types

2 participants