From 900489aacd98489542fff2ca529b74ed55b22e90 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 13 May 2026 15:12:26 -0400 Subject: [PATCH 1/3] Add local_careduce_join rewrite for CAReduce of Join with axis=None or join_axis in reduce_axis --- pytensor/tensor/rewriting/math.py | 48 +++++++++++++++ tests/tensor/rewriting/test_math.py | 95 ++++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 937fb068d5..b85fe63b93 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1964,6 +1964,54 @@ def local_reduce_join(fgraph, node): return [ret] +@register_specialize +@register_canonicalize +@register_uncanonicalize # Needed for Min which is formed from Neg(Max(Neg)) +@node_rewriter([CAReduce]) +def local_careduce_join(fgraph, node): + r"""CAReduce(Join(axis, \*tensors), axis=ax) -> Elemwise{scalar.op}(\*[CAReduce(axis=ax, t) for t in tensors]) + + When the reduction axis includes the join axis (or reduces all elements), + this avoids creating the concatenated intermediate array. + + For >2 joined inputs, only scalar ops with variadic support + (Add, Mul) are rewritten, since Elemwise can't combine >2 + binary-only ops (e.g. Maximum, Minimum) at once. + + """ + [joined_out] = node.inputs + if joined_out.owner is None or not isinstance(joined_out.owner.op, Join): + return None + + join_axis_tensor, *joined_inputs = joined_out.owner.inputs + + if len(joined_inputs) < 2: + return None + + if not isinstance(join_axis_tensor, Constant): + return None + + if len(joined_inputs) > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul): + return None + + join_axis = int(join_axis_tensor.data) + if join_axis < 0: + join_axis += joined_out.type.ndim + + reduce_op = node.op + if reduce_op.axis is not None and join_axis not in reduce_op.axis: + return None + + reduced = [reduce_op.clone(axis=reduce_op.axis)(inp) for inp in joined_inputs] + ret = Elemwise(reduce_op.scalar_op)(*reduced) + + if ret.dtype != node.outputs[0].dtype: + return None + + copy_stack_trace(node.outputs[0], ret) + return [ret] + + @register_infer_shape @register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_useless("local_cut_useless_reduce") diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 106862c641..d6bb3417cc 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,7 +31,7 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint, pprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor.basic import Alloc, constant, join, second, switch +from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise @@ -3500,12 +3500,12 @@ def test_type(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - # This case could be rewritten + # Join axis is included in reduction axes A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op, Elemwise) A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) @@ -3559,6 +3559,95 @@ def test_non_ds_inputs(self): expected_out = add(exp(x), log(x)) assert equal_computations([rewritten_out], [expected_out]) + def test_careduce_join_sum_2(self): + """Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs""" + x, y = vectors("xy") + xv, yv = ( + np.random.rand(100).astype(config.floatX), + np.random.rand(200).astype(config.floatX), + ) + out = pt_sum(pt.concatenate([x, y]), axis=None) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]))) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_sum_3(self): + """Sum(concat(a, b, c), axis=None) -> Add(Sum(a), Sum(b), Sum(c)) with 3 inputs (variadic combine)""" + x, y, z = vectors("xyz") + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(150).astype(config.floatX) + zv = np.random.rand(200).astype(config.floatX) + out = pt_sum(pt.concatenate([x, y, z]), axis=None) + f = function([x, y, z], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv]))) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_max_2(self): + """Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)""" + x, y = vectors("xy") + xv, yv = ( + np.random.rand(100).astype(config.floatX), + np.random.rand(200).astype(config.floatX), + ) + out = pt_max(pt.concatenate([x, y]), axis=None) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv]))) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_sum_specific_axis(self): + """Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0)) + + join_axis=0 is included in the reduction axes, so the rewrite applies. + """ + x, y = matrices("xy") + xv = np.array([[1, 2], [3, 4]], dtype=config.floatX) + yv = np.array([[5, 6]], dtype=config.floatX) + out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0)) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_max_3_not_applied(self): + """Max(concat(a, b, c), axis=None) should NOT trigger (binary-only combine can't take >2) + + Elemwise{maximum} is a binary op with nin=2, so it can't combine + three individual Max results at once like Add/Mul can. + """ + x, y, z = vectors("xyz") + out = pt_max(pt.concatenate([x, y, z]), axis=None) + f = function([x, y, z], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_not_applied_axis_excludes_join(self): + """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" + x, y = matrices("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=1) + f = function([x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_not_applied_empty_axis(self): + """Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)""" + x, y = vectors("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=[]) + f = function([x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_not_applied_dynamic_axis(self): + """Non-constant join axis should NOT trigger""" + axis = iscalar("axis") + x, y = vectors("xy") + out = pt_sum(join(axis, x, y), axis=None) + f = function([axis, x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + def test_local_useless_adds(): default_mode = get_default_mode() From 23220c4d1a0ffac5e80b266b5f3a930dc612382c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 13 May 2026 16:05:00 -0400 Subject: [PATCH 2/3] Address Ricardo's review: use owner_op, functools.reduce, drop clone, update tests --- pytensor/tensor/rewriting/math.py | 21 +++--- tests/tensor/rewriting/test_math.py | 102 +++++++++++++++++----------- 2 files changed, 71 insertions(+), 52 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index b85fe63b93..ec94ed486b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1969,18 +1969,19 @@ def local_reduce_join(fgraph, node): @register_uncanonicalize # Needed for Min which is formed from Neg(Max(Neg)) @node_rewriter([CAReduce]) def local_careduce_join(fgraph, node): - r"""CAReduce(Join(axis, \*tensors), axis=ax) -> Elemwise{scalar.op}(\*[CAReduce(axis=ax, t) for t in tensors]) + r"""CAReduce(Join(axis, \*tensors), axis=ax) -> CAReduce(..t1), .. combined via scalar.op When the reduction axis includes the join axis (or reduces all elements), - this avoids creating the concatenated intermediate array. + this avoids creating the concatenated intermediate array by reducing each + join input separately and combining results with the scalar op. - For >2 joined inputs, only scalar ops with variadic support - (Add, Mul) are rewritten, since Elemwise can't combine >2 - binary-only ops (e.g. Maximum, Minimum) at once. + For >2 joined inputs with binary-only ops (e.g. Maximum, Minimum), the + combine step uses nested applications since Elemwise{nin > 2} is not + currently supported for those ops. """ [joined_out] = node.inputs - if joined_out.owner is None or not isinstance(joined_out.owner.op, Join): + if not isinstance(joined_out.owner_op, Join): return None join_axis_tensor, *joined_inputs = joined_out.owner.inputs @@ -1991,9 +1992,6 @@ def local_careduce_join(fgraph, node): if not isinstance(join_axis_tensor, Constant): return None - if len(joined_inputs) > 2 and not isinstance(node.op.scalar_op, ps.Add | ps.Mul): - return None - join_axis = int(join_axis_tensor.data) if join_axis < 0: join_axis += joined_out.type.ndim @@ -2002,8 +2000,9 @@ def local_careduce_join(fgraph, node): if reduce_op.axis is not None and join_axis not in reduce_op.axis: return None - reduced = [reduce_op.clone(axis=reduce_op.axis)(inp) for inp in joined_inputs] - ret = Elemwise(reduce_op.scalar_op)(*reduced) + reduced = [reduce_op(inp) for inp in joined_inputs] + scalar_op = reduce_op.scalar_op + ret = reduce(lambda a, b: Elemwise(scalar_op)(a, b), reduced) if ret.dtype != node.outputs[0].dtype: return None diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index d6bb3417cc..f0357fd063 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -103,6 +103,7 @@ from pytensor.tensor.rewriting.math import ( compute_mul, is_1pexp, + local_careduce_join, local_div_switch_sink, local_grad_log_erfc_neg, local_greedy_distributor, @@ -3513,19 +3514,22 @@ def test_type(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - def test_not_supported_axis_none(self): - # Test that the rewrite does not crash in one case where it - # is not applied. Reported at - # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion - vx = matrix() - vy = matrix() - vz = matrix() + def test_careduce_join_list_sum_axis_none(self): + """Sum([a, b, c], axis=None) -> Add(Sum(ExpandDims(a)), Sum(ExpandDims(b)), Sum(ExpandDims(c))) with list inputs (via Join)""" + vx, vy, vz = matrices("xyz") + out = pt_sum([vx, vy, vz], axis=None) + + fg = FunctionGraph([vx, vy, vz], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None])) + assert equal_computations([rewritten_out], [expected_out]) + x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) - - out = pt_sum([vx, vy, vz], axis=None) - f = function([vx, vy, vz], out, mode=self.mode) + # Use py linker to avoid a pre-existing Numba bug with Sum(ExpandDims(...)) + py_mode = self.mode.__class__("py", self.mode.optimizer) + f = function([vx, vy, vz], out, mode=py_mode) np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) def test_not_supported_unequal_shapes(self): @@ -3562,40 +3566,64 @@ def test_non_ds_inputs(self): def test_careduce_join_sum_2(self): """Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs""" x, y = vectors("xy") - xv, yv = ( - np.random.rand(100).astype(config.floatX), - np.random.rand(200).astype(config.floatX), - ) out = pt_sum(pt.concatenate([x, y]), axis=None) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(pt_sum(x), pt_sum(y)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(200).astype(config.floatX) f = function([x, y], out, mode=self.mode) np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]))) - topo = f.maker.fgraph.toposort() - assert not any(isinstance(n.op, Join) for n in topo) def test_careduce_join_sum_3(self): - """Sum(concat(a, b, c), axis=None) -> Add(Sum(a), Sum(b), Sum(c)) with 3 inputs (variadic combine)""" + """Sum(concat(a, b, c), axis=None) with 3 inputs applied via nested Add combine""" x, y, z = vectors("xyz") + out = pt_sum(pt.concatenate([x, y, z]), axis=None) + + fg = FunctionGraph([x, y, z], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(add(pt_sum(x), pt_sum(y)), pt_sum(z)) + assert equal_computations([rewritten_out], [expected_out]) + xv = np.random.rand(100).astype(config.floatX) yv = np.random.rand(150).astype(config.floatX) zv = np.random.rand(200).astype(config.floatX) - out = pt_sum(pt.concatenate([x, y, z]), axis=None) f = function([x, y, z], out, mode=self.mode) np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv]))) - topo = f.maker.fgraph.toposort() - assert not any(isinstance(n.op, Join) for n in topo) def test_careduce_join_max_2(self): """Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)""" x, y = vectors("xy") - xv, yv = ( - np.random.rand(100).astype(config.floatX), - np.random.rand(200).astype(config.floatX), - ) out = pt_max(pt.concatenate([x, y]), axis=None) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = maximum(pt_max(x), pt_max(y)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(200).astype(config.floatX) f = function([x, y], out, mode=self.mode) np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv]))) - topo = f.maker.fgraph.toposort() - assert not any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_max_3(self): + """Max(concat(a, b, c), axis=None) applied via nested binary Maximum ops""" + x, y, z = vectors("xyz") + out = pt_max(pt.concatenate([x, y, z]), axis=None) + + fg = FunctionGraph([x, y, z], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = maximum(maximum(pt_max(x), pt_max(y)), pt_max(z)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(150).astype(config.floatX) + zv = np.random.rand(200).astype(config.floatX) + f = function([x, y, z], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv, zv), np.max(np.concatenate([xv, yv, zv]))) def test_careduce_join_sum_specific_axis(self): """Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0)) @@ -3603,25 +3631,17 @@ def test_careduce_join_sum_specific_axis(self): join_axis=0 is included in the reduction axes, so the rewrite applies. """ x, y = matrices("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(pt_sum(x, axis=0), pt_sum(y, axis=0)) + assert equal_computations([rewritten_out], [expected_out]) + xv = np.array([[1, 2], [3, 4]], dtype=config.floatX) yv = np.array([[5, 6]], dtype=config.floatX) - out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) f = function([x, y], out, mode=self.mode) np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0)) - topo = f.maker.fgraph.toposort() - assert not any(isinstance(n.op, Join) for n in topo) - - def test_careduce_join_max_3_not_applied(self): - """Max(concat(a, b, c), axis=None) should NOT trigger (binary-only combine can't take >2) - - Elemwise{maximum} is a binary op with nin=2, so it can't combine - three individual Max results at once like Add/Mul can. - """ - x, y, z = vectors("xyz") - out = pt_max(pt.concatenate([x, y, z]), axis=None) - f = function([x, y, z], out, mode=self.mode) - topo = f.maker.fgraph.toposort() - assert any(isinstance(n.op, Join) for n in topo) def test_careduce_join_not_applied_axis_excludes_join(self): """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" From dcfe621bdb6ffa7252960dcfbd868453f16eb066 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 13 May 2026 16:30:30 -0400 Subject: [PATCH 3/3] Use FAST_COMPILE for numerical check to avoid Numba backend bug --- tests/tensor/rewriting/test_math.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index f0357fd063..f579ca4cbc 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -3527,9 +3527,7 @@ def test_careduce_join_list_sum_axis_none(self): x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) - # Use py linker to avoid a pre-existing Numba bug with Sum(ExpandDims(...)) - py_mode = self.mode.__class__("py", self.mode.optimizer) - f = function([vx, vy, vz], out, mode=py_mode) + f = function([vx, vy, vz], out, mode=get_mode("FAST_COMPILE")) np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) def test_not_supported_unequal_shapes(self):