77 FindWithin , MapNodes , MapHaloSpots , Transformer ,
88 retrieve_iteration_tree )
99from devito .ir .support import PARALLEL , Scope
10- from devito .ir .support .guards import GuardFactorEq
1110from devito .mpi .halo_scheme import HaloScheme
1211from devito .mpi .reduction_scheme import DistReduce
1312from devito .mpi .routines import HaloExchangeBuilder , ReductionBuilder
@@ -78,8 +77,7 @@ def _hoist_redundant_from_conditionals(iet):
7877 cond_mapper = _make_cond_mapper (iet )
7978 iter_mapper = _filter_iter_mapper (iet )
8079
81- hsmapper = defaultdict (list )
82- imapper = defaultdict (list )
80+ mapper = HaloSpotMapper ()
8381 for it , halo_spots in iter_mapper .items ():
8482 scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
8583
@@ -107,16 +105,10 @@ def _hoist_redundant_from_conditionals(iet):
107105 # No candidate found, skip
108106 continue
109107
110- hsmapper [ hs0 ]. append ( f )
111- imapper [ condition ]. append ( hsf0 )
108+ mapper . drop ( hs0 , f )
109+ mapper . add ( condition , hsf0 )
112110
113- # Transform the IET according to the analysis
114- #TODO: MOVE THIS INTO UTILITY
115- mapper = {hs : hs ._rebuild (halo_scheme = hs .halo_scheme .drop (functions ))
116- for hs , functions in hsmapper .items ()}
117- mapper .update ({i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
118- for i , hss in imapper .items ()})
119- iet = Transformer (mapper , nested = True ).visit (iet )
111+ iet = mapper .apply (iet )
120112
121113 return iet
122114
@@ -143,7 +135,7 @@ def _merge_halospots(iet):
143135 cond_mapper = _make_cond_mapper (iet )
144136 iter_mapper = _filter_iter_mapper (iet )
145137
146- mapper = {}
138+ mapper = HaloSpotMapper ()
147139 for it , halo_spots in iter_mapper .items ():
148140 for hs0 , hs1 in combinations (halo_spots , r = 2 ):
149141 if _check_control_flow (hs0 , hs1 , cond_mapper ):
@@ -152,24 +144,21 @@ def _merge_halospots(iet):
152144 scope = _derive_scope (it , hs0 , hs1 )
153145
154146 for f in hs1 .fmapper :
155- hsf0 = mapper .get (hs0 , hs0 .halo_scheme )
156- hsf1 = mapper .get (hs1 , hs1 .halo_scheme ) .project (f )
147+ hsf0 = mapper .get (hs0 ) .halo_scheme
148+ hsf1 = mapper .get (hs1 ) .halo_scheme .project (f )
157149 if not _is_mergeable (hsf0 , hsf1 , scope ):
158150 continue
159151
160152 # All good -- `hsf1` can be merged within `hs0`
161- mapper [ hs0 ] = HaloScheme . union ([ hsf0 , hsf1 ] )
153+ mapper . add ( hs0 , hsf1 )
162154
163155 # If the `loc_indices` differ, we rely on hoisting to optimize
164156 # `hsf1` out of `it`, otherwise we just drop it
165157 if hsf0 .loc_values != hsf1 .loc_values :
166158 continue
167- mapper [ hs1 ] = mapper . get (hs1 , hs1 . halo_scheme ). drop ( f )
159+ mapper . drop (hs1 , f )
168160
169- # Transform the IET according to the analysis
170- mapper = {i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
171- for i , hs in mapper .items ()}
172- iet = Transformer (mapper , nested = True ).visit (iet )
161+ iet = mapper .apply (iet )
173162
174163 return iet
175164
@@ -195,8 +184,7 @@ def _hoist_invariant(iet):
195184 cond_mapper = _make_cond_mapper (iet )
196185 iter_mapper = _filter_iter_mapper (iet )
197186
198- hsmapper = defaultdict (list )
199- imapper = defaultdict (list )
187+ mapper = HaloSpotMapper ()
200188 for it , halo_spots in iter_mapper .items ():
201189 for hs0 , hs1 in combinations (halo_spots , r = 2 ):
202190 if _check_control_flow (hs0 , hs1 , cond_mapper ):
@@ -232,16 +220,10 @@ def _hoist_invariant(iet):
232220 loc_indices [d ] = v
233221 hhs = hsf .drop (f ).add (f , hse ._rebuild (loc_indices = loc_indices ))
234222
235- hsmapper [ hs ]. append ( f )
236- imapper [ it ]. append ( hhs )
223+ mapper . drop ( hs , f )
224+ mapper . add ( it , hhs )
237225
238- # Transform the IET according to the analysis
239- #TODO: MOVE THIS INTO UTILITY
240- mapper = {hs : hs ._rebuild (halo_scheme = hs .halo_scheme .drop (functions ))
241- for hs , functions in hsmapper .items ()}
242- mapper .update ({i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
243- for i , hss in imapper .items ()})
244- iet = Transformer (mapper , nested = True ).visit (iet )
226+ iet = mapper .apply (iet )
245227
246228 return iet
247229
@@ -408,6 +390,46 @@ def mpiize(graph, **kwargs):
408390
409391# *** Utilities
410392
393+
394+ class HaloSpotMapper (dict ):
395+
396+ def get (self , hs ):
397+ return super ().get (hs , hs )
398+
399+ def drop (self , hs , functions ):
400+ """
401+ Drop `functions` from the HaloSpot `hs`.
402+ """
403+ v = self .get (hs )
404+ hss = v .halo_scheme .drop (functions )
405+ self [hs ] = hs ._rebuild (halo_scheme = hss )
406+
407+ def add (self , node , hss ):
408+ """
409+ Add the HaloScheme `hss` to `node`:
410+
411+ * If `node` is a HaloSpot, then `hss` is added to its
412+ existing HaloSchemes;
413+ * Otherwise, a HaloSpot is created wrapping `node`, and `hss`
414+ is added to it.
415+ """
416+ v = self .get (node )
417+ if isinstance (v , HaloSpot ):
418+ hss = HaloScheme .union ([v .halo_scheme , hss ])
419+ hs = v ._rebuild (halo_scheme = hss )
420+ else :
421+ hs = HaloSpot (v ._rebuild (), hss )
422+ self [node ] = hs
423+
424+ def apply (self , iet ):
425+ """
426+ Transform `iet` using the HaloSpotMapper.
427+ """
428+ mapper = {i : i .body if hs .is_void else hs for i , hs in self .items ()}
429+ iet = Transformer (mapper , nested = True ).visit (iet )
430+ return iet
431+
432+
411433def _filter_iter_mapper (iet ):
412434 """
413435 Given an IET, return a mapper from Iterations to the HaloSpots.
@@ -426,12 +448,8 @@ def _make_cond_mapper(iet):
426448 """
427449 Return a mapper from HaloSpots to the Conditionals that contain them.
428450 """
429- #TODO
430- cond_mapper = {}
431- for hs , v in MapHaloSpots ().visit (iet ).items ():
432- cond_mapper [hs ] = tuple (i for i in v if i .is_Conditional )
433-
434- return cond_mapper
451+ return {hs : tuple (i for i in v if i .is_Conditional )
452+ for hs , v in MapHaloSpots ().visit (iet ).items ()}
435453
436454
437455def _derive_scope (it , hs0 , hs1 ):
0 commit comments