Skip to content

Commit 2027e92

Browse files
committed
compiler: Revamp HaloUpdate hoisting and merging
1 parent 5a4df89 commit 2027e92

4 files changed

Lines changed: 340 additions & 84 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
IndexedData, DeviceMap)
2929

3030

31-
__all__ = ['FindApplications', 'FindNodes', 'FindSections', 'FindSymbols',
32-
'MapExprStmts', 'MapHaloSpots', 'MapNodes', 'IsPerfectIteration',
33-
'printAST', 'CGen', 'CInterface', 'Transformer', 'Uxreplace']
31+
__all__ = ['FindApplications', 'FindNodes', 'FindWithin', 'FindSections',
32+
'FindSymbols', 'MapExprStmts', 'MapHaloSpots', 'MapNodes',
33+
'IsPerfectIteration', 'printAST', 'CGen', 'CInterface', 'Transformer',
34+
'Uxreplace']
3435

3536

3637
class Visitor(GenericVisitor):
@@ -1112,6 +1113,46 @@ def visit_Node(self, o, ret=None):
11121113
return ret
11131114

11141115

1116+
class FindWithin(FindNodes):
1117+
1118+
@classmethod
1119+
def default_retval(cls):
1120+
return [], False
1121+
1122+
"""
1123+
Like FindNodes, but given an additional parameter `within=(start, stop)`,
1124+
it starts collecting matching nodes only after `start` is found, and stops
1125+
collecting matching nodes after `stop` is found.
1126+
"""
1127+
1128+
def __init__(self, match, start, stop=None):
1129+
super().__init__(match)
1130+
self.start = start
1131+
self.stop = stop
1132+
1133+
def visit(self, o, ret=None):
1134+
found, _ = self._visit(o, ret=ret)
1135+
return found
1136+
1137+
def visit_Node(self, o, ret=None):
1138+
if ret is None:
1139+
ret = self.default_retval()
1140+
found, flag = ret
1141+
1142+
if o is self.start:
1143+
flag = True
1144+
1145+
if flag and self.rule(self.match, o):
1146+
found.append(o)
1147+
for i in o.children:
1148+
found, flag = self._visit(i, ret=(found, flag))
1149+
1150+
if o is self.stop:
1151+
flag = False
1152+
1153+
return found, flag
1154+
1155+
11151156
class FindApplications(Visitor):
11161157

11171158
"""

devito/mpi/halo_scheme.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ def owned_size(self):
388388
mapper[d] = (max(maxl, s.left), max(maxr, s.right))
389389
return mapper
390390

391+
@cached_property
392+
def functions(self):
393+
return frozenset(self.fmapper)
394+
391395
@cached_property
392396
def dimensions(self):
393397
retval = set()

devito/passes/iet/mpi.py

Lines changed: 166 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from itertools import combinations
55

66
from devito.ir.iet import (Call, Expression, HaloSpot, Iteration, FindNodes,
7-
MapNodes, MapHaloSpots, Transformer,
7+
FindWithin, MapNodes, MapHaloSpots, Transformer,
88
retrieve_iteration_tree)
99
from devito.ir.support import PARALLEL, Scope
1010
from devito.ir.support.guards import GuardFactorEq
@@ -24,6 +24,7 @@ def optimize_halospots(iet, **kwargs):
2424
merged and moved around in order to improve the halo exchange performance.
2525
"""
2626
iet = _drop_reduction_halospots(iet)
27+
iet = _hoist_redundant_from_conditionals(iet)
2728
iet = _merge_halospots(iet)
2829
iet = _hoist_invariant(iet)
2930
iet = _drop_if_unwritten(iet, **kwargs)
@@ -54,6 +55,72 @@ def _drop_reduction_halospots(iet):
5455
return iet
5556

5657

