Skip to content

Commit 9f74d37

Browse files
committed
ci: add test for cond multi all-reduce
1 parent ecc0a0c commit 9f74d37

1 file changed

Lines changed: 33 additions & 0 deletions

File tree

tests/test_mpi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,39 @@ def test_multi_allreduce_time(self, mode):
20682068
assert np.isclose(np.max(g.data), 4356.0)
20692069
assert np.isclose(np.max(h.data), 4356.0)
20702070

2071+
@pytest.mark.parallel(mode=1)
2072+
def test_multi_allreduce_time_cond(self, mode):
2073+
space_order = 8
2074+
nx, ny = 11, 11
2075+
2076+
grid = Grid(shape=(nx, ny))
2077+
tt = grid.time_dim
2078+
nt = 20
2079+
ct = ConditionalDimension(name="ct", parent=tt, factor=2)
2080+
2081+
ux = TimeFunction(name="ux", grid=grid, time_order=1, space_order=space_order)
2082+
g = TimeFunction(name="g", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2083+
time_dim=ct)
2084+
h = TimeFunction(name="h", grid=grid, dimensions=(ct, ), shape=(int(nt/2),),
2085+
time_dim=ct)
2086+
2087+
op = Operator([Eq(g, 0), Eq(ux.forward, tt), Inc(g, ux), Inc(h, ux)], name="Op")
2088+
assert_structure(op, ['t', 't,x,y', 't,x,y'], 'txyxy')
2089+
2090+
# Make sure the two allreduce calls are in the time the loop
2091+
iters = FindNodes(Iteration).visit(op)
2092+
for i in iters:
2093+
if i.dim.is_Time:
2094+
assert len(FindNodes(Call).visit(i)) == 2 # Two allreduce
2095+
else:
2096+
assert len(FindNodes(Call).visit(i)) == 0
2097+
2098+
op.apply(time_m=0, time_M=nt-1)
2099+
2100+
expected = [nx * ny * max(t-1, 0) for t in range(0, nt, 2)]
2101+
assert np.allclose(g.data, expected)
2102+
assert np.allclose(h.data, expected)
2103+
20712104

20722105
class TestOperatorAdvanced:
20732106

0 commit comments

Comments
 (0)