Skip to content

Commit 300f840

Browse files
committed
compiler: Improve halo-exchange hoisting
1 parent 315131d commit 300f840

3 files changed

Lines changed: 154 additions & 77 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def __hash__(self):
4343
return hash((self.loc_indices, self.loc_dirs, self.halos, self.dims,
4444
self.bundle))
4545

46+
@property
47+
def loc_values(self):
48+
return frozenset(self.loc_indices.values())
49+
4650
def union(self, other):
4751
"""
4852
Return a new HaloSchemeEntry that is the union of this and `other`.
@@ -413,6 +417,37 @@ def loc_values(self):
413417
def arguments(self):
414418
return self.dimensions | set(flatten(self.honored.values()))
415419

420+
def issubset(self, other):
421+
"""
422+
Check if `self` is a subset of `other`.
423+
"""
424+
if not isinstance(other, HaloScheme):
425+
return False
426+
427+
if not all(f in other.fmapper for f in self.fmapper):
428+
return False
429+
430+
for f, hse0 in self.fmapper.items():
431+
hse1 = other.fmapper[f]
432+
433+
# Clearly, `hse0`'s halos must be a subset of `hse1`'s halos...
434+
if not hse0.halos.issubset(hse1.halos) or \
435+
hse0.bundle is not hse1.bundle:
436+
return False
437+
438+
# But now, to be a subset, `hse0`'s must be expecting such halos
439+
# at a time index that is less than or equal to that of `hse1`
440+
if hse0.loc_dirs != hse1.loc_dirs:
441+
return False
442+
443+
loc_dirs = hse0.loc_dirs
444+
loc_indices = {**hse0.loc_indices, **hse1.loc_indices}
445+
projected_loc_indices, _ = process_loc_indices(loc_indices, loc_dirs)
446+
if projected_loc_indices != hse1.loc_indices:
447+
return False
448+
449+
return True
450+
416451
def project(self, functions):
417452
"""
418453
Create a new HaloScheme that only retains the HaloSchemeEntries corresponding

devito/passes/iet/mpi.py

Lines changed: 87 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -56,73 +56,81 @@ def _drop_reduction_halospots(iet):
5656

5757
def _hoist_invariant(iet):
5858
"""
59-
Hoist HaloSpots from inner to outer Iterations where all data dependencies
60-
would be honored. This pass avoids redundant halo exchanges when the same
61-
data is redundantly exchanged within the same Iteration tree level.
59+
Hoist redundant HaloSpots out of Iterations.
6260
6361
Example:
64-
haloupd v[t0]
65-
for time for time
66-
W v[t1]- R v[t0] W v[t1]- R v[t0]
67-
haloupd v[t1] haloupd v[t1]
68-
R v[t1] R v[t1]
69-
haloupd v[t0] R v[t0]
70-
R v[t0]
7162
63+
haloupd v[t0]
64+
haloupd v[t0]'
65+
for time for time
66+
haloupd v[t0]
67+
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
68+
haloupd v[t1] haloupd v[t1]
69+
R v[t1] R v[t1]
70+
haloupd v[t0]'
71+
R v[t0] R v[t0]
72+
73+
Where `haloupd v[t0]` and `haloupd v[t0]'` are subsets of `haloupd v[t1]`.
7274
"""
73-
74-
# Precompute scopes to save time
75-
scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}
76-
77-
# Analysis
78-
hsmapper = {}
79-
imapper = defaultdict(list)
80-
8175
cond_mapper = _make_cond_mapper(iet)
8276
iter_mapper = _filter_iter_mapper(iet)
8377

78+
hsmapper = defaultdict(list)
79+
imapper = defaultdict(list)
8480
for it, halo_spots in iter_mapper.items():
81+
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
82+
8583
for hs0, hs1 in combinations(halo_spots, r=2):
8684
if _check_control_flow(hs0, hs1, cond_mapper):
8785
continue
8886

