diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 937fb068d5..ec94ed486b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1964,6 +1964,53 @@ def local_reduce_join(fgraph, node): return [ret] +@register_specialize +@register_canonicalize +@register_uncanonicalize # Needed for Min which is formed from Neg(Max(Neg)) +@node_rewriter([CAReduce]) +def local_careduce_join(fgraph, node): + r"""CAReduce(Join(axis, \*tensors), axis=ax) -> CAReduce(..t1), .. combined via scalar.op + + When the reduction axis includes the join axis (or reduces all elements), + this avoids creating the concatenated intermediate array by reducing each + join input separately and combining results with the scalar op. + + For >2 joined inputs with binary-only ops (e.g. Maximum, Minimum), the + combine step uses nested applications since Elemwise{nin > 2} is not + currently supported for those ops. + + """ + [joined_out] = node.inputs + if not isinstance(joined_out.owner_op, Join): + return None + + join_axis_tensor, *joined_inputs = joined_out.owner.inputs + + if len(joined_inputs) < 2: + return None + + if not isinstance(join_axis_tensor, Constant): + return None + + join_axis = int(join_axis_tensor.data) + if join_axis < 0: + join_axis += joined_out.type.ndim + + reduce_op = node.op + if reduce_op.axis is not None and join_axis not in reduce_op.axis: + return None + + reduced = [reduce_op(inp) for inp in joined_inputs] + scalar_op = reduce_op.scalar_op + ret = reduce(lambda a, b: Elemwise(scalar_op)(a, b), reduced) + + if ret.dtype != node.outputs[0].dtype: + return None + + copy_stack_trace(node.outputs[0], ret) + return [ret] + + @register_infer_shape @register_canonicalize("fast_compile", "local_cut_useless_reduce") @register_useless("local_cut_useless_reduce") diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 106862c641..f579ca4cbc 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -31,7 +31,7 @@ from pytensor.graph.traversal import ancestors from pytensor.printing import debugprint, pprint from pytensor.scalar import PolyGamma, Psi, TriGamma -from pytensor.tensor.basic import Alloc, constant, join, second, switch +from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise @@ -103,6 +103,7 @@ from pytensor.tensor.rewriting.math import ( compute_mul, is_1pexp, + local_careduce_join, local_div_switch_sink, local_grad_log_erfc_neg, local_greedy_distributor, @@ -3500,12 +3501,12 @@ def test_type(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - # This case could be rewritten + # Join axis is included in reduction axes A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) topo = f.maker.fgraph.toposort() - assert not isinstance(topo[-1].op, Elemwise) + assert isinstance(topo[-1].op, Elemwise) A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) @@ -3513,19 +3514,20 @@ def test_type(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - def test_not_supported_axis_none(self): - # Test that the rewrite does not crash in one case where it - # is not applied. Reported at - # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion - vx = matrix() - vy = matrix() - vz = matrix() + def test_careduce_join_list_sum_axis_none(self): + """Sum([a, b, c], axis=None) -> Add(Sum(ExpandDims(a)), Sum(ExpandDims(b)), Sum(ExpandDims(c))) with list inputs (via Join)""" + vx, vy, vz = matrices("xyz") + out = pt_sum([vx, vy, vz], axis=None) + + fg = FunctionGraph([vx, vy, vz], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None])) + assert equal_computations([rewritten_out], [expected_out]) + x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) - - out = pt_sum([vx, vy, vz], axis=None) - f = function([vx, vy, vz], out, mode=self.mode) + f = function([vx, vy, vz], out, mode=get_mode("FAST_COMPILE")) np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) def test_not_supported_unequal_shapes(self): @@ -3559,6 +3561,111 @@ def test_non_ds_inputs(self): expected_out = add(exp(x), log(x)) assert equal_computations([rewritten_out], [expected_out]) + def test_careduce_join_sum_2(self): + """Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs""" + x, y = vectors("xy") + out = pt_sum(pt.concatenate([x, y]), axis=None) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(pt_sum(x), pt_sum(y)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(200).astype(config.floatX) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]))) + + def test_careduce_join_sum_3(self): + """Sum(concat(a, b, c), axis=None) with 3 inputs applied via nested Add combine""" + x, y, z = vectors("xyz") + out = pt_sum(pt.concatenate([x, y, z]), axis=None) + + fg = FunctionGraph([x, y, z], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(add(pt_sum(x), pt_sum(y)), pt_sum(z)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(150).astype(config.floatX) + zv = np.random.rand(200).astype(config.floatX) + f = function([x, y, z], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv]))) + + def test_careduce_join_max_2(self): + """Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)""" + x, y = vectors("xy") + out = pt_max(pt.concatenate([x, y]), axis=None) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = maximum(pt_max(x), pt_max(y)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(200).astype(config.floatX) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv]))) + + def test_careduce_join_max_3(self): + """Max(concat(a, b, c), axis=None) applied via nested binary Maximum ops""" + x, y, z = vectors("xyz") + out = pt_max(pt.concatenate([x, y, z]), axis=None) + + fg = FunctionGraph([x, y, z], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = maximum(maximum(pt_max(x), pt_max(y)), pt_max(z)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.random.rand(100).astype(config.floatX) + yv = np.random.rand(150).astype(config.floatX) + zv = np.random.rand(200).astype(config.floatX) + f = function([x, y, z], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv, zv), np.max(np.concatenate([xv, yv, zv]))) + + def test_careduce_join_sum_specific_axis(self): + """Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0)) + + join_axis=0 is included in the reduction axes, so the rewrite applies. + """ + x, y = matrices("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) + + fg = FunctionGraph([x, y], [out], clone=False) + [rewritten_out] = local_careduce_join.transform(fg, out.owner) + expected_out = add(pt_sum(x, axis=0), pt_sum(y, axis=0)) + assert equal_computations([rewritten_out], [expected_out]) + + xv = np.array([[1, 2], [3, 4]], dtype=config.floatX) + yv = np.array([[5, 6]], dtype=config.floatX) + f = function([x, y], out, mode=self.mode) + np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0)) + + def test_careduce_join_not_applied_axis_excludes_join(self): + """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" + x, y = matrices("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=1) + f = function([x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_not_applied_empty_axis(self): + """Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)""" + x, y = vectors("xy") + out = pt_sum(pt.concatenate([x, y], axis=0), axis=[]) + f = function([x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + + def test_careduce_join_not_applied_dynamic_axis(self): + """Non-constant join axis should NOT trigger""" + axis = iscalar("axis") + x, y = vectors("xy") + out = pt_sum(join(axis, x, y), axis=None) + f = function([axis, x, y], out, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert any(isinstance(n.op, Join) for n in topo) + def test_local_useless_adds(): default_mode = get_default_mode()