Skip to content

Commit fcdd6a4

Browse files
FabioLuporinimloubout
authored andcommitted
compiler: Cleanup MPI passes
1 parent 09bbb1f commit fcdd6a4

4 files changed

Lines changed: 80 additions & 42 deletions

File tree

devito/ir/iet/nodes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,8 +1516,8 @@ def arguments(self):
15161516
return self.halo_scheme.arguments
15171517

15181518
@property
1519-
def is_empty(self):
1520-
return len(self.halo_scheme) == 0
1519+
def is_void(self):
1520+
return self.halo_scheme.is_void
15211521

15221522
@property
15231523
def body(self):

devito/passes/iet/mpi.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
FindWithin, MapNodes, MapHaloSpots, Transformer,
88
retrieve_iteration_tree)
99
from devito.ir.support import PARALLEL, Scope
10-
from devito.ir.support.guards import GuardFactorEq
1110
from devito.mpi.halo_scheme import HaloScheme
1211
from devito.mpi.reduction_scheme import DistReduce
1312
from devito.mpi.routines import HaloExchangeBuilder, ReductionBuilder
@@ -78,8 +77,7 @@ def _hoist_redundant_from_conditionals(iet):
7877
cond_mapper = _make_cond_mapper(iet)
7978
iter_mapper = _filter_iter_mapper(iet)
8079

81-
hsmapper = defaultdict(list)
82-
imapper = defaultdict(list)
80+
mapper = HaloSpotMapper()
8381
for it, halo_spots in iter_mapper.items():
8482
scope = Scope([e.expr for e in FindNodes(Expression).visit(it)])
8583

@@ -107,16 +105,10 @@ def _hoist_redundant_from_conditionals(iet):
107105
# No candidate found, skip
108106
continue
109107

110-
hsmapper[hs0].append(f)
111-
imapper[condition].append(hsf0)
108+
mapper.drop(hs0, f)
109+
mapper.add(condition, hsf0)
112110

113-
# Transform the IET according to the analysis
114-
#TODO: MOVE THIS INTO UTILITY
115-
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
116-
for hs, functions in hsmapper.items()}
117-
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
118-
for i, hss in imapper.items()})
119-
iet = Transformer(mapper, nested=True).visit(iet)
111+
iet = mapper.apply(iet)
120112

121113
return iet
122114

@@ -143,7 +135,7 @@ def _merge_halospots(iet):
143135
cond_mapper = _make_cond_mapper(iet)
144136
iter_mapper = _filter_iter_mapper(iet)
145137

146-
mapper = {}
138+
mapper = HaloSpotMapper()
147139
for it, halo_spots in iter_mapper.items():
148140
for hs0, hs1 in combinations(halo_spots, r=2):
149141
if _check_control_flow(hs0, hs1, cond_mapper):
@@ -152,24 +144,21 @@ def _merge_halospots(iet):
152144
scope = _derive_scope(it, hs0, hs1)
153145

154146
for f in hs1.fmapper:
155-
hsf0 = mapper.get(hs0, hs0.halo_scheme)
156-
hsf1 = mapper.get(hs1, hs1.halo_scheme).project(f)
147+
hsf0 = mapper.get(hs0).halo_scheme
148+
hsf1 = mapper.get(hs1).halo_scheme.project(f)
157149
if not _is_mergeable(hsf0, hsf1, scope):
158150
continue
159151

160152
# All good -- `hsf1` can be merged within `hs0`
161-
mapper[hs0] = HaloScheme.union([hsf0, hsf1])
153+
mapper.add(hs0, hsf1)
162154

163155
# If the `loc_indices` differ, we rely on hoisting to optimize
164156
# `hsf1` out of `it`, otherwise we just drop it
165157
if hsf0.loc_values != hsf1.loc_values:
166158
continue
167-
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)
159+
mapper.drop(hs1, f)
168160

169-
# Transform the IET according to the analysis
170-
mapper = {i: i.body if hs.is_void else i._rebuild(halo_scheme=hs)
171-
for i, hs in mapper.items()}
172-
iet = Transformer(mapper, nested=True).visit(iet)
161+
iet = mapper.apply(iet)
173162

174163
return iet
175164

@@ -195,8 +184,7 @@ def _hoist_invariant(iet):
195184
cond_mapper = _make_cond_mapper(iet)
196185
iter_mapper = _filter_iter_mapper(iet)
197186

198-
hsmapper = defaultdict(list)
199-
imapper = defaultdict(list)
187+
mapper = HaloSpotMapper()
200188
for it, halo_spots in iter_mapper.items():
201189
for hs0, hs1 in combinations(halo_spots, r=2):
202190
if _check_control_flow(hs0, hs1, cond_mapper):
@@ -232,16 +220,10 @@ def _hoist_invariant(iet):
232220
loc_indices[d] = v
233221
hhs = hsf.drop(f).add(f, hse._rebuild(loc_indices=loc_indices))
234222

235-
hsmapper[hs].append(f)
236-
imapper[it].append(hhs)
223+
mapper.drop(hs, f)
224+
mapper.add(it, hhs)
237225