89-
# If there are overlapping loc_indices, skip
90-
hs0_mdims = hs0.halo_scheme.loc_values
91-
hs1_mdims = hs1.halo_scheme.loc_values
92-
if hs0_mdims.intersection(hs1_mdims):
87+
functions = set(hs0.fmapper) & set(hs1.fmapper)
88+
if not functions:
89+
continue
90+
functions = sorted(functions, key=lambda f: f.name) # Determinism
91+
92+
hsf0 = hs0.halo_scheme.project(functions)
93+
hsf1 = hs1.halo_scheme.project(functions)
94+
95+
# NOTE: lexicographic order is important here, as if both `hsf0` and
96+
# `hsf1` are iteration-carried, we must hoist the later one in the
97+
# control flow (that is `hsf1`)
98+
flag0 = hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope)
99+
flag1 = hsf0.issubset(hsf1) and _is_iter_carried(hsf0, scope)
100+
if flag0 and flag1:
101+
# Special case: `_merge_halospots` will merge `hsf0` and `hsf1`
102+
# anyway, so we can avoid a useless hoisting
103+
continue
104+
elif flag0:
105+
hs, hsf = hs1, hsf1
106+
elif flag1:
107+
hs, hsf = hs0, hsf0
108+
else:
109+
# No hoisting possible, skip
93110
continue
94111

95-
for f, v in hs1.fmapper.items():
96-
if f not in hs0.functions:
97-
continue
112+
# At this point, we must infer valid loc_indices for `found`
113+
hhs = hsf.drop(functions) # hhs -> hoisted HaloScheme
114+
for f, hse in hsf.fmapper.items():
115+
loc_indices = {}
116+
for d, v in hse.loc_indices.items():
117+
if v in it.uindices:
118+
loc_indices[d] = v.symbolic_min.subs(it.dim, it.start)
119+
else:
120+
loc_indices[d] = v
121+
122+
hhs = hhs.add(f, hse._rebuild(loc_indices=loc_indices))
123+
124+
hsmapper[hs].extend(functions)
125+
imapper[it].append(hhs)
98126

99-
for dep in scopes[it].d_flow.project(f):
100-
if not any(r(dep, hs1, v.loc_indices) for r in rules):
101-
break
102-
else:
103-
# `hs1`` can be hoisted out of `it`, but we need to infer valid
104-
# loc_indices
105-
hse = hs1.halo_scheme.fmapper[f]
106-
loc_indices = {}
107-
108-
for d, v in hse.loc_indices.items():
109-
if v in it.uindices:
110-
loc_indices[d] = v.symbolic_min.subs(it.dim, it.start)
111-
else:
112-
loc_indices[d] = v
113-
114-
hse = hse._rebuild(loc_indices=loc_indices)
115-
hs1.halo_scheme.fmapper[f] = hse
116-
117-
hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
118-
imapper[it].append(hs1.halo_scheme.project(f))
119-
120-
mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
121-
for i, hss in imapper.items()}
122-
mapper.update({i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
123-
for i, hs in hsmapper.items()})
124127
# Transform the IET hoisting/dropping HaloSpots as according to the analysis
128+
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
129+
for hs, functions in hsmapper.items()}
130+
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
131+
for i, hss in imapper.items()})
125132
iet = Transformer(mapper, nested=True).visit(iet)
133+
126134
return iet
127135

128136

