Skip to content

Commit e2045e5

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit e2045e5

5 files changed

Lines changed: 45 additions & 6 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -131,7 +132,8 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
131132
# AliasList -> Schedule
132133
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
133134

134-
variants.append(Variant(schedule, exprs))
135+
if schedule:
136+
variants.append(Variant(schedule, exprs))
135137

136138
if not variants:
137139
return []
@@ -860,6 +862,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860862
make = TempFunction if opt_ftemps else TempArray
861863

862864
clusters = []
865+
inits = []
863866
subs = {}
864867
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865868
name = sregistry.make_name()
@@ -928,8 +931,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928931
assert writeto.size == 0
929932

930933
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
934+
const = not meta.guards
935+
obj = Temp(name=name, dtype=dtype, is_const=const)
932936
expression = Eq(obj, uxreplace(pivot, subs))
937+
if not const:
938+
inits.append(Eq(obj, 0))
933939

934940
callback = lambda idx: obj # noqa: B023
935941

@@ -959,6 +965,13 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959965
# Finally, build the alias Cluster
960966
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961967

968+
if inits:
969+
# To avoid undefined variables when an constant (Temp) alias is used
970+
# within different guards/loop, we need to initialize it outside of the loops
971+
# so that it's globally defined.
972+
# See tests/test_dse.py::TestAliases::test_split_cond
973+
clusters.insert(0, Cluster(inits, null_ispace))
974+
962975
return clusters, subs
963976

964977

devito/passes/clusters/cse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def cse_dtype(exprdtype, cdtype):
4040
Return the dtype of a CSE temporary given the dtype of the expression to be
4141
captured and the cluster's dtype.
4242
"""
43+
if np.issubdtype(cdtype, np.floating) and np.issubdtype(exprdtype, np.integer):
44+
# Integer expression and floating-point cluster: promote to the floating point
45+
# np.promote_types upcast integers (e.g int32 -> Float64) so we
46+
# need to ensure that the promoted type is not larger than the cluster's dtype
47+
return cdtype
48+
4349
if np.issubdtype(cdtype, np.complexfloating):
4450
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
4551
else:
@@ -97,8 +103,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
97103
if cluster.is_fence:
98104
return cluster
99105

100-
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
101-
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
106+
def make(e):
107+
edtype = cse_dtype(e.dtype, dtype)
108+
return CTemp(name=sregistry.make_name(), dtype=edtype)
102109

103110
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
104111

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

tests/test_dse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,25 @@ def test_space_and_time_invariant_together(self):
26852685
'tx0_blk0y0_blk0xyzyz'
26862686
)
26872687

2688+
def test_split_cond(self):
2689+
grid = Grid((11, 11))
2690+
time = grid.time_dim
2691+
2692+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2693+
2694+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2695+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2696+
2697+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2698+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2699+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2700+
2701+
op = Operator([eq0, eq1, eq2])
2702+
cond = FindNodes(Conditional).visit(op)
2703+
assert len(cond) == 3
2704+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'r0 = cos(time);'
2705+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2706+
26882707

26892708
class TestIsoAcoustic:
26902709

0 commit comments

Comments
 (0)