-
Notifications
You must be signed in to change notification settings - Fork 186
Optimize CAReduce of Join by pushing reduction through concatenation #2130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
900489a
23220c4
dcfe621
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
| from pytensor.graph.traversal import ancestors | ||
| from pytensor.printing import debugprint, pprint | ||
| from pytensor.scalar import PolyGamma, Psi, TriGamma | ||
| from pytensor.tensor.basic import Alloc, constant, join, second, switch | ||
| from pytensor.tensor.basic import Alloc, Join, constant, join, second, switch | ||
| from pytensor.tensor.blas import Dot22, Gemv | ||
| from pytensor.tensor.blas_c import CGemv | ||
| from pytensor.tensor.blockwise import Blockwise | ||
|
|
@@ -103,6 +103,7 @@ | |
| from pytensor.tensor.rewriting.math import ( | ||
| compute_mul, | ||
| is_1pexp, | ||
| local_careduce_join, | ||
| local_div_switch_sink, | ||
| local_grad_log_erfc_neg, | ||
| local_greedy_distributor, | ||
|
|
@@ -3500,32 +3501,33 @@ def test_type(self): | |
| topo = f.maker.fgraph.toposort() | ||
| assert not isinstance(topo[-1].op, Elemwise) | ||
|
|
||
| # This case could be rewritten | ||
| # Join axis is included in reduction axes | ||
| A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) | ||
| f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=1), mode=self.mode) | ||
| np.testing.assert_allclose(f(), [2, 4, 6, 8, 10]) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not isinstance(topo[-1].op, Elemwise) | ||
| assert isinstance(topo[-1].op, Elemwise) | ||
|
|
||
| A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) | ||
| f = function([], pt_sum(pt.concatenate((A, A), axis=1), axis=0), mode=self.mode) | ||
| np.testing.assert_allclose(f(), [15, 15]) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert not isinstance(topo[-1].op, Elemwise) | ||
|
|
||
| def test_not_supported_axis_none(self): | ||
| # Test that the rewrite does not crash in one case where it | ||
| # is not applied. Reported at | ||
| # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion | ||
| vx = matrix() | ||
| vy = matrix() | ||
| vz = matrix() | ||
| def test_careduce_join_list_sum_axis_none(self): | ||
| """Sum([a, b, c], axis=None) -> Add(Sum(ExpandDims(a)), Sum(ExpandDims(b)), Sum(ExpandDims(c))) with list inputs (via Join)""" | ||
| vx, vy, vz = matrices("xyz") | ||
| out = pt_sum([vx, vy, vz], axis=None) | ||
|
|
||
| fg = FunctionGraph([vx, vy, vz], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None])) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| x = np.asarray([[1, 0], [3, 4]], dtype=config.floatX) | ||
| y = np.asarray([[4, 0], [2, 1]], dtype=config.floatX) | ||
| z = np.asarray([[5, 0], [1, 2]], dtype=config.floatX) | ||
|
|
||
| out = pt_sum([vx, vy, vz], axis=None) | ||
| f = function([vx, vy, vz], out, mode=self.mode) | ||
| f = function([vx, vy, vz], out, mode=get_mode("FAST_COMPILE")) | ||
| np.testing.assert_allclose(f(x, y, z), np.sum([x, y, z])) | ||
|
|
||
| def test_not_supported_unequal_shapes(self): | ||
|
|
@@ -3559,6 +3561,111 @@ def test_non_ds_inputs(self): | |
| expected_out = add(exp(x), log(x)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| def test_careduce_join_sum_2(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check the kind of approach to writing rewrite tests we're trying to settle on: #2103 |
||
| """Sum(concat(a, b), axis=None) -> Add(Sum(a), Sum(b)) with 2 inputs""" | ||
| x, y = vectors("xy") | ||
| out = pt_sum(pt.concatenate([x, y]), axis=None) | ||
|
|
||
| fg = FunctionGraph([x, y], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = add(pt_sum(x), pt_sum(y)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| xv = np.random.rand(100).astype(config.floatX) | ||
| yv = np.random.rand(200).astype(config.floatX) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]))) | ||
|
|
||
| def test_careduce_join_sum_3(self): | ||
| """Sum(concat(a, b, c), axis=None) with 3 inputs applied via nested Add combine""" | ||
| x, y, z = vectors("xyz") | ||
| out = pt_sum(pt.concatenate([x, y, z]), axis=None) | ||
|
|
||
| fg = FunctionGraph([x, y, z], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = add(add(pt_sum(x), pt_sum(y)), pt_sum(z)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| xv = np.random.rand(100).astype(config.floatX) | ||
| yv = np.random.rand(150).astype(config.floatX) | ||
| zv = np.random.rand(200).astype(config.floatX) | ||
| f = function([x, y, z], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv, zv), np.sum(np.concatenate([xv, yv, zv]))) | ||
|
|
||
| def test_careduce_join_max_2(self): | ||
| """Max(concat(a, b), axis=None) -> Maximum(Max(a), Max(b)) with 2 inputs (binary combine)""" | ||
| x, y = vectors("xy") | ||
| out = pt_max(pt.concatenate([x, y]), axis=None) | ||
|
|
||
| fg = FunctionGraph([x, y], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = maximum(pt_max(x), pt_max(y)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| xv = np.random.rand(100).astype(config.floatX) | ||
| yv = np.random.rand(200).astype(config.floatX) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.max(np.concatenate([xv, yv]))) | ||
|
|
||
| def test_careduce_join_max_3(self): | ||
| """Max(concat(a, b, c), axis=None) applied via nested binary Maximum ops""" | ||
| x, y, z = vectors("xyz") | ||
| out = pt_max(pt.concatenate([x, y, z]), axis=None) | ||
|
|
||
| fg = FunctionGraph([x, y, z], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = maximum(maximum(pt_max(x), pt_max(y)), pt_max(z)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| xv = np.random.rand(100).astype(config.floatX) | ||
| yv = np.random.rand(150).astype(config.floatX) | ||
| zv = np.random.rand(200).astype(config.floatX) | ||
| f = function([x, y, z], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv, zv), np.max(np.concatenate([xv, yv, zv]))) | ||
|
|
||
| def test_careduce_join_sum_specific_axis(self): | ||
| """Sum(concat(mat_a, mat_b), axis=0) -> Add(Sum(mat_a, axis=0), Sum(mat_b, axis=0)) | ||
|
|
||
| join_axis=0 is included in the reduction axes, so the rewrite applies. | ||
| """ | ||
| x, y = matrices("xy") | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=0) | ||
|
|
||
| fg = FunctionGraph([x, y], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = add(pt_sum(x, axis=0), pt_sum(y, axis=0)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| xv = np.array([[1, 2], [3, 4]], dtype=config.floatX) | ||
| yv = np.array([[5, 6]], dtype=config.floatX) | ||
| f = function([x, y], out, mode=self.mode) | ||
| np.testing.assert_allclose(f(xv, yv), np.sum(np.concatenate([xv, yv]), axis=0)) | ||
|
|
||
| def test_careduce_join_not_applied_axis_excludes_join(self): | ||
| """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead |
||
| x, y = matrices("xy") | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=1) | ||
| f = function([x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_empty_axis(self): | ||
| """Sum(concat(a, b), axis=[]) should NOT trigger (empty reduction)""" | ||
| x, y = vectors("xy") | ||
| out = pt_sum(pt.concatenate([x, y], axis=0), axis=[]) | ||
| f = function([x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_dynamic_axis(self): | ||
| """Non-constant join axis should NOT trigger""" | ||
| axis = iscalar("axis") | ||
| x, y = vectors("xy") | ||
| out = pt_sum(join(axis, x, y), axis=None) | ||
| f = function([axis, x, y], out, mode=self.mode) | ||
| topo = f.maker.fgraph.toposort() | ||
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
|
|
||
| def test_local_useless_adds(): | ||
| default_mode = get_default_mode() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are missing CAReduce(Dimshuffle(x)) -> CAReduce(x), when the DimShuffle has no effect due to reduction (or DimShuffle may still be needed but only a subset of its behavior). In this case it isn't needed.
Does not need to be a blocker for this PR, but we should open an issue. I thought there was one already