@@ -134,18 +142,16 @@ def _merge_halospots(iet):
134142
135143
Example:
136144
137-
for time for time
138-
haloupd v[t0] haloupd v[t0], h
139-
W v[t1]- R v[t0] W v[t1]- R v[t0]
145+
for time for time
146+
haloupd v[t0] haloupd v[t0], h
147+
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
140148
haloupd v[t0], h
141-
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
149+
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
142150
"""
143-
144-
# Analysis
145-
mapper = {}
146151
cond_mapper = _make_cond_mapper(iet)
147152
iter_mapper = _filter_iter_mapper(iet)
148153

154+
mapper = {}
149155
for it, halo_spots in iter_mapper.items():
150156
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
151157

@@ -157,19 +163,19 @@ def _merge_halospots(iet):
157163

158164
for f, v in hs1.fmapper.items():
159165
for dep in scope.d_flow.project(f):
160-
if not any(rule(dep, hs1, v.loc_indices) for rule in rules):
166+
if not any(r(dep, hs1, v.loc_indices) for r in rules):
161167
break
162168
else:
163-
# All good -- `hs1` can be merged with `hs0`
164-
hs = hs1.halo_scheme.project(f)
165-
mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs])
169+
# All good -- `hsf1` can be merged within `hs0`
170+
hsf1 = hs1.halo_scheme.project(f)
171+
mapper[hs0] = HaloScheme.union(
172+
[mapper.get(hs0, hs0.halo_scheme), hsf1]
173+
)
166174
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
167175

168-
# Post-process analysis
176+
# Transform the IET merging/dropping HaloSpots as according to the analysis
169177
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
170178
for i, hs in mapper.items()}
171-
172-
# Transform the IET merging/dropping HaloSpots as according to the analysis
173179
iet = Transformer(mapper, nested=True).visit(iet)
174180

175181
return iet
@@ -375,6 +381,23 @@ def _check_control_flow(hs0, hs1, cond_mapper):
375381
return cond0 != cond1
376382

377383

384+
def _is_iter_carried(hs, scope):
385+
"""
386+
True if the HaloScheme `hs` is iteration-carried, i.e., if the exchanged halo
387+
*only* uses values computed in a past iteration, False otherwise.
388+
"""
389+
for f, hse in hs.fmapper.items():
390+
try:
391+
writes = scope.writes[f]
392+
except KeyError:
393+
continue
394+
395+
if any(hse.loc_values.intersection(w.aindices) for w in writes):
396+
return False
397+
398+
return True
399+
400+
378401
# Code motion rules -- if the retval is True, then it means the input `dep` is not
379402
# a stopper to moving the HaloSpot `hs` around
380403

tests/test_mpi.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,8 @@ def test_avoid_redundant_haloupdate(self, mode):
11041104

11051105
calls = FindNodes(Call).visit(op)
11061106
assert len(calls) == 1
1107+
calls = FindNodes(Call).visit(get_time_loop(op))
1108+
assert len(calls) == 1
11071109

11081110
@pytest.mark.parallel(mode=1)
11091111
def test_avoid_redundant_haloupdate_cond(self, mode):
@@ -1123,7 +1125,7 @@ def test_avoid_redundant_haloupdate_cond(self, mode):
11231125
# access `f` at `t`, not `t+1` through factor subdim!
11241126
Eq(g, f[t, j] + 1, implicit_dim=t_sub)])
11251127

1126-
calls = FindNodes(Call).visit(op)
1128+
calls = FindNodes(Call).visit(get_time_loop(op))
11271129
assert len(calls) == 1
11281130
assert calls[0].functions[0] is f
11291131

@@ -1243,6 +1245,8 @@ def test_hoist_haloupdate_if_no_flowdep(self, mode):
12431245

12441246
calls = FindNodes(Call).visit(op)
12451247
assert len(calls) == 1
1248+
calls = FindNodes(Call).visit(get_time_loop(op))
1249+
assert len(calls) == 1
12461250

12471251
# Below, there is a flow-dependence along `x`, so a further halo update
12481252
# before the Inc is required
@@ -1405,16 +1409,22 @@ def test_avoid_haloupdate_if_flowdep_along_other_dim(self, mode):
14051409

14061410
calls = FindNodes(Call).visit(op)
14071411
assert len(calls) == 1
1412+
calls = FindNodes(Call).visit(get_time_loop(op))
1413+
assert len(calls) == 1
14081414

14091415
op.apply(time_M=1)
14101416
glb_pos_map = f.grid.distributor.glb_pos_map
14111417
R = 1e-07 # Can't use np.all due to rounding error at the tails
14121418
if LEFT in glb_pos_map[x]:
1413-
assert np.allclose(f.data_ro_domain[0, :5], [5.6, 6.8, 5.8, 4.8, 4.8], rtol=R)
1414-
assert np.allclose(g.data_ro_domain[0, :5], [2., 5.8, 5.8, 4.8, 4.8], rtol=R)
1419+
assert np.allclose(f.data_ro_domain[0, :5], [5.6, 6.8, 5.8, 4.8, 4.8],
1420+
rtol=R)
1421+
assert np.allclose(g.data_ro_domain[0, :5], [2., 5.8, 5.8, 4.8, 4.8],
1422+
rtol=R)
14151423
else:
1416-
assert np.allclose(f.data_ro_domain[0, 5:], [4.8, 4.8, 4.8, 4.8, 3.6], rtol=R)
1417-
assert np.allclose(g.data_ro_domain[0, 5:], [4.8, 4.8, 4.8, 4.8, 2.], rtol=R)
1424+
assert np.allclose(f.data_ro_domain[0, 5:], [4.8, 4.8, 4.8, 4.8, 3.6],
1425+
rtol=R)
1426+
assert np.allclose(g.data_ro_domain[0, 5:], [4.8, 4.8, 4.8, 4.8, 2.],
1427+
rtol=R)
14181428

14191429
@pytest.mark.parallel(mode=2)
14201430
def test_unmerge_haloupdate_if_no_locindices(self, mode):
@@ -1436,7 +1446,7 @@ def test_unmerge_haloupdate_if_no_locindices(self, mode):
14361446
calls = FindNodes(Call).visit(op)
14371447
assert len(calls) == 2
14381448

1439-
titer = op.body.body[-1].body[0]
1449+
titer = get_time_loop(op)
14401450
assert titer.dim is grid.time_dim
14411451
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
14421452
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
@@ -1483,8 +1493,8 @@ def test_merge_and_hoist_haloupdate_if_diff_locindices(self, mode):
14831493
"""
14841494
This test is a revisited, more complex version of
14851495
`test_merge_haloupdate_if_diff_locindices`, also checking hoisting.
1486-
And in addition to checking the generated code,
1487-
it also checks the numerical output.
1496+
And in addition to checking the generated code, it also checks the
1497+
numerical output.
14881498
14891499
In the Operator there are three Eqs:
14901500
@@ -1851,11 +1861,12 @@ def test_haloupdate_buffer1_v2(self, mode):
18511861
op.cfunction
18521862

