Skip to content

Commit 5a4df89

Browse files
committed
compiler: Rework MPI hoist and merge
1 parent 300f840 commit 5a4df89

3 files changed

Lines changed: 103 additions & 103 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,9 @@ def issubset(self, other):
441441
return False
442442

443443
loc_dirs = hse0.loc_dirs
444-
loc_indices = {**hse0.loc_indices, **hse1.loc_indices}
445-
projected_loc_indices, _ = process_loc_indices(loc_indices, loc_dirs)
444+
raw_loc_indices = {d: (hse0.loc_indices[d], hse1.loc_indices[d])
445+
for d in hse0.loc_indices}
446+
projected_loc_indices, _ = process_loc_indices(raw_loc_indices, loc_dirs)
446447
if projected_loc_indices != hse1.loc_indices:
447448
return False
448449

devito/passes/iet/mpi.py

Lines changed: 75 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def optimize_halospots(iet, **kwargs):
2424
merged and moved around in order to improve the halo exchange performance.
2525
"""
2626
iet = _drop_reduction_halospots(iet)
27-
iet = _hoist_invariant(iet)
2827
iet = _merge_halospots(iet)
28+
iet = _hoist_invariant(iet)
2929
iet = _drop_if_unwritten(iet, **kwargs)
3030
iet = _mark_overlappable(iet)
3131

@@ -34,8 +34,8 @@ def optimize_halospots(iet, **kwargs):
3434

3535
def _drop_reduction_halospots(iet):
3636
"""
37-
Remove HaloSpots that are used to compute Increments
38-
(in which case, a halo exchange is actually unnecessary)
37+
Remove HaloSpots that are used to compute Increments (in which case, a halo
38+
exchange is actually unnecessary)
3939
"""
4040
mapper = defaultdict(set)
4141

@@ -54,32 +54,86 @@ def _drop_reduction_halospots(iet):
5454
return iet
5555

5656

57-
def _hoist_invariant(iet):
57+
def _merge_halospots(iet):
5858
"""
59-
Hoist redundant HaloSpots out of Iterations.
59+
Merge HaloSpots on the same IET level when data dependences allow it. This
60+
has two effects: anticipating communication over computation, and (potentially)
61+
avoiding redundant halo exchanges.
6062
61-
Example:
63+
Examples
64+
--------
65+
In the following example, we have two HaloSpots that both require a halo
66+
exchange for the same Function `v` at `t0`. Since `v[t0]` is not written
67+
to in the Iteration nest, we can merge the second HaloSpot into the first
68+
one, thus avoiding a redundant halo exchange.
6269
63-
haloupd v[t0]
64-
haloupd v[t0]'
6570
for time for time
66-
haloupd v[t0]
67-
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
68-
haloupd v[t1] haloupd v[t1]
69-
R v[t1] R v[t1]
70-
haloupd v[t0]'
71-
R v[t0] R v[t0]
71+
haloupd v[t0] haloupd v[t0], h
72+
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
73+
haloupd v[t0], h
74+
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
75+
"""
76+
cond_mapper = _make_cond_mapper(iet)
77+
iter_mapper = _filter_iter_mapper(iet)
78+
79+
mapper = {}
80+
for it, halo_spots in iter_mapper.items():
81+
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
82+
83+
hs0 = halo_spots[0]
84+
85+
for hs1 in halo_spots[1:]:
86+
if _check_control_flow(hs0, hs1, cond_mapper):
87+
continue
88+
89+
for f, v in hs1.fmapper.items():
90+
hsf0 = hs0.halo_scheme.project(f)
91+
hsf1 = hs1.halo_scheme.project(f)
92+
if hsf0.loc_values != hsf1.loc_values:
93+
continue
94+
95+
for dep in scope.d_flow.project(f):
96+
if not any(r(dep, hs1, v.loc_indices) for r in rules):
97+
break
98+
else:
99+
# All good -- `hsf1` can be merged within `hs0`
100+
mapper[hs0] = HaloScheme.union(
101+
[mapper.get(hs0, hs0.halo_scheme), hsf1]
102+
)
103+
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
104+
105+
# Transform the IET merging/dropping HaloSpots as according to the analysis
106+
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
107+
for i, hs in mapper.items()}
108+
iet = Transformer(mapper, nested=True).visit(iet)
109+
110+
return iet
111+
112+
113+
def _hoist_invariant(iet):
114+
"""
115+
Hoist iteration-carried HaloSpots out of Iterations.
116+
117+
Examples
118+
--------
119+
There is one typical case in which hoisting is possible, i.e., when a
120+
HaloSpot is iteration-carried, and it is a subset of another HaloSpot
121+
that is not iteration-carried. In this case, we can hoist the former
122+
out of the Iteration containing the latter, as follows:
72123
73-
Where `haloupd v[t0]` and `haloupd v[t0]'` are subsets of `haloupd v[t1]`.
124+
haloupd v[t0]
125+
for time for time
126+
haloupd v[t0]
127+
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
128+
haloupd v[t1] haloupd v[t1]
129+
R v[t1] R v[t1]
74130
"""
75131
cond_mapper = _make_cond_mapper(iet)
76132
iter_mapper = _filter_iter_mapper(iet)
77133

78134
hsmapper = defaultdict(list)
79135
imapper = defaultdict(list)
80136
for it, halo_spots in iter_mapper.items():
81-
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
82-
83137
for hs0, hs1 in combinations(halo_spots, r=2):
84138
if _check_control_flow(hs0, hs1, cond_mapper):
85139
continue
@@ -92,24 +146,15 @@ def _hoist_invariant(iet):
92146
hsf0 = hs0.halo_scheme.project(functions)
93147
hsf1 = hs1.halo_scheme.project(functions)
94148

95-
# NOTE: lexicographic order is important here, as if both `hsf0` and
96-
# `hsf1` are iteration-carried, we must hoist the later one in the
97-
# control flow (that is `hsf1`)
98-
flag0 = hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope)
99-
flag1 = hsf0.issubset(hsf1) and _is_iter_carried(hsf0, scope)
100-
if flag0 and flag1:
101-
# Special case: `_merge_halospots` will merge `hsf0` and `hsf1`
102-
# anyway, so we can avoid a useless hoisting
103-
continue
104-
elif flag0:
149+
if hsf1.issubset(hsf0):
105150
hs, hsf = hs1, hsf1
106-
elif flag1:
151+
elif hsf0.issubset(hsf1):
107152
hs, hsf = hs0, hsf0
108153
else:
109154
# No hoisting possible, skip
110155
continue
111156

112-
# At this point, we must infer valid loc_indices for `found`
157+
# At this point, we must infer valid loc_indices
113158
hhs = hsf.drop(functions) # hhs -> hoisted HaloScheme
114159
for f, hse in hsf.fmapper.items():
115160
loc_indices = {}
@@ -134,53 +179,6 @@ def _hoist_invariant(iet):
134179
return iet
135180

136181

137-
def _merge_halospots(iet):
138-
"""
139-
Using data dependence analysis, merge HaloSpots on the same IET level. This
140-
has two effects: anticipating communication over computation, and (potentially)
141-
avoiding redundant halo exchanges.
142-
143-
Example:
144-
145-
for time for time
146-
haloupd v[t0] haloupd v[t0], h
147-
W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
148-
haloupd v[t0], h
149-
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
150-
"""
151-
cond_mapper = _make_cond_mapper(iet)
152-
iter_mapper = _filter_iter_mapper(iet)
153-
154-
mapper = {}
155-
for it, halo_spots in iter_mapper.items():
156-
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
157-
158-
hs0 = halo_spots[0]
159-
160-
for hs1 in halo_spots[1:]:
161-
if _check_control_flow(hs0, hs1, cond_mapper):
162-
continue
163-
164-
for f, v in hs1.fmapper.items():
165-
for dep in scope.d_flow.project(f):
166-
if not any(r(dep, hs1, v.loc_indices) for r in rules):
167-
break
168-
else:
169-
# All good -- `hsf1` can be merged within `hs0`
170-
hsf1 = hs1.halo_scheme.project(f)
171-
mapper[hs0] = HaloScheme.union(
172-
[mapper.get(hs0, hs0.halo_scheme), hsf1]
173-
)
174-
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
175-
176-
# Transform the IET merging/dropping HaloSpots as according to the analysis
177-
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
178-
for i, hs in mapper.items()}
179-
iet = Transformer(mapper, nested=True).visit(iet)
180-
181-
return iet
182-
183-
184182
def _drop_if_unwritten(iet, options=None, **kwargs):
185183
"""
186184
Drop HaloSpots for unwritten Functions.
@@ -381,23 +379,6 @@ def _check_control_flow(hs0, hs1, cond_mapper):
381379
return cond0 != cond1
382380

