Skip to content

Commit 5e4ca20

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit 5e4ca20

5 files changed

Lines changed: 40 additions & 4 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -131,7 +132,8 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
131132
# AliasList -> Schedule
132133
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
133134

134-
variants.append(Variant(schedule, exprs))
135+
if schedule:
136+
variants.append(Variant(schedule, exprs))
135137

136138
if not variants:
137139
return []
@@ -860,6 +862,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860862
make = TempFunction if opt_ftemps else TempArray
861863

862864
clusters = []
865+
inits = []
863866
subs = {}
864867
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865868
name = sregistry.make_name()
@@ -928,8 +931,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928931
assert writeto.size == 0
929932

930933
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
934+
const = not meta.guards
935+
obj = Temp(name=name, dtype=dtype, is_const=const)
932936
expression = Eq(obj, uxreplace(pivot, subs))
937+
if not const:
938+
inits.append(Eq(obj, 0))
933939

934940
callback = lambda idx: obj # noqa: B023
935941

@@ -959,6 +965,13 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959965
# Finally, build the alias Cluster
960966
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961967

968+
if inits:
969+
# To avoid undefined variables when an constant (Temp) alias is used
970+
# within different guards/loop, we need to initialize it outside of the loops
971+
# so that it's globally defined.
972+
# See tests/test_dse.py::TestAliases::test_split_cond
973+
clusters.insert(0, Cluster(inits, null_ispace))
974+
962975
return clusters, subs
963976

964977

devito/passes/clusters/cse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,10 @@ def _(exprs):
400400

401401
@_catch.register(sympy.Eq)
402402
def _(expr):
403+
if isinstance(expr.lhs, Temp):
404+
# Prevent re-factorization of constant alias
405+
return {}
406+
403407
mapper = _catch(expr.rhs)
404408
try:
405409
cond = expr.conditionals

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

tests/test_dse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,25 @@ def test_space_and_time_invariant_together(self):
26852685
'tx0_blk0y0_blk0xyzyz'
26862686
)
26872687

2688+
def test_split_cond(self):
2689+
grid = Grid((11, 11))
2690+
time = grid.time_dim
2691+
2692+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2693+
2694+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2695+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2696+
2697+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2698+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2699+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2700+
2701+
op = Operator([eq0, eq1, eq2])
2702+
cond = FindNodes(Conditional).visit(op)
2703+
assert len(cond) == 3
2704+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'r0 = cos(time);'
2705+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2706+
26882707

26892708
class TestIsoAcoustic:
26902709

0 commit comments

Comments
 (0)