58+
def _hoist_redundant_from_conditionals(iet):
59+
"""
60+
Hoist redundant HaloSpots from Conditionals. The idea is that, by doing so,
61+
the subsequent passes (hoist, merge) will be able to optimize them away.
62+
63+
Examples
64+
--------
65+
66+
for time for time
67+
if(cond) if(cond)
68+
haloupd u[t0] haloupd u[t0]
69+
haloupd v[t0] haloupd v[t0]
70+
W v[t1]- R u[t0],v[t0] ---> W v[t1]- R u[t0],v[t0]
71+
haloupd v[t0] haloupd v[t0]
72+
R v[t0] R v[t0]
73+
74+
Note that in the above example, the `haloupd v[t0]` in the if branch is
75+
redundant, as it also appears later on in the Iteration body, so it is
76+
hoisted out of the Conditional.
77+
"""
78+
cond_mapper = _make_cond_mapper(iet)
79+
iter_mapper = _filter_iter_mapper(iet)
80+
81+
hsmapper = defaultdict(list)
82+
imapper = defaultdict(list)
83+
for it, halo_spots in iter_mapper.items():
84+
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
85+
86+
for hs0 in halo_spots:
87+
conditions = cond_mapper[hs0]
88+
if not conditions:
89+
continue
90+
condition = conditions[-1] # Take the innermost Conditional
91+
92+
for f in hs0.fmapper:
93+
hsf0 = hs0.halo_scheme.project(f)
94+
95+
# Find candidate for subsequent merging
96+
for hs1 in halo_spots:
97+
if hs0 is hs1 or cond_mapper[hs1]:
98+
continue
99+
100+
hsf1 = hs1.halo_scheme.project(f)
101+
if not _is_mergeable(hsf1, hsf0, scope) or \
102+
not hsf1.issubset(hsf0):
103+
continue
104+
105+
break
106+
else:
107+
# No candidate found, skip
108+
continue
109+
110+
hsmapper[hs0].append(f)
111+
imapper[condition].append(hsf0)
112+
113+
# Transform the IET according to the analysis
114+
#TODO: MOVE THIS INTO UTILITY
115+
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
116+
for hs, functions in hsmapper.items()}
117+
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
118+
for i, hss in imapper.items()})
119+
iet = Transformer(mapper, nested=True).visit(iet)
120+
121+
return iet
122+
123+
57124
def _merge_halospots(iet):
58125
"""
59126
Merge HaloSpots on the same IET level when data dependences allow it. This
@@ -78,31 +145,28 @@ def _merge_halospots(iet):
78145

79146
mapper = {}
80147
for it, halo_spots in iter_mapper.items():
81-
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
82-
83-
hs0 = halo_spots[0]
84-
85-
for hs1 in halo_spots[1:]:
148+
for hs0, hs1 in combinations(halo_spots, r=2):
86149
if _check_control_flow(hs0, hs1, cond_mapper):
87150
continue
88151

89-
for f, v in hs1.fmapper.items():
90-
hsf0 = hs0.halo_scheme.project(f)
91-
hsf1 = hs1.halo_scheme.project(f)
92-
if hsf0.loc_values != hsf1.loc_values:
152+
scope = _derive_scope(it, hs0, hs1)
153+
154+
for f in hs1.fmapper:
155+
hsf0 = mapper.get(hs0, hs0.halo_scheme)
156+
hsf1 = mapper.get(hs1, hs1.halo_scheme).project(f)
157+
if not _is_mergeable(hsf0, hsf1, scope):
93158
continue
94159

95-
for dep in scope.d_flow.project(f):
96-
if not any(r(dep, hs1, v.loc_indices) for r in rules):
97-
break
98-
else:
99-
# All good -- `hsf1` can be merged within `hs0`
100-
mapper[hs0] = HaloScheme.union(
101-
[mapper.get(hs0, hs0.halo_scheme), hsf1]
102-
)
103-
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
160+
# All good -- `hsf1` can be merged within `hs0`
161+
mapper[hs0] = HaloScheme.union([hsf0, hsf1])
104162

105-
# Transform the IET merging/dropping HaloSpots as according to the analysis
163+
# If the `loc_indices` differ, we rely on hoisting to optimize
164+
# `hsf1` out of `it`, otherwise we just drop it
165+
if hsf0.loc_values != hsf1.loc_values:
166+
continue
167+
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
168+
169+
# Transform the IET according to the analysis
106170
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
107171
for i, hs in mapper.items()}
108172
iet = Transformer(mapper, nested=True).visit(iet)
@@ -116,10 +180,10 @@ def _hoist_invariant(iet):
116180
117181
Examples
118182
--------
119-
There is one typical case in which hoisting is possible, i.e., when a
120-
HaloSpot is iteration-carried, and it is a subset of another HaloSpot
121-
that is not iteration-carried. In this case, we can hoist the former
122-
out of the Iteration containing the latter, as follows:
183+
There is one typical case in which hoisting is possible, i.e., when a HaloSpot
184+
is iteration-carried, and it is a subset of another HaloSpot within the same
185+
Iteration. In this case, we can hoist the former out of the Iteration
186+
containing the latter, as follows:
123187
124188
haloupd v[t0]
125189
for time for time
@@ -138,38 +202,41 @@ def _hoist_invariant(iet):
138202
if _check_control_flow(hs0, hs1, cond_mapper):
139203
continue
140204

141-
functions = set(hs0.fmapper) & set(hs1.fmapper)
142-
if not functions:
143-
continue
144-
functions = sorted(functions, key=lambda f: f.name) # Determinism
205+
scope = _derive_scope(it, hs0, hs1)
145206

146-
hsf0 = hs0.halo_scheme.project(functions)
147-
hsf1 = hs1.halo_scheme.project(functions)
207+
for f in hs1.fmapper:
208+
hsf0 = hs0.halo_scheme.project(f)
209+
if hsf0.is_void:
210+
continue
211+
hsf1 = hs1.halo_scheme.project(f)
148212

149-
if hsf1.issubset(hsf0):
150-
hs, hsf = hs1, hsf1
151-
elif hsf0.issubset(hsf1):
152-
hs, hsf = hs0, hsf0
153-
else:
154-
# No hoisting possible, skip
155-
continue
213+
# Ensure there's another HaloScheme that could cover for
214+
# us should we get hoisted while still satisfying the
215+
# data dependences
216+
if hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope):
217+
hs, hsf = hs1, hsf1
218+
elif hsf0.issubset(hsf1) and hs0 is halo_spots[0]:
219+
# Special case
220+
hs, hsf = hs0, hsf0
221+
else:
222+
# No hoisting possible, skip
223+
continue
156224

157-
# At this point, we must infer valid loc_indices
158-
hhs = hsf.drop(functions) # hhs -> hoisted HaloScheme
159-
for f, hse in hsf.fmapper.items():
225+
# At this point, we must infer valid loc_indices
226+
hse = hsf.fmapper[f]
160227
loc_indices = {}
161228
for d, v in hse.loc_indices.items():
162229
if v in it.uindices:
163230
loc_indices[d] = v.symbolic_min.subs(it.dim, it.start)
164231
else:
165232
loc_indices[d] = v
233+
hhs = hsf.drop(f).add(f, hse._rebuild(loc_indices=loc_indices))
166234

167-
hhs = hhs.add(f, hse._rebuild(loc_indices=loc_indices))
235+
hsmapper[hs].append(f)
236+
imapper[it].append(hhs)
168237

169-
hsmapper[hs].extend(functions)
170-
imapper[it].append(hhs)
171-
172-
# Transform the IET hoisting/dropping HaloSpots as according to the analysis
238+
# Transform the IET according to the analysis
239+
#TODO: MOVE THIS INTO UTILITY
173240
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
174241
for hs, functions in hsmapper.items()}
175242
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
@@ -359,15 +426,26 @@ def _make_cond_mapper(iet):
359426
"""
360427
Return a mapper from HaloSpots to the Conditionals that contain them.
361428
"""
429+
#TODO
362430
cond_mapper = {}
363431
for hs, v in MapHaloSpots().visit(iet).items():
364-
conditionals = {i for i in v if i.is_Conditional and
365-
not isinstance(i.condition, GuardFactorEq)}
366-
cond_mapper[hs] = conditionals
432+
cond_mapper[hs] = tuple(i for i in v if i.is_Conditional)
367433

368434
return cond_mapper
369435

370436

437+
def _derive_scope(it, hs0, hs1):
438+
"""
439+
Derive a Scope within the Iteration `it` that starts at the HaloSpot `hs0`
440+
and ends at the HaloSpot `hs1`.
441+
"""
442+
expressions = FindWithin(Expression, hs0, stop=hs1).visit(it)
443+
assert len(expressions) > 0, \
444+
"Expected at least one Expression between %s and %s" % (hs0, hs1)
445+
446+
return Scope([e.expr for e in expressions])
447+
448+
371449
def _check_control_flow(hs0, hs1, cond_mapper):
372450
"""
373451
If there are Conditionals involved, both `hs0` and `hs1` must be
@@ -379,22 +457,48 @@ def _check_control_flow(hs0, hs1, cond_mapper):
379457
return cond0 != cond1
380458

