Description
When a DimShuffle/ExpandDims operation adds dimensions that are immediately reduced away by a CAReduce, the DimShuffle is unnecessary and should be removed.
Motivating example
After the local_careduce_join rewrite (#2130), Sum{axes=None}(Join(0, ExpandDims(a), ExpandDims(b))) becomes:
Add
├─ Sum{axes=None}(ExpandDims{axis=0}(a))
├─ Sum{axes=None}(ExpandDims{axis=0}(b))
└─ Sum{axes=None}(ExpandDims{axis=0}(c))
The ExpandDims{axis=0} adds a dimension at position 0, but Sum{axes=None} reduces all dimensions including that one. The same applies when the reduction axis is a subset: e.g., Sum{axis=(1,)}(ExpandDims{axis=0}(x)) could drop the ExpandDims since axis=0 is not in the reduction — the ExpandDims is still needed for broadcasting but could be kept. The easy case is axis=None.
Proposed rewrite
@node_rewriter([CAReduce])
def local_careduce_dimshuffle(fgraph, node):
"""CAReduce(DimShuffle(x), axis=ax) -> CAReduce(x)
When DimShuffle only adds dimensions (no transpose/reshape),
and those dimensions are all reduced away, it can be removed.
"""
[inp] = node.inputs
if not isinstance(inp.owner_op, (DimShuffle, ExpandDims)):
return None
# Check if the DimShuffle is "expanding" only (no transpose/reshape)
# and the expansion axes are all in the reduction set
...
Related
Description
When a
DimShuffle/ExpandDimsoperation adds dimensions that are immediately reduced away by aCAReduce, the DimShuffle is unnecessary and should be removed.Motivating example
After the
local_careduce_joinrewrite (#2130),Sum{axes=None}(Join(0, ExpandDims(a), ExpandDims(b)))becomes:The
ExpandDims{axis=0}adds a dimension at position 0, butSum{axes=None}reduces all dimensions including that one. The same applies when the reduction axis is a subset: e.g.,Sum{axis=(1,)}(ExpandDims{axis=0}(x))could drop the ExpandDims since axis=0 is not in the reduction — the ExpandDims is still needed for broadcasting but could be kept. The easy case isaxis=None.Proposed rewrite
Related
local_careduce_join) exposes this patternSum(ExpandDims(...)))