383381

384-
def _is_iter_carried(hs, scope):
385-
"""
386-
True if the HaloScheme `hs` is iteration-carried, i.e., if the exchanged halo
387-
*only* uses values computed in a past iteration, False otherwise.
388-
"""
389-
for f, hse in hs.fmapper.items():
390-
try:
391-
writes = scope.writes[f]
392-
except KeyError:
393-
continue
394-
395-
if any(hse.loc_values.intersection(w.aindices) for w in writes):
396-
return False
397-
398-
return True
399-
400-
401382
# Code motion rules -- if the retval is True, then it means the input `dep` is not
402383
# a stopper to moving the HaloSpot `hs` around
403384

@@ -407,6 +388,7 @@ def _rule0(dep, hs, loc_indices):
407388
for d in dep.cause)
408389

409390

391+
#TODO: MAYBE AVOID PASSING IN LOC_INDICES AND PULL THEM STRAIGHT FROM HS
410392
def _rule1(dep, hs, loc_indices):
411393
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
412394
return any(dep.distance_mapper[d] == 0 and

tests/test_mpi.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,30 +1843,47 @@ def test_haloupdate_buffer1(self, mode):
18431843
assert len([i for i in FindSymbols('dimensions').visit(op) if i.is_Modulo]) == 0
18441844

18451845
@pytest.mark.parallel(mode=1)
1846-
def test_haloupdate_buffer1_v2(self, mode):
1846+
@pytest.mark.parametrize('sz,fwd,expr,exp0,exp1,args', [
1847+
(1, True, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
1848+
(1, True, 'Eq(v3.forward, v2.laplace + 1)', 1, 1, ('v2',)),
1849+
(1, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1850+
(2, True, 'Eq(v3.forward, v2.forward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1851+
(1, False, 'rec.interpolate(v2)', 3, 2, ('v1', 'v2')),
1852+
(1, False, 'Eq(v3.backward, v2.laplace + 1)', 1, 1, ('v2',)),
1853+
(1, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1854+
(2, False, 'Eq(v3.backward, v2.backward.laplace + 1)', 3, 2, ('v1', 'v2',)),
1855+
])
1856+
def test_haloupdate_buffer_cases(self, sz, fwd, expr, exp0, exp1, args, mode):
18471857
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
18481858

18491859
v1 = TimeFunction(name='v1', grid=grid, space_order=2, time_order=1,
18501860
save=Buffer(1))
18511861
v2 = TimeFunction(name='v2', grid=grid, space_order=2, time_order=1,
18521862
save=Buffer(1))
1863+
v3 = TimeFunction(name='v3', grid=grid, space_order=2, time_order=1, # noqa
1864+
save=Buffer(1))
18531865

18541866
rec = SparseTimeFunction(name='rec', grid=grid, nt=500, npoint=65)
18551867

1856-
eqns = [Eq(v1.forward, v2.dx2 + v1),
1857-
Eq(v2.forward, v1.forward.dx2 + v2),
1858-
rec.interpolate(v2)]
1868+
if fwd:
1869+
eqns = [Eq(v1.forward, v2.laplace + v1),
1870+
Eq(v2.forward, v1.forward.dx2 + v2),
1871+
eval(expr)]
1872+
else:
1873+
eqns = [Eq(v1.backward, v2.laplace + v1),
1874+
Eq(v2.backward, v1.backward.dx2 + v2),
1875+
eval(expr)]
18591876

18601877
op = Operator(eqns)
18611878
op.cfunction
18621879

18631880
# Ensure there's a halo exchange over v2 before the rec interpolation
18641881
calls = FindNodes(HaloUpdateCall).visit(op)
1865-
assert len(calls) == 3
1882+
assert len(calls) == exp0
18661883
calls = FindNodes(HaloUpdateCall).visit(get_time_loop(op))
1867-
assert len(calls) == 2
1868-
assert calls[0].arguments[0] is v1
1869-
assert calls[1].arguments[0] is v2
1884+
assert len(calls) == exp1
1885+
for i, v in enumerate(args):
1886+
assert calls[i].arguments[0] is eval(v)
18701887

18711888

18721889
class TestOperatorAdvanced:

0 commit comments

Comments
 (0)