381459

382-
# Code motion rules -- if the retval is True, then it means the input `dep` is not
383-
# a stopper to moving the HaloSpot `hs` around
460+
def _is_iter_carried(hsf, scope):
461+
"""
462+
True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces
463+
a halo exchange that requires values from the previous iteration(s); False
464+
otherwise.
465+
"""
466+
467+
def rule0(dep):
468+
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
469+
return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause)
470+
471+
def rule1(dep, loc_indices):
472+
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
473+
return any(dep.distance_mapper[d] == 0 and
474+
dep.source[d] is not v and
475+
dep.sink[d] is not v
476+
for d, v in loc_indices.items())
384477

385-
def _rule0(dep, hs, loc_indices):
386-
# E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>` => True
387-
return not any(d in hs.dimensions or dep.distance_mapper[d] is S.Infinity
388-
for d in dep.cause)
478+
for f, v in hsf.fmapper.items():
479+
for dep in scope.d_flow.project(f):
480+
if not rule0(dep) and not rule1(dep, v.loc_indices):
481+
return False
389482

483+
return True
484+
485+
486+
def _is_mergeable(hsf0, hsf1, scope):
487+
"""
488+
True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible
489+
and the data dependences would be satisfied, False otherwise.
490+
"""
491+
# If `hsf1` is empty there's nothing to merge
492+
if hsf1.is_void:
493+
return False
390494

391-
#TODO: MAYBE AVOID PASSING IN LOC_INDICES AND PULL THEM STRAIGHT FROM HS
392-
def _rule1(dep, hs, loc_indices):
393-
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
394-
return any(dep.distance_mapper[d] == 0 and
395-
dep.source[d] is not v and
396-
dep.sink[d] is not v
397-
for d, v in loc_indices.items())
495+
# Ensure `hsf0` and `hsf1` are compatible
496+
if hsf0.dimensions != hsf1.dimensions or \
497+
not hsf0.functions & hsf1.functions:
498+
return False
398499

500+
# Then, check the data dependences would be satisfied
501+
if not _is_iter_carried(hsf1, scope):
502+
return False
399503

400-
rules = (_rule0, _rule1)
504+
return True

0 commit comments

Comments
 (0)