Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,6 +1964,53 @@ def local_reduce_join(fgraph, node):
return [ret]


@register_specialize
@register_canonicalize
@register_uncanonicalize # Needed for Min which is formed from Neg(Max(Neg))
@node_rewriter([CAReduce])
def local_careduce_join(fgraph, node):
r"""CAReduce(Join(axis, \*tensors), axis=ax) -> CAReduce(..t1), .. combined via scalar.op

When the reduction axis includes the join axis (or reduces all elements),
this avoids creating the concatenated intermediate array by reducing each
join input separately and combining results with the scalar op.

For >2 joined inputs with binary-only ops (e.g. Maximum, Minimum), the
combine step uses nested applications since Elemwise{nin > 2} is not
currently supported for those ops.

"""
[joined_out] = node.inputs
if not isinstance(joined_out.owner_op, Join):
return None

join_axis_tensor, *joined_inputs = joined_out.owner.inputs

if len(joined_inputs) < 2:
return None

if not isinstance(join_axis_tensor, Constant):
return None

join_axis = int(join_axis_tensor.data)
if join_axis < 0:
join_axis += joined_out.type.ndim

reduce_op = node.op
if reduce_op.axis is not None and join_axis not in reduce_op.axis:
return None

reduced = [reduce_op(inp) for inp in joined_inputs]
scalar_op = reduce_op.scalar_op
ret = reduce(lambda a, b: Elemwise(scalar_op)(a, b), reduced)

if ret.dtype != node.outputs[0].dtype:
return None

copy_stack_trace(node.outputs[0], ret)
return [ret]


@register_infer_shape
@register_canonicalize("fast_compile", "local_cut_useless_reduce")
@register_useless("local_cut_useless_reduce")
Expand Down
133 changes: 120 additions & 13 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytensor.graph.traversal import ancestors
from pytensor.printing import debugprint, pprint
from pytensor.scalar import PolyGamma, Psi, TriGamma
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
Expand Down Expand Up @@ -103,6 +103,7 @@
from pytensor.tensor.rewriting.math import (
compute_mul,
is_1pexp,
local_careduce_join,
local_div_switch_sink,
local_grad_log_erfc_neg,
local_greedy_distributor,
Expand Down Expand Up @@ -3500,32 +3501,33 @@ def test_type(self):
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

# This case could be rewritten
# Join axis is included in reduction axes
A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode)
np.testing.assert_allclose(f(), [2, 4, 6, 8, 10])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)
assert isinstance(topo[-1].op, Elemwise)

A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1))
f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode)
np.testing.assert_allclose(f(), [15, 15])
topo = f.maker.fgraph.toposort()
assert not isinstance(topo[-1].op, Elemwise)

def test_not_supported_axis_none(self):
# Test that the rewrite does not crash in one case where it
# is not applied. Reported at
# https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion
vx = matrix()
vy = matrix()
vz = matrix()
def test_careduce_join_list_sum_axis_none(self):
"""Sum([a, b, c], axis=None) -> Add(Sum(ExpandDims(a)), Sum(ExpandDims(b)), Sum(ExpandDims(c))) with list inputs (via Join)"""
vx, vy, vz = matrices("xyz")
out = pt_sum([vx, vy, vz], axis=None)

fg = FunctionGraph([vx, vy, vz], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None]))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing CAReduce(Dimshuffle(x)) -> CAReduce(x), when the DimShuffle has no effect due to reduction (or DimShuffle may still be needed but only a subset of its behavior). In this case it isn't needed.

Does not need to be a blocker for this PR, but we should open an issue. I thought there was one already

assert equal_computations([rewritten_out], [expected_out])

x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX)
y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX)
z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX)

out = pt_sum([vx, vy, vz], axis=None)
f = function([vx, vy, vz], out, mode=self.mode)
f = function([vx, vy, vz], out, mode=get_mode("FAST_COMPILE"))
np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z]))

def test_not_supported_unequal_shapes(self):
Expand Down Expand Up @@ -3559,6 +3561,111 @@ def test_non_ds_inputs(self):
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])

def test_careduce_join_sum_2(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the kind of approach to writing rewrite tests we're trying to settle on: #2103

"""Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs"""
x, y = vectors("xy")
out = pt_sum(pt.concatenate([x, y]), axis=None)

fg = FunctionGraph([x, y], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = add(pt_sum(x), pt_sum(y))
assert equal_computations([rewritten_out], [expected_out])

xv = np.random.rand(100).astype(config.floatX)
yv = np.random.rand(200).astype(config.floatX)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv])))

def test_careduce_join_sum_3(self):
"""Sum(concat(a, b, c), axis=None) with 3 inputs applied via nested Add combine"""
x, y, z = vectors("xyz")
out = pt_sum(pt.concatenate([x, y, z]), axis=None)

fg = FunctionGraph([x, y, z], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = add(add(pt_sum(x), pt_sum(y)), pt_sum(z))
assert equal_computations([rewritten_out], [expected_out])

xv = np.random.rand(100).astype(config.floatX)
yv = np.random.rand(150).astype(config.floatX)
zv = np.random.rand(200).astype(config.floatX)
f = function([x, y, z], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv])))

def test_careduce_join_max_2(self):
"""Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)"""
x, y = vectors("xy")
out = pt_max(pt.concatenate([x, y]), axis=None)

fg = FunctionGraph([x, y], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = maximum(pt_max(x), pt_max(y))
assert equal_computations([rewritten_out], [expected_out])

xv = np.random.rand(100).astype(config.floatX)
yv = np.random.rand(200).astype(config.floatX)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv])))

def test_careduce_join_max_3(self):
"""Max(concat(a, b, c), axis=None) applied via nested binary Maximum ops"""
x, y, z = vectors("xyz")
out = pt_max(pt.concatenate([x, y, z]), axis=None)

fg = FunctionGraph([x, y, z], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = maximum(maximum(pt_max(x), pt_max(y)), pt_max(z))
assert equal_computations([rewritten_out], [expected_out])

xv = np.random.rand(100).astype(config.floatX)
yv = np.random.rand(150).astype(config.floatX)
zv = np.random.rand(200).astype(config.floatX)
f = function([x, y, z], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv, zv), np.max(np.concatenate([xv, yv, zv])))

def test_careduce_join_sum_specific_axis(self):
"""Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0))

join_axis=0 is included in the reduction axes, so the rewrite applies.
"""
x, y = matrices("xy")
out = pt_sum(pt.concatenate([x, y], axis=0), axis=0)

fg = FunctionGraph([x, y], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = add(pt_sum(x, axis=0), pt_sum(y, axis=0))
assert equal_computations([rewritten_out], [expected_out])

xv = np.array([[1, 2], [3, 4]], dtype=config.floatX)
yv = np.array([[5, 6]], dtype=config.floatX)
f = function([x, y], out, mode=self.mode)
np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0))

def test_careduce_join_not_applied_axis_excludes_join(self):
"""Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead

x, y = matrices("xy")
out = pt_sum(pt.concatenate([x, y], axis=0), axis=1)
f = function([x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_empty_axis(self):
"""Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)"""
x, y = vectors("xy")
out = pt_sum(pt.concatenate([x, y], axis=0), axis=[])
f = function([x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_dynamic_axis(self):
"""Non-constant join axis should NOT trigger"""
axis = iscalar("axis")
x, y = vectors("xy")
out = pt_sum(join(axis, x, y), axis=None)
f = function([axis, x, y], out, mode=self.mode)
topo = f.maker.fgraph.toposort()
assert any(isinstance(n.op, Join) for n in topo)


def test_local_useless_adds():
default_mode = get_default_mode()
Expand Down
Loading