@@ -56,73 +56,81 @@ def _drop_reduction_halospots(iet):
5656
5757def _hoist_invariant (iet ):
5858 """
59- Hoist HaloSpots from inner to outer Iterations where all data dependencies
60- would be honored. This pass avoids redundant halo exchanges when the same
61- data is redundantly exchanged within the same Iteration tree level.
59+ Hoist redundant HaloSpots out of Iterations.
6260
6361 Example:
64- haloupd v[t0]
65- for time for time
66- W v[t1]- R v[t0] W v[t1]- R v[t0]
67- haloupd v[t1] haloupd v[t1]
68- R v[t1] R v[t1]
69- haloupd v[t0] R v[t0]
70- R v[t0]
7162
63+ haloupd v[t0]
64+ haloupd v[t0]'
65+ for time for time
66+ haloupd v[t0]
67+ W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
68+ haloupd v[t1] haloupd v[t1]
69+ R v[t1] R v[t1]
70+ haloupd v[t0]'
71+ R v[t0] R v[t0]
72+
73+ Where `haloupd v[t0]` and `haloupd v[t0]'` are subsets of `haloupd v[t1]`.
7274 """
73-
74- # Precompute scopes to save time
75- scopes = {i : Scope ([e .expr for e in v ]) for i , v in MapNodes ().visit (iet ).items ()}
76-
77- # Analysis
78- hsmapper = {}
79- imapper = defaultdict (list )
80-
8175 cond_mapper = _make_cond_mapper (iet )
8276 iter_mapper = _filter_iter_mapper (iet )
8377
78+ hsmapper = defaultdict (list )
79+ imapper = defaultdict (list )
8480 for it , halo_spots in iter_mapper .items ():
81+ scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
82+
8583 for hs0 , hs1 in combinations (halo_spots , r = 2 ):
8684 if _check_control_flow (hs0 , hs1 , cond_mapper ):
8785 continue
8886
89- # If there are overlapping loc_indices, skip
90- hs0_mdims = hs0 .halo_scheme .loc_values
91- hs1_mdims = hs1 .halo_scheme .loc_values
92- if hs0_mdims .intersection (hs1_mdims ):
87+ functions = set (hs0 .fmapper ) & set (hs1 .fmapper )
88+ if not functions :
89+ continue
90+ functions = sorted (functions , key = lambda f : f .name ) # Determinism
91+
92+ hsf0 = hs0 .halo_scheme .project (functions )
93+ hsf1 = hs1 .halo_scheme .project (functions )
94+
95+ # NOTE: lexicographic order is important here, as if both `hsf0` and
96+ # `hsf1` are iteration-carried, we must hoist the later one in the
97+ # control flow (that is `hsf1`)
98+ flag0 = hsf1 .issubset (hsf0 ) and _is_iter_carried (hsf1 , scope )
99+ flag1 = hsf0 .issubset (hsf1 ) and _is_iter_carried (hsf0 , scope )
100+ if flag0 and flag1 :
101+ # Special case: `_merge_halospots` will merge `hsf0` and `hsf1`
102+ # anyway, so we can avoid a useless hoisting
103+ continue
104+ elif flag0 :
105+ hs , hsf = hs1 , hsf1
106+ elif flag1 :
107+ hs , hsf = hs0 , hsf0
108+ else :
109+ # No hoisting possible, skip
93110 continue
94111
95- for f , v in hs1 .fmapper .items ():
96- if f not in hs0 .functions :
97- continue
112+ # At this point, we must infer valid loc_indices for `found`
113+ hhs = hsf .drop (functions ) # hhs -> hoisted HaloScheme
114+ for f , hse in hsf .fmapper .items ():
115+ loc_indices = {}
116+ for d , v in hse .loc_indices .items ():
117+ if v in it .uindices :
118+ loc_indices [d ] = v .symbolic_min .subs (it .dim , it .start )
119+ else :
120+ loc_indices [d ] = v
121+
122+ hhs = hhs .add (f , hse ._rebuild (loc_indices = loc_indices ))
123+
124+ hsmapper [hs ].extend (functions )
125+ imapper [it ].append (hhs )
98126
99- for dep in scopes [it ].d_flow .project (f ):
100- if not any (r (dep , hs1 , v .loc_indices ) for r in rules ):
101- break
102- else :
103- # `hs1`` can be hoisted out of `it`, but we need to infer valid
104- # loc_indices
105- hse = hs1 .halo_scheme .fmapper [f ]
106- loc_indices = {}
107-
108- for d , v in hse .loc_indices .items ():
109- if v in it .uindices :
110- loc_indices [d ] = v .symbolic_min .subs (it .dim , it .start )
111- else :
112- loc_indices [d ] = v
113-
114- hse = hse ._rebuild (loc_indices = loc_indices )
115- hs1 .halo_scheme .fmapper [f ] = hse
116-
117- hsmapper [hs1 ] = hsmapper .get (hs1 , hs1 .halo_scheme ).drop (f )
118- imapper [it ].append (hs1 .halo_scheme .project (f ))
119-
120- mapper = {i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
121- for i , hss in imapper .items ()}
122- mapper .update ({i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
123- for i , hs in hsmapper .items ()})
124127 # Transform the IET hoisting/dropping HaloSpots as according to the analysis
128+ mapper = {hs : hs ._rebuild (halo_scheme = hs .halo_scheme .drop (functions ))
129+ for hs , functions in hsmapper .items ()}
130+ mapper .update ({i : HaloSpot (i ._rebuild (), HaloScheme .union (hss ))
131+ for i , hss in imapper .items ()})
125132 iet = Transformer (mapper , nested = True ).visit (iet )
133+
126134 return iet
127135
128136
@@ -134,18 +142,16 @@ def _merge_halospots(iet):
134142
135143 Example:
136144
137- for time for time
138- haloupd v[t0] haloupd v[t0], h
139- W v[t1]- R v[t0] W v[t1]- R v[t0]
145+ for time for time
146+ haloupd v[t0] haloupd v[t0], h
147+ W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
140148 haloupd v[t0], h
141- W g[t1]- R v[t0], h W g[t1]- R v[t0], h
149+ W g[t1]- R v[t0], h W g[t1]- R v[t0], h
142150 """
143-
144- # Analysis
145- mapper = {}
146151 cond_mapper = _make_cond_mapper (iet )
147152 iter_mapper = _filter_iter_mapper (iet )
148153
154+ mapper = {}
149155 for it , halo_spots in iter_mapper .items ():
150156 scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
151157
@@ -157,19 +163,19 @@ def _merge_halospots(iet):
157163
158164 for f , v in hs1 .fmapper .items ():
159165 for dep in scope .d_flow .project (f ):
160- if not any (rule (dep , hs1 , v .loc_indices ) for rule in rules ):
166+ if not any (r (dep , hs1 , v .loc_indices ) for r in rules ):
161167 break
162168 else :
163- # All good -- `hs1` can be merged with `hs0`
164- hs = hs1 .halo_scheme .project (f )
165- mapper [hs0 ] = HaloScheme .union ([mapper .get (hs0 , hs0 .halo_scheme ), hs ])
169+ # All good -- `hsf1` can be merged within `hs0`
170+ hsf1 = hs1 .halo_scheme .project (f )
171+ mapper [hs0 ] = HaloScheme .union (
172+ [mapper .get (hs0 , hs0 .halo_scheme ), hsf1 ]
173+ )
166174 mapper [hs1 ] = mapper .get (hs1 , hs1 .halo_scheme ).drop (f )
167175
168- # Post-process analysis
176+ # Transform the IET merging/dropping HaloSpots as according to the analysis
169177 mapper = {i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
170178 for i , hs in mapper .items ()}
171-
172- # Transform the IET merging/dropping HaloSpots as according to the analysis
173179 iet = Transformer (mapper , nested = True ).visit (iet )
174180
175181 return iet
@@ -375,6 +381,23 @@ def _check_control_flow(hs0, hs1, cond_mapper):
375381 return cond0 != cond1
376382
377383
384+ def _is_iter_carried (hs , scope ):
385+ """
386+ True if the HaloScheme `hs` is iteration-carried, i.e., if the exchanged halo
387+ *only* uses values computed in a past iteration, False otherwise.
388+ """
389+ for f , hse in hs .fmapper .items ():
390+ try :
391+ writes = scope .writes [f ]
392+ except KeyError :
393+ continue
394+
395+ if any (hse .loc_values .intersection (w .aindices ) for w in writes ):
396+ return False
397+
398+ return True
399+
400+
378401# Code motion rules -- if the retval is True, then it means the input `dep` is not
379402# a stopper to moving the HaloSpot `hs` around
380403
0 commit comments