@@ -24,8 +24,8 @@ def optimize_halospots(iet, **kwargs):
2424 merged and moved around in order to improve the halo exchange performance.
2525 """
2626 iet = _drop_reduction_halospots (iet )
27- iet = _hoist_invariant (iet )
2827 iet = _merge_halospots (iet )
28+ iet = _hoist_invariant (iet )
2929 iet = _drop_if_unwritten (iet , ** kwargs )
3030 iet = _mark_overlappable (iet )
3131
@@ -34,8 +34,8 @@ def optimize_halospots(iet, **kwargs):
3434
3535def _drop_reduction_halospots (iet ):
3636 """
37- Remove HaloSpots that are used to compute Increments
38- (in which case, a halo exchange is actually unnecessary)
37+ Remove HaloSpots that are used to compute Increments (in which case, a halo
38+ exchange is actually unnecessary)
3939 """
4040 mapper = defaultdict (set )
4141
@@ -54,32 +54,86 @@ def _drop_reduction_halospots(iet):
5454 return iet
5555
5656
57- def _hoist_invariant (iet ):
57+ def _merge_halospots (iet ):
5858 """
59- Hoist redundant HaloSpots out of Iterations.
59+ Merge HaloSpots on the same IET level when data dependences allow it. This
60+ has two effects: anticipating communication over computation, and (potentially)
61+ avoiding redundant halo exchanges.
6062
61- Example:
63+ Examples
64+ --------
65+ In the following example, we have two HaloSpots that both require a halo
66+ exchange for the same Function `v` at `t0`. Since `v[t0]` is not written
67+ to in the Iteration nest, we can merge the second HaloSpot into the first
68+ one, thus avoiding a redundant halo exchange.
6269
63- haloupd v[t0]
64- haloupd v[t0]'
6570 for time for time
66- haloupd v[t0]
67- W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
68- haloupd v[t1] haloupd v[t1]
69- R v[t1] R v[t1]
70- haloupd v[t0]'
71- R v[t0] R v[t0]
71+ haloupd v[t0] haloupd v[t0], h
72+ W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
73+ haloupd v[t0], h
74+ W g[t1]- R v[t0], h W g[t1]- R v[t0], h
75+ """
76+ cond_mapper = _make_cond_mapper (iet )
77+ iter_mapper = _filter_iter_mapper (iet )
78+
79+ mapper = {}
80+ for it , halo_spots in iter_mapper .items ():
81+ scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
82+
83+ hs0 = halo_spots [0 ]
84+
85+ for hs1 in halo_spots [1 :]:
86+ if _check_control_flow (hs0 , hs1 , cond_mapper ):
87+ continue
88+
89+ for f , v in hs1 .fmapper .items ():
90+ hsf0 = hs0 .halo_scheme .project (f )
91+ hsf1 = hs1 .halo_scheme .project (f )
92+ if hsf0 .loc_values != hsf1 .loc_values :
93+ continue
94+
95+ for dep in scope .d_flow .project (f ):
96+ if not any (r (dep , hs1 , v .loc_indices ) for r in rules ):
97+ break
98+ else :
99+ # All good -- `hsf1` can be merged within `hs0`
100+ mapper [hs0 ] = HaloScheme .union (
101+ [mapper .get (hs0 , hs0 .halo_scheme ), hsf1 ]
102+ )
103+ mapper [hs1 ] = mapper .get (hs1 , hs1 .halo_scheme ).drop (f )
104+
105+ # Transform the IET merging/dropping HaloSpots as according to the analysis
106+ mapper = {i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
107+ for i , hs in mapper .items ()}
108+ iet = Transformer (mapper , nested = True ).visit (iet )
109+
110+ return iet
111+
112+
113+ def _hoist_invariant (iet ):
114+ """
115+ Hoist iteration-carried HaloSpots out of Iterations.
116+
117+ Examples
118+ --------
119+ There is one typical case in which hoisting is possible, i.e., when a
120+ HaloSpot is iteration-carried, and it is a subset of another HaloSpot
121+ that is not iteration-carried. In this case, we can hoist the former
122+ out of the Iteration containing the latter, as follows:
72123
73- Where `haloupd v[t0]` and `haloupd v[t0]'` are subsets of `haloupd v[t1]`.
124+ haloupd v[t0]
125+ for time for time
126+ haloupd v[t0]
127+ W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
128+ haloupd v[t1] haloupd v[t1]
129+ R v[t1] R v[t1]
74130 """
75131 cond_mapper = _make_cond_mapper (iet )
76132 iter_mapper = _filter_iter_mapper (iet )
77133
78134 hsmapper = defaultdict (list )
79135 imapper = defaultdict (list )
80136 for it , halo_spots in iter_mapper .items ():
81- scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
82-
83137 for hs0 , hs1 in combinations (halo_spots , r = 2 ):
84138 if _check_control_flow (hs0 , hs1 , cond_mapper ):
85139 continue
@@ -92,24 +146,15 @@ def _hoist_invariant(iet):
92146 hsf0 = hs0 .halo_scheme .project (functions )
93147 hsf1 = hs1 .halo_scheme .project (functions )
94148
95- # NOTE: lexicographic order is important here, as if both `hsf0` and
96- # `hsf1` are iteration-carried, we must hoist the later one in the
97- # control flow (that is `hsf1`)
98- flag0 = hsf1 .issubset (hsf0 ) and _is_iter_carried (hsf1 , scope )
99- flag1 = hsf0 .issubset (hsf1 ) and _is_iter_carried (hsf0 , scope )
100- if flag0 and flag1 :
101- # Special case: `_merge_halospots` will merge `hsf0` and `hsf1`
102- # anyway, so we can avoid a useless hoisting
103- continue
104- elif flag0 :
149+ if hsf1 .issubset (hsf0 ):
105150 hs , hsf = hs1 , hsf1
106- elif flag1 :
151+ elif hsf0 . issubset ( hsf1 ) :
107152 hs , hsf = hs0 , hsf0
108153 else :
109154 # No hoisting possible, skip
110155 continue
111156
112- # At this point, we must infer valid loc_indices for `found`
157+ # At this point, we must infer valid loc_indices
113158 hhs = hsf .drop (functions ) # hhs -> hoisted HaloScheme
114159 for f , hse in hsf .fmapper .items ():
115160 loc_indices = {}
@@ -134,53 +179,6 @@ def _hoist_invariant(iet):
134179 return iet
135180
136181
137- def _merge_halospots (iet ):
138- """
139- Using data dependence analysis, merge HaloSpots on the same IET level. This
140- has two effects: anticipating communication over computation, and (potentially)
141- avoiding redundant halo exchanges.
142-
143- Example:
144-
145- for time for time
146- haloupd v[t0] haloupd v[t0], h
147- W v[t1]- R v[t0] ---> W v[t1]- R v[t0]
148- haloupd v[t0], h
149- W g[t1]- R v[t0], h W g[t1]- R v[t0], h
150- """
151- cond_mapper = _make_cond_mapper (iet )
152- iter_mapper = _filter_iter_mapper (iet )
153-
154- mapper = {}
155- for it , halo_spots in iter_mapper .items ():
156- scope = Scope ([e .expr for e in FindNodes (Expression ).visit (it )])
157-
158- hs0 = halo_spots [0 ]
159-
160- for hs1 in halo_spots [1 :]:
161- if _check_control_flow (hs0 , hs1 , cond_mapper ):
162- continue
163-
164- for f , v in hs1 .fmapper .items ():
165- for dep in scope .d_flow .project (f ):
166- if not any (r (dep , hs1 , v .loc_indices ) for r in rules ):
167- break
168- else :
169- # All good -- `hsf1` can be merged within `hs0`
170- hsf1 = hs1 .halo_scheme .project (f )
171- mapper [hs0 ] = HaloScheme .union (
172- [mapper .get (hs0 , hs0 .halo_scheme ), hsf1 ]
173- )
174- mapper [hs1 ] = mapper .get (hs1 , hs1 .halo_scheme ).drop (f )
175-
176- # Transform the IET merging/dropping HaloSpots as according to the analysis
177- mapper = {i : i .body if hs .is_void else i ._rebuild (halo_scheme = hs )
178- for i , hs in mapper .items ()}
179- iet = Transformer (mapper , nested = True ).visit (iet )
180-
181- return iet
182-
183-
184182def _drop_if_unwritten (iet , options = None , ** kwargs ):
185183 """
186184 Drop HaloSpots for unwritten Functions.
@@ -381,23 +379,6 @@ def _check_control_flow(hs0, hs1, cond_mapper):
381379 return cond0 != cond1
382380
383381
384- def _is_iter_carried (hs , scope ):
385- """
386- True if the HaloScheme `hs` is iteration-carried, i.e., if the exchanged halo
387- *only* uses values computed in a past iteration, False otherwise.
388- """
389- for f , hse in hs .fmapper .items ():
390- try :
391- writes = scope .writes [f ]
392- except KeyError :
393- continue
394-
395- if any (hse .loc_values .intersection (w .aindices ) for w in writes ):
396- return False
397-
398- return True
399-
400-
401382# Code motion rules -- if the retval is True, then it means the input `dep` is not
402383# a stopper to moving the HaloSpot `hs` around
403384
@@ -407,6 +388,7 @@ def _rule0(dep, hs, loc_indices):
407388 for d in dep .cause )
408389
409390
391+ #TODO: MAYBE AVOID PASSING IN LOC_INDICES AND PULL THEM STRAIGHT FROM HS
410392def _rule1 (dep , hs , loc_indices ):
411393 # E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
412394 return any (dep .distance_mapper [d ] == 0 and
0 commit comments