1212 Interval , IntervalGroup , IterationSpace , LabeledVector , Queue , Vector , extrema ,
1313 maximum , minimum , normalize_properties , relax_properties , unbounded , vmax , vmin
1414)
15+ from devito .ir .support import null_ispace
1516from devito .passes .clusters .cse import _cse
1617from devito .symbolics import (
1718 Uxmapper , estimate_cost , retrieve_functions , reuse_if_untouched , search , sympy_dtype ,
@@ -131,10 +132,11 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
131132 # AliasList -> Schedule
132133 schedule = lower_aliases (aliases , meta , self .opt_maxpar )
133134
134- variants .append (Variant (schedule , exprs ))
135+ if schedule :
136+ variants .append (Variant (schedule , exprs ))
135137
136138 if not variants :
137- return []
139+ return [], []
138140
139141 # [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140142 schedule , exprs = self ._select (variants )
@@ -144,9 +146,9 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
144146 schedule = optimize_schedule_rotations (schedule , self .sregistry )
145147
146148 # Schedule -> [Clusters]_k
147- processed , subs = lower_schedule (schedule , meta , self .sregistry ,
148- self .opt_ftemps , self .opt_min_dtype ,
149- self .opt_minmem )
149+ processed , subs , inits = lower_schedule (schedule , meta , self .sregistry ,
150+ self .opt_ftemps , self .opt_min_dtype ,
151+ self .opt_minmem )
150152
151153 # [Clusters]_k -> [Clusters]_k (optimization)
152154 if self .opt_multisubdomain :
@@ -166,7 +168,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
166168
167169 assert len (exprs ) == 0
168170
169- return processed
171+ return processed , inits
170172
171173 def process (self , clusters ):
172174 raise NotImplementedError
@@ -302,12 +304,13 @@ def callback(self, clusters, prefix, xtracted=None):
302304
303305 key = lambda c : self ._lookup_key (c , d )
304306 processed = list (clusters )
307+ inits = []
305308 for ak , group in as_mapper (clusters , key = key ).items ():
306309 g = [c for c in group if c .is_dense and c not in xtracted ]
307310 if not g :
308311 continue
309312
310- made = self ._aliases_from_clusters (ClusterGroup (g ), exclude , ak )
313+ made , cinits = self ._aliases_from_clusters (ClusterGroup (g ), exclude , ak )
311314
312315 if made :
313316 idx = processed .index (g [0 ])
@@ -317,7 +320,10 @@ def callback(self, clusters, prefix, xtracted=None):
317320
318321 xtracted .extend (made )
319322
320- return processed
323+ if cinits :
324+ inits .extend (cinits )
325+
326+ return inits + processed
321327
322328 def _lookup_key (self , c , d ):
323329 ispace = c .ispace .reset ()
@@ -390,7 +396,7 @@ def process(self, clusters):
390396 # TODO: to process third- and higher-order derivatives, we could
391397 # extend this by calling `_aliases_from_clusters` repeatedly until
392398 # `made` is empty. To be investigated
393- made = self ._aliases_from_clusters (
399+ made , _ = self ._aliases_from_clusters (
394400 ClusterGroup (c ), exclude , self ._lookup_key (c )
395401 )
396402
@@ -860,6 +866,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860866 make = TempFunction if opt_ftemps else TempArray
861867
862868 clusters = []
869+ inits = []
863870 subs = {}
864871 for pivot , writeto , ispace , aliaseds , indicess in schedule :
865872 name = sregistry .make_name ()
@@ -928,8 +935,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928935 assert writeto .size == 0
929936
930937 dtype = sympy_dtype (pivot , base = meta .dtype , smin = opt_min_dtype )
931- obj = Temp (name = name , dtype = dtype )
938+ const = not meta .guards
939+ obj = Temp (name = name , dtype = dtype , is_const = const )
932940 expression = Eq (obj , uxreplace (pivot , subs ))
941+ if not const :
942+ inits .append (Eq (obj , 0 ))
933943
934944 callback = lambda idx : obj # noqa: B023
935945
@@ -959,7 +969,14 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959969 # Finally, build the alias Cluster
960970 clusters .append (Cluster (expression , ispace , meta .guards , properties ))
961971
962- return clusters , subs
972+ if inits :
973+ # To avoid undefined variables when an constant (Temp) alias is used
974+ # within different guards/loop, we need to initialize it outside of the loops
975+ # so that it's globally defined.
976+ # See tests/test_dse.py::TestAliases::test_split_cond
977+ inits = [Cluster (inits , null_ispace )]
978+
979+ return clusters , subs , inits
963980
964981
965982def optimize_clusters_msds (clusters ):
0 commit comments