18531863
# Ensure there's a halo exchange over v2 before the rec interpolation
1854-
section1 = op.body.body[-1].body[1].nodes[1]
1855-
assert section1.is_Section
1856-
calls = FindNodes(HaloUpdateCall).visit(section1)
1857-
assert len(calls) == 1
1858-
assert calls[0].arguments[0] is v2
1864+
calls = FindNodes(HaloUpdateCall).visit(op)
1865+
assert len(calls) == 3
1866+
calls = FindNodes(HaloUpdateCall).visit(get_time_loop(op))
1867+
assert len(calls) == 2
1868+
assert calls[0].arguments[0] is v1
1869+
assert calls[1].arguments[0] is v2
18591870

18601871

18611872
class TestOperatorAdvanced:
@@ -3162,6 +3173,14 @@ def test_halo_structure(self, mode):
31623173
assert calls[0].functions[1].name == 'v'
31633174

31643175

3176+
def get_time_loop(op):
3177+
iters = FindNodes(Iteration).visit(op)
3178+
for i in iters:
3179+
if i.dim.is_Time:
3180+
return i
3181+
assert False
3182+
3183+
31653184
if __name__ == "__main__":
31663185
# configuration['mpi'] = 'overlap'
31673186
# TestDecomposition().test_reshape_left_right()

0 commit comments

Comments
 (0)