Skip to content

Commit 3b80b02

Browse files
committed
compiler: Cleanup MPI passes
1 parent 2027e92 commit 3b80b02

2 files changed

Lines changed: 58 additions & 40 deletions

File tree

devito/ir/iet/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ def arguments(self):
15161516
return self.halo_scheme.arguments
15171517

15181518
@property
1519-
def is_empty(self):
1520-
return len(self.halo_scheme) == 0
1519+
def is_void(self):
1520+
return self.halo_scheme.is_void
15211521

15221522
@property
15231523
def body(self):

devito/passes/iet/mpi.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
FindWithin, MapNodes, MapHaloSpots, Transformer,
88
retrieve_iteration_tree)
99
from devito.ir.support import PARALLEL, Scope
10-
from devito.ir.support.guards import GuardFactorEq
1110
from devito.mpi.halo_scheme import HaloScheme
1211
from devito.mpi.reduction_scheme import DistReduce
1312
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
@@ -78,8 +77,7 @@ def _hoist_redundant_from_conditionals(iet):
7877
cond_mapper = _make_cond_mapper(iet)
7978
iter_mapper = _filter_iter_mapper(iet)
8079

81-
hsmapper = defaultdict(list)
82-
imapper = defaultdict(list)
80+
mapper = HaloSpotMapper()
8381
for it, halo_spots in iter_mapper.items():
8482
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
8583

@@ -107,16 +105,10 @@ def _hoist_redundant_from_conditionals(iet):
107105
# No candidate found, skip
108106
continue
109107

110-
hsmapper[hs0].append(f)
111-
imapper[condition].append(hsf0)
108+
mapper.drop(hs0, f)
109+
mapper.add(condition, hsf0)
112110

113-
# Transform the IET according to the analysis
114-
#TODO: MOVE THIS INTO UTILITY
115-
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
116-
for hs, functions in hsmapper.items()}
117-
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
118-
for i, hss in imapper.items()})
119-
iet = Transformer(mapper, nested=True).visit(iet)
111+
iet = mapper.apply(iet)
120112

121113
return iet
122114

@@ -143,7 +135,7 @@ def _merge_halospots(iet):
143135
cond_mapper = _make_cond_mapper(iet)
144136
iter_mapper = _filter_iter_mapper(iet)
145137

146-
mapper = {}
138+
mapper = HaloSpotMapper()
147139
for it, halo_spots in iter_mapper.items():
148140
for hs0, hs1 in combinations(halo_spots, r=2):
149141
if _check_control_flow(hs0, hs1, cond_mapper):
@@ -152,24 +144,21 @@ def _merge_halospots(iet):
152144
scope = _derive_scope(it, hs0, hs1)
153145

154146
for f in hs1.fmapper:
155-
hsf0 = mapper.get(hs0, hs0.halo_scheme)
156-
hsf1 = mapper.get(hs1, hs1.halo_scheme).project(f)
147+
hsf0 = mapper.get(hs0).halo_scheme
148+
hsf1 = mapper.get(hs1).halo_scheme.project(f)
157149
if not _is_mergeable(hsf0, hsf1, scope):
158150
continue
159151

160152
# All good -- `hsf1` can be merged within `hs0`
161-
mapper[hs0] = HaloScheme.union([hsf0, hsf1])
153+
mapper.add(hs0, hsf1)
162154

163155
# If the `loc_indices` differ, we rely on hoisting to optimize
164156
# `hsf1` out of `it`, otherwise we just drop it
165157
if hsf0.loc_values != hsf1.loc_values:
166158
continue
167-
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
159+
mapper.drop(hs1, f)
168160

169-
# Transform the IET according to the analysis
170-
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
171-
for i, hs in mapper.items()}
172-
iet = Transformer(mapper, nested=True).visit(iet)
161+
iet = mapper.apply(iet)
173162

174163
return iet
175164

@@ -195,8 +184,7 @@ def _hoist_invariant(iet):
195184
cond_mapper = _make_cond_mapper(iet)
196185
iter_mapper = _filter_iter_mapper(iet)
197186

198-
hsmapper = defaultdict(list)
199-
imapper = defaultdict(list)
187+
mapper = HaloSpotMapper()
200188
for it, halo_spots in iter_mapper.items():
201189
for hs0, hs1 in combinations(halo_spots, r=2):
202190
if _check_control_flow(hs0, hs1, cond_mapper):
@@ -232,16 +220,10 @@ def _hoist_invariant(iet):
232220
loc_indices[d] = v
233221
hhs = hsf.drop(f).add(f, hse._rebuild(loc_indices=loc_indices))
234222

235-
hsmapper[hs].append(f)
236-
imapper[it].append(hhs)
223+
mapper.drop(hs, f)
224+
mapper.add(it, hhs)
237225

238-
# Transform the IET according to the analysis
239-
#TODO: MOVE THIS INTO UTILITY
240-
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
241-
for hs, functions in hsmapper.items()}
242-
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
243-
for i, hss in imapper.items()})
244-
iet = Transformer(mapper, nested=True).visit(iet)
226+
iet = mapper.apply(iet)
245227

246228
return iet
247229

@@ -408,6 +390,46 @@ def mpiize(graph, **kwargs):
408390

409391
# *** Utilities
410392

393+
394+
class HaloSpotMapper(dict):
395+
396+
def get(self, hs):
397+
return super().get(hs, hs)
398+
399+
def drop(self, hs, functions):
400+
"""
401+
Drop `functions` from the HaloSpot `hs`.
402+
"""
403+
v = self.get(hs)
404+
hss = v.halo_scheme.drop(functions)
405+
self[hs] = hs._rebuild(halo_scheme=hss)
406+
407+
def add(self, node, hss):
408+
"""
409+
Add the HaloScheme `hss` to `node`:
410+
411+
* If `node` is a HaloSpot, then `hss` is added to its
412+
existing HaloSchemes;
413+
* Otherwise, a HaloSpot is created wrapping `node`, and `hss`
414+
is added to it.
415+
"""
416+
v = self.get(node)
417+
if isinstance(v, HaloSpot):
418+
hss = HaloScheme.union([v.halo_scheme, hss])
419+
hs = v._rebuild(halo_scheme=hss)
420+
else:
421+
hs = HaloSpot(v._rebuild(), hss)
422+
self[node] = hs
423+
424+
def apply(self, iet):
425+
"""
426+
Transform `iet` using the HaloSpotMapper.
427+
"""
428+
mapper = {i: i.body if hs.is_void else hs for i, hs in self.items()}
429+
iet = Transformer(mapper, nested=True).visit(iet)
430+
return iet
431+
432+
411433
def _filter_iter_mapper(iet):
412434
"""
413435
Given an IET, return a mapper from Iterations to the HaloSpots.
@@ -426,12 +448,8 @@ def _make_cond_mapper(iet):
426448
"""
427449
Return a mapper from HaloSpots to the Conditionals that contain them.
428450
"""
429-
#TODO
430-
cond_mapper = {}
431-
for hs, v in MapHaloSpots().visit(iet).items():
432-
cond_mapper[hs] = tuple(i for i in v if i.is_Conditional)
433-
434-
return cond_mapper
451+
return {hs: tuple(i for i in v if i.is_Conditional)
452+
for hs, v in MapHaloSpots().visit(iet).items()}
435453

436454

437455
def _derive_scope(it, hs0, hs1):

0 commit comments

Comments
 (0)