238-
# Transform the IET according to the analysis
239-
#TODO: MOVE THIS INTO UTILITY
240-
mapper = {hs: hs._rebuild(halo_scheme=hs.halo_scheme.drop(functions))
241-
for hs, functions in hsmapper.items()}
242-
mapper.update({i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
243-
for i, hss in imapper.items()})
244-
iet = Transformer(mapper, nested=True).visit(iet)
226+
iet = mapper.apply(iet)
245227

246228
return iet
247229

@@ -408,6 +390,46 @@ def mpiize(graph, **kwargs):
408390

409391
# *** Utilities
410392

393+
394+
class HaloSpotMapper(dict):
395+
396+
def get(self, hs):
397+
return super().get(hs, hs)
398+
399+
def drop(self, hs, functions):
400+
"""
401+
Drop `functions` from the HaloSpot `hs`.
402+
"""
403+
v = self.get(hs)
404+
hss = v.halo_scheme.drop(functions)
405+
self[hs] = hs._rebuild(halo_scheme=hss)
406+
407+
def add(self, node, hss):
408+
"""
409+
Add the HaloScheme `hss` to `node`:
410+
411+
* If `node` is a HaloSpot, then `hss` is added to its
412+
existing HaloSchemes;
413+
* Otherwise, a HaloSpot is created wrapping `node`, and `hss`
414+
is added to it.
415+
"""
416+
v = self.get(node)
417+
if isinstance(v, HaloSpot):
418+
hss = HaloScheme.union([v.halo_scheme, hss])
419+
hs = v._rebuild(halo_scheme=hss)
420+
else:
421+
hs = HaloSpot(v._rebuild(), hss)
422+
self[node] = hs
423+
424+
def apply(self, iet):
425+
"""
426+
Transform `iet` using the HaloSpotMapper.
427+
"""
428+
mapper = {i: i.body if hs.is_void else hs for i, hs in self.items()}
429+
iet = Transformer(mapper, nested=True).visit(iet)
430+
return iet
431+
432+
411433
def _filter_iter_mapper(iet):
412434
"""
413435
Given an IET, return a mapper from Iterations to the HaloSpots.
@@ -426,12 +448,8 @@ def _make_cond_mapper(iet):
426448
"""
427449
Return a mapper from HaloSpots to the Conditionals that contain them.
428450
"""
429-
#TODO
430-
cond_mapper = {}
431-
for hs, v in MapHaloSpots().visit(iet).items():
432-
cond_mapper[hs] = tuple(i for i in v if i.is_Conditional)
433-
434-
return cond_mapper
451+
return {hs: tuple(i for i in v if i.is_Conditional)
452+
for hs, v in MapHaloSpots().visit(iet).items()}
435453

436454

437455
def _derive_scope(it, hs0, hs1):

tests/test_dle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def test_cache_blocking_structure_distributed(mode):
169169
eqns += [Eq(U.forward, U.dx + u.forward)]
170170

171171
op = Operator(eqns)
172+
op.cfunction
172173

173174
bns0, _ = assert_blocking(op._func_table['compute0'].root, {'x0_blk0'})
174-
bns1, _ = assert_blocking(op, {'x1_blk0'})
175+
bns1, _ = assert_blocking(op._func_table['compute2'].root, {'x1_blk0'})
175176

176177
for i in [bns0['x0_blk0'], bns1['x1_blk0']]:
177178
iters = FindNodes(Iteration).visit(i)

tests/test_mpi.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ def test_avoid_haloupdate_if_flowdep_along_other_dim(self, mode):
14371437
rtol=R)
14381438

14391439
@pytest.mark.parallel(mode=2)
1440-
def test_unmerge_haloupdate_if_no_locindices(self, mode):
1440+
def test_avoid_merging_if_no_locindices(self, mode):
14411441
grid = Grid(shape=(10,))
14421442
x = grid.dimensions[0]
14431443
t = grid.stepping_dim
@@ -1477,6 +1477,25 @@ def test_unmerge_haloupdate_if_no_locindices(self, mode):
14771477
assert np.allclose(f.data_ro_domain[5:], [8., 8., 8., 6., 5.], rtol=R)
14781478
assert np.allclose(g.data_ro_domain[0, 5:], [16., 16., 14., 13., 6.], rtol=R)
14791479

1480+
@pytest.mark.parallel(mode=1)
1481+
def test_avoid_merging_if_diff_functions(self, mode):
1482+
grid = Grid(shape=(4, 4, 4))
1483+
x, y, z = grid.dimensions
1484+
1485+
u = TimeFunction(name="u", grid=grid, space_order=2)
1486+
U = TimeFunction(name="U", grid=grid, space_order=2)
1487+
src = SparseTimeFunction(name="src", grid=grid, nt=3, npoint=1,
1488+
coordinates=np.array([(0.5, 0.5, 0.5)]))
1489+
1490+
eqns = [Eq(u.forward, u.dx)]
1491+
eqns += src.inject(field=u.forward, expr=src)
1492+
eqns += [Eq(U.forward, U.dx + u.forward)]
1493+
1494+
op = Operator(eqns)
1495+
op.cfunction
1496+
1497+
check_halo_exchanges(op, 2, 2)
1498+
14801499
@pytest.mark.parallel(mode=1)
14811500
def test_merge_haloupdate_if_diff_locindices(self, mode):
14821501
grid = Grid(shape=(101, 101))

0 commit comments

Comments
 (0)