Skip to content

Commit 8a40da4

Browse files
committed
mpi: fix allreduce iteration space
1 parent 77ffae7 commit 8a40da4

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.exceptions import CompilationError
99
from devito.finite_differences.elementary import Max, Min
1010
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
11-
pull_dims, null_ispace)
11+
pull_dims)
1212
from devito.ir.equations import OpMin, OpMax, identity_mapper
1313
from devito.ir.clusters.analysis import analyze
1414
from devito.ir.clusters.cluster import Cluster, ClusterGroup
@@ -493,9 +493,9 @@ def reduction_comms(clusters):
493493
processed.append(c)
494494

495495
# Leftover reductions are placed at the very end
496-
if fifo:
497-
exprs = [Eq(dr.var, dr) for dr in fifo]
498-
processed.append(Cluster(exprs=exprs, ispace=null_ispace))
496+
for ispace, reds in groupby(fifo, key=lambda r: r.ispace):
497+
exprs = [Eq(dr.var, dr) for dr in reds]
498+
processed.append(Cluster(exprs=exprs, ispace=ispace))
499499

500500
return processed
501501

tests/test_mpi.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,24 @@ def test_merge_smart_if_within_conditional(self, mode):
20142014
for n in FindNodes(Conditional).visit(op1):
20152015
assert len(FindNodes(HaloUpdateCall).visit(n)) == 0
20162016

2017+
@pytest.mark.parallel(mode=2)
2018+
def test_allreduce_time(self, mode):
2019+
space_order = 8
2020+
nx, ny = 11, 11
2021+
2022+
grid = Grid(shape=(nx, ny))
2023+
tt = grid.time_dim
2024+
nt = 10
2025+
2026+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2027+
g = TimeFunction(name="g", grid=grid, dimensions=(tt, ), shape=(nt,))
2028+
2029+
op = Operator([Eq(ux.forward, ux + tt), Inc(g, ux)], name="Op")
2030+
assert_structure(op, ['t,x,y', 't'], 'txy')
2031+
2032+
op.apply(time_m=0, time_M=nt-1)
2033+
assert np.isclose(np.max(g.data), 4356.0)
2034+
20172035

20182036
class TestOperatorAdvanced:
20192037

0 commit comments

Comments
 (0)