Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167
Fix local_subtensor_of_batch_dims when encountering stale broadcast types#2167jaanerik wants to merge 3 commits into
Conversation
3ad4689 to
edcf930
Compare
70780a2 to
7475ed4
Compare
|
|
||
| if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs): | ||
| # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs | ||
| if not dim_bcast_out and all(dim_bcast_inputs): |
There was a problem hiding this comment.
We need to think a bit better about the patch.
The rewrite could do x[idx] -> alloc(x, idx.shape[dim]) in this case, the issue is it is missing that last step when after correctly figuring out the idx is never needed, we may still need to provide shape with an alloc.
Also there's a branch above that is not great:
if all(inp.type.broadcastable[:batch_ndim] == elem_bcast for inp in elem_inputs):
# No need to worry about implicit broadcasting.
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]It tries to check if there's any broadcasting going on, so in the cases where we know x is length 1 it just blindly lifts, but that's not good either, because we end up with wasteful x[idx] + x[idx]. expanding size 1 dims for no reason.
Would be better to fix these two scenarios, we never want to end up with a worse graph, but also we don't want the bail out that's currently proposed
There was a problem hiding this comment.
The new commit hopefully fixes those. Also Claude thought it is worth noting that
- alloc uses
*old_out.shaperather than per-dimidx.shape[d] - a fast
inp[idx_tuple]path is kept when no per-dim adjustment is needed, to sidestep_canonical_indexing'sScalarType AttributeError(separate latent bug —idx.type.broadcastableat line 107 doesn't exist onScalarType).
Instead of bailing out when every input is length 1 on a dim but the Elemwise output type stale-claims non-bcast, lift through and recover the original shape with an alloc. Also avoid the prior fast path that wastefully expanded size 1 dims when an advanced index landed on a dim where all inputs were length 1 (would emit x[idx] + x[idx] expanding size 1 to size K for no reason). The per-dim loop now: - skips slice(None) and dims with no broadcast input, - applies basic indices (slice/scalar) as-is on all-bcast dims since they can't expand size 1, - otherwise replaces the index on bcast inputs with a size 1 stand-in. After building the lifted Elemwise, broadcast back to old_out.shape via alloc whenever the bcast pattern differs from the original, which covers both the stale-type case and the collapsed-all-bcast case. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Description
The local_subtensor_of_batch_dims rewrite crashes with a TypeError when it attempts to lift a subtensor through an Elemwise node that is in a "stale" state.
A stale state occurs mid-optimization when upstream rewrites have made an Elemwise node's inputs broadcastable (length-1), but the node's output Type has not yet been updated to reflect this (remaining non-broadcastable).
Checklist
Type of change
Issues