44from itertools import combinations
55
66from devito .ir .iet import (Call , Expression , HaloSpot , Iteration , FindNodes ,
7- MapNodes , MapHaloSpots , Transformer ,
7+ FindWithin , MapNodes , MapHaloSpots , Transformer ,
88 retrieve_iteration_tree )
99from devito .ir .support import PARALLEL , Scope
1010from devito .ir .support .guards import GuardFactorEq
@@ -24,6 +24,7 @@ def optimize_halospots(iet, **kwargs):
2424 merged and moved around in order to improve the halo exchange performance.
2525 """
2626 iet = _drop_reduction_halospots (iet )
27+ iet = _hoist_redundant_from_conditionals (iet )
2728 iet = _merge_halospots (iet )
2829 iet = _hoist_invariant (iet )
2930 iet = _drop_if_unwritten (iet , ** kwargs )
@@ -54,6 +55,72 @@ def _drop_reduction_halospots(iet):
5455 return iet
5556
5657
58+ def _hoist_redundant_from_conditionals (iet ):
59+ """
60+ Hoist redundant HaloSpots from Conditionals. The idea is that, by doing so,
61+ the subsequent passes (hoist, merge) will be able to optimize them away.
62+
63+ Examples
64+ --------
65+
66+ for time for time
67+ if(cond) if(cond)
68+ haloupd u[t0] haloupd u[t0]
69+ haloupd v[t0] haloupd v[t0]
70+ W v[t1]- R u[t0],v[t0] ---> W v[t1]- R u[t0],v[t0]
71+ haloupd v[t0] haloupd v[t0]
72+ R v[t0] R v[t0]
73+
74+ Note that in the above example, the `haloupd v[t0]` in the if branch is
75+ redundant, as it also appears later on in the Iteration body, so it is
76+ hoisted out of the Conditional.
77+ """
78+ cond_mapper = _make_cond_mapper (iet )
79+ iter_mapper = _filter_iter_mapper (iet )
80+
81+ hsmapper = defaultdict (list )
82+ imapper = defaultdict (list )
83+ for it , halo_spots in iter_mapper .items ():
84+ scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
85+
86+ for hs0 in halo_spots :
87+ conditions = cond_mapper [hs0 ]
88+ if not conditions :
89+ continue
90+ condition = conditions [- 1 ] # Take the innermost Conditional
91+
92+ for f in hs0 .fmapper :
93+ hsf0 = hs0 .halo_scheme .project (f )
94+
95+ # Find candidate for subsequent merging
96+ for hs1 in halo_spots :
97+ if hs0 is hs1 or cond_mapper [hs1 ]:
98+ continue
99+
100+ hsf1 = hs1 .halo_scheme .project (f )
101+ if not _is_mergeable (hsf1 , hsf0 , scope ) or \
102+ not hsf1 .issubset (hsf0 ):
103+ continue
104+
105+ break
106+ else :
107+ # No candidate found, skip
108+ continue
109+
110+ hsmapper [hs0 ].append (f )
111+ imapper [condition ].append (hsf0 )
112+
113+ # Transform the IET according to the analysis
114+ #TODO: MOVE THIS INTO UTILITY
115+ mapper = {hs : hs ._rebuild (halo_scheme = hs .halo_scheme .drop (functions ))
116+ for hs , functions in hsmapper .items ()}
117+ mapper .update ({i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
118+ for i , hss in imapper .items ()})
119+ iet = Transformer (mapper , nested = True ).visit (iet )
120+
121+ return iet
122+
123+
57124def _merge_halospots (iet ):
58125 """
59126 Merge HaloSpots on the same IET level when data dependences allow it. This
@@ -78,31 +145,28 @@ def _merge_halospots(iet):
78145
79146 mapper = {}
80147 for it , halo_spots in iter_mapper .items ():
81- scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
82-
83- hs0 = halo_spots [0 ]
84-
85- for hs1 in halo_spots [1 :]:
148+ for hs0 , hs1 in combinations (halo_spots , r = 2 ):
86149 if _check_control_flow (hs0 , hs1 , cond_mapper ):
87150 continue
88151
89- for f , v in hs1 .fmapper .items ():
90- hsf0 = hs0 .halo_scheme .project (f )
91- hsf1 = hs1 .halo_scheme .project (f )
92- if hsf0 .loc_values != hsf1 .loc_values :
152+ scope = _derive_scope (it , hs0 , hs1 )
153+
154+ for f in hs1 .fmapper :
155+ hsf0 = mapper .get (hs0 , hs0 .halo_scheme )
156+ hsf1 = mapper .get (hs1 , hs1 .halo_scheme ).project (f )
157+ if not _is_mergeable (hsf0 , hsf1 , scope ):
93158 continue
94159
95- for dep in scope .d_flow .project (f ):
96- if not any (r (dep , hs1 , v .loc_indices ) for r in rules ):
97- break
98- else :
99- # All good -- `hsf1` can be merged within `hs0`
100- mapper [hs0 ] = HaloScheme .union (
101- [mapper .get (hs0 , hs0 .halo_scheme ), hsf1 ]
102- )
103- mapper [hs1 ] = mapper .get (hs1 , hs1 .halo_scheme ).drop (f )
160+ # All good -- `hsf1` can be merged within `hs0`
161+ mapper [hs0 ] = HaloScheme .union ([hsf0 , hsf1 ])
104162
105- # Transform the IET merging/dropping HaloSpots as according to the analysis
163+ # If the `loc_indices` differ, we rely on hoisting to optimize
164+ # `hsf1` out of `it`, otherwise we just drop it
165+ if hsf0 .loc_values != hsf1 .loc_values :
166+ continue
167+ mapper [hs1 ] = mapper .get (hs1 , hs1 .halo_scheme ).drop (f )
168+
169+ # Transform the IET according to the analysis
106170 mapper = {i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
107171 for i , hs in mapper .items ()}
108172 iet = Transformer (mapper , nested = True ).visit (iet )
@@ -116,10 +180,10 @@ def _hoist_invariant(iet):
116180
117181 Examples
118182 --------
119- There is one typical case in which hoisting is possible, i.e., when a
120- HaloSpot is iteration-carried, and it is a subset of another HaloSpot
121- that is not iteration-carried . In this case, we can hoist the former
122- out of the Iteration containing the latter, as follows:
183+ There is one typical case in which hoisting is possible, i.e., when a HaloSpot
184+ is iteration-carried, and it is a subset of another HaloSpot within the same
185+ Iteration . In this case, we can hoist the former out of the Iteration
186+ containing the latter, as follows:
123187
124188 haloupd v[t0]
125189 for time for time
@@ -138,38 +202,41 @@ def _hoist_invariant(iet):
138202 if _check_control_flow (hs0 , hs1 , cond_mapper ):
139203 continue
140204
141- functions = set (hs0 .fmapper ) & set (hs1 .fmapper )
142- if not functions :
143- continue
144- functions = sorted (functions , key = lambda f : f .name ) # Determinism
205+ scope = _derive_scope (it , hs0 , hs1 )
145206
146- hsf0 = hs0 .halo_scheme .project (functions )
147- hsf1 = hs1 .halo_scheme .project (functions )
207+ for f in hs1 .fmapper :
208+ hsf0 = hs0 .halo_scheme .project (f )
209+ if hsf0 .is_void :
210+ continue
211+ hsf1 = hs1 .halo_scheme .project (f )
148212
149- if hsf1 .issubset (hsf0 ):
150- hs , hsf = hs1 , hsf1
151- elif hsf0 .issubset (hsf1 ):
152- hs , hsf = hs0 , hsf0
153- else :
154- # No hoisting possible, skip
155- continue
213+ # Ensure there's another HaloScheme that could cover for
214+ # us should we get hoisted while still satisfying the
215+ # data dependences
216+ if hsf1 .issubset (hsf0 ) and _is_iter_carried (hsf1 , scope ):
217+ hs , hsf = hs1 , hsf1
218+ elif hsf0 .issubset (hsf1 ) and hs0 is halo_spots [0 ]:
219+ # Special case
220+ hs , hsf = hs0 , hsf0
221+ else :
222+ # No hoisting possible, skip
223+ continue
156224
157- # At this point, we must infer valid loc_indices
158- hhs = hsf .drop (functions ) # hhs -> hoisted HaloScheme
159- for f , hse in hsf .fmapper .items ():
225+ # At this point, we must infer valid loc_indices
226+ hse = hsf .fmapper [f ]
160227 loc_indices = {}
161228 for d , v in hse .loc_indices .items ():
162229 if v in it .uindices :
163230 loc_indices [d ] = v .symbolic_min .subs (it .dim , it .start )
164231 else :
165232 loc_indices [d ] = v
233+ hhs = hsf .drop (f ).add (f , hse ._rebuild (loc_indices = loc_indices ))
166234
167- hhs = hhs .add (f , hse ._rebuild (loc_indices = loc_indices ))
235+ hsmapper [hs ].append (f )
236+ imapper [it ].append (hhs )
168237
169- hsmapper [hs ].extend (functions )
170- imapper [it ].append (hhs )
171-
172- # Transform the IET hoisting/dropping HaloSpots as according to the analysis
238+ # Transform the IET according to the analysis
239+ #TODO: MOVE THIS INTO UTILITY
173240 mapper = {hs : hs ._rebuild (halo_scheme = hs .halo_scheme .drop (functions ))
174241 for hs , functions in hsmapper .items ()}
175242 mapper .update ({i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
@@ -359,15 +426,26 @@ def _make_cond_mapper(iet):
359426 """
360427 Return a mapper from HaloSpots to the Conditionals that contain them.
361428 """
429+ #TODO
362430 cond_mapper = {}
363431 for hs , v in MapHaloSpots ().visit (iet ).items ():
364- conditionals = {i for i in v if i .is_Conditional and
365- not isinstance (i .condition , GuardFactorEq )}
366- cond_mapper [hs ] = conditionals
432+ cond_mapper [hs ] = tuple (i for i in v if i .is_Conditional )
367433
368434 return cond_mapper
369435
370436
437+ def _derive_scope (it , hs0 , hs1 ):
438+ """
439+ Derive a Scope within the Iteration `it` that starts at the HaloSpot `hs0`
440+ and ends at the HaloSpot `hs1`.
441+ """
442+ expressions = FindWithin (Expression , hs0 , stop = hs1 ).visit (it )
443+ assert len (expressions ) > 0 , \
444+ "Expected at least one Expression between %s and %s" % (hs0 , hs1 )
445+
446+ return Scope ([e .expr for e in expressions ])
447+
448+
371449def _check_control_flow (hs0 , hs1 , cond_mapper ):
372450 """
373451 If there are Conditionals involved, both `hs0` and `hs1` must be
@@ -379,22 +457,48 @@ def _check_control_flow(hs0, hs1, cond_mapper):
379457 return cond0 != cond1
380458
381459
382- # Code motion rules -- if the retval is True, then it means the input `dep` is not
383- # a stopper to moving the HaloSpot `hs` around
460+ def _is_iter_carried (hsf , scope ):
461+ """
462+ True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces
463+ a halo exchange that requires values from the previous iteration(s); False
464+ otherwise.
465+ """
466+
467+ def rule0 (dep ):
468+ # E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>`, `d=t` => OK
469+ return not any (dep .distance_mapper [d ] is S .Infinity for d in dep .cause )
470+
471+ def rule1 (dep , loc_indices ):
472+ # E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>`, `loc_indices={t: t0}` => OK
473+ return any (dep .distance_mapper [d ] == 0 and
474+ dep .source [d ] is not v and
475+ dep .sink [d ] is not v
476+ for d , v in loc_indices .items ())
384477
385- def _rule0 ( dep , hs , loc_indices ):
386- # E.g., `dep=W<f,[t1, x]> -> R<f,[t0, x-1]>` => True
387- return not any ( d in hs . dimensions or dep . distance_mapper [ d ] is S . Infinity
388- for d in dep . cause )
478+ for f , v in hsf . fmapper . items ( ):
479+ for dep in scope . d_flow . project ( f ):
480+ if not rule0 ( dep ) and not rule1 ( dep , v . loc_indices ):
481+ return False
389482
483+ return True
484+
485+
486+ def _is_mergeable (hsf0 , hsf1 , scope ):
487+ """
488+ True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible
489+ and the data dependences would be satisfied, False otherwise.
490+ """
491+ # If `hsf1` is empty there's nothing to merge
492+ if hsf1 .is_void :
493+ return False
390494
391- #TODO: MAYBE AVOID PASSING IN LOC_INDICES AND PULL THEM STRAIGHT FROM HS
392- def _rule1 (dep , hs , loc_indices ):
393- # E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
394- return any (dep .distance_mapper [d ] == 0 and
395- dep .source [d ] is not v and
396- dep .sink [d ] is not v
397- for d , v in loc_indices .items ())
495+ # Ensure `hsf0` and `hsf1` are compatible
496+ if hsf0 .dimensions != hsf1 .dimensions or \
497+ not hsf0 .functions & hsf1 .functions :
498+ return False
398499
500+ # Then, check the data dependences would be satisfied
501+ if not _is_iter_carried (hsf1 , scope ):
502+ return False
399503
400- rules = ( _rule0 , _rule1 )
504+ return True
0 commit comments