Skip to content

Add CAReduce(DimShuffle(x)) -> CAReduce(x) rewrite when DimShuffle is a no-op for the reduction #2132

@williambdean

Description

@williambdean

Description

When a DimShuffle/ExpandDims operation adds dimensions that are immediately reduced away by a CAReduce, the DimShuffle is unnecessary and should be removed.

Motivating example

After the local_careduce_join rewrite (#2130), Sum{axes=None}(Join(0, ExpandDims(a), ExpandDims(b))) becomes:

Add
 ├─ Sum{axes=None}(ExpandDims{axis=0}(a))
 ├─ Sum{axes=None}(ExpandDims{axis=0}(b))
 └─ Sum{axes=None}(ExpandDims{axis=0}(c))

The ExpandDims{axis=0} adds a dimension at position 0, but Sum{axes=None} reduces all dimensions including that one. The same applies when the reduction axis is a subset: e.g., Sum{axis=(1,)}(ExpandDims{axis=0}(x)) could drop the ExpandDims since axis=0 is not in the reduction — the ExpandDims is still needed for broadcasting but could be kept. The easy case is axis=None.

Proposed rewrite

@node_rewriter([CAReduce])
def local_careduce_dimshuffle(fgraph, node):
    """CAReduce(DimShuffle(x), axis=ax) -> CAReduce(x)
    
    When DimShuffle only adds dimensions (no transpose/reshape),
    and those dimensions are all reduced away, it can be removed.
    """
    [inp] = node.inputs
    if not isinstance(inp.owner_op, (DimShuffle, ExpandDims)):
        return None
    # Check if the DimShuffle is "expanding" only (no transpose/reshape)
    # and the expansion axes are all in the reduction set
    ...

Related

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions