Skip to content

Commit 1687a21

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit 1687a21

5 files changed

Lines changed: 85 additions & 15 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -131,10 +132,11 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
131132
# AliasList -> Schedule
132133
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
133134

134-
variants.append(Variant(schedule, exprs))
135+
if schedule:
136+
variants.append(Variant(schedule, exprs))
135137

136138
if not variants:
137-
return []
139+
return [], []
138140

139141
# [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140142
schedule, exprs = self._select(variants)
@@ -144,9 +146,9 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
144146
schedule = optimize_schedule_rotations(schedule, self.sregistry)
145147

146148
# Schedule -> [Clusters]_k
147-
processed, subs = lower_schedule(schedule, meta, self.sregistry,
148-
self.opt_ftemps, self.opt_min_dtype,
149-
self.opt_minmem)
149+
processed, subs, inits = lower_schedule(schedule, meta, self.sregistry,
150+
self.opt_ftemps, self.opt_min_dtype,
151+
self.opt_minmem)
150152

151153
# [Clusters]_k -> [Clusters]_k (optimization)
152154
if self.opt_multisubdomain:
@@ -166,7 +168,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
166168

167169
assert len(exprs) == 0
168170

169-
return processed
171+
return processed, inits
170172

171173
def process(self, clusters):
172174
raise NotImplementedError
@@ -302,12 +304,13 @@ def callback(self, clusters, prefix, xtracted=None):
302304

303305
key = lambda c: self._lookup_key(c, d)
304306
processed = list(clusters)
307+
inits = []
305308
for ak, group in as_mapper(clusters, key=key).items():
306309
g = [c for c in group if c.is_dense and c not in xtracted]
307310
if not g:
308311
continue
309312

310-
made = self._aliases_from_clusters(ClusterGroup(g), exclude, ak)
313+
made, cinits = self._aliases_from_clusters(ClusterGroup(g), exclude, ak)
311314

312315
if made:
313316
idx = processed.index(g[0])
@@ -317,7 +320,10 @@ def callback(self, clusters, prefix, xtracted=None):
317320

318321
xtracted.extend(made)
319322

320-
return processed
323+
if cinits:
324+
inits.extend(cinits)
325+
326+
return inits + processed
321327

322328
def _lookup_key(self, c, d):
323329
ispace = c.ispace.reset()
@@ -390,7 +396,7 @@ def process(self, clusters):
390396
# TODO: to process third- and higher-order derivatives, we could
391397
# extend this by calling `_aliases_from_clusters` repeatedly until
392398
# `made` is empty. To be investigated
393-
made = self._aliases_from_clusters(
399+
made, _ = self._aliases_from_clusters(
394400
ClusterGroup(c), exclude, self._lookup_key(c)
395401
)
396402

@@ -860,6 +866,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860866
make = TempFunction if opt_ftemps else TempArray
861867

862868
clusters = []
869+
inits = []
863870
subs = {}
864871
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865872
name = sregistry.make_name()
@@ -928,8 +935,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928935
assert writeto.size == 0
929936

930937
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
938+
const = not meta.guards
939+
obj = Temp(name=name, dtype=dtype, is_const=const)
932940
expression = Eq(obj, uxreplace(pivot, subs))
941+
if not const:
942+
inits.append(Eq(obj, 0))
933943

934944
callback = lambda idx: obj # noqa: B023
935945

@@ -959,7 +969,14 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959969
# Finally, build the alias Cluster
960970
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961971

962-
return clusters, subs
972+
if inits:
973+
# To avoid undefined variables when an constant (Temp) alias is used
974+
# within different guards/loop, we need to initialize it outside of the loops
975+
# so that it's globally defined.
976+
# See tests/test_dse.py::TestAliases::test_split_cond
977+
inits = [Cluster(inits, null_ispace)]
978+
979+
return clusters, subs, inits
963980

964981

965982
def optimize_clusters_msds(clusters):

devito/passes/clusters/cse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def cse_dtype(exprdtype, cdtype):
4040
Return the dtype of a CSE temporary given the dtype of the expression to be
4141
captured and the cluster's dtype.
4242
"""
43+
if np.issubdtype(cdtype, np.floating) and np.issubdtype(exprdtype, np.integer):
44+
# Integer expression and floating-point cluster: promote to the floating point
45+
# np.promote_types upcast integers (e.g int32 -> Float64) so we
46+
# need to ensure that the promoted type is not larger than the cluster's dtype
47+
return cdtype
48+
4349
if np.issubdtype(cdtype, np.complexfloating):
4450
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
4551
else:
@@ -97,8 +103,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
97103
if cluster.is_fence:
98104
return cluster
99105

100-
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
101-
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
106+
def make(e):
107+
edtype = cse_dtype(e.dtype, dtype)
108+
return CTemp(name=sregistry.make_name(), dtype=edtype)
102109

103110
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
104111

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

tests/test_dse.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,52 @@ def test_space_and_time_invariant_together(self):
26852685
'tx0_blk0y0_blk0xyzyz'
26862686
)
26872687

2688+
def test_split_cond(self):
2689+
grid = Grid((11, 11))
2690+
time = grid.time_dim
2691+
2692+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2693+
2694+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2695+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2696+
2697+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2698+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2699+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2700+
2701+
op = Operator([eq0, eq1, eq2])
2702+
cond = FindNodes(Conditional).visit(op)
2703+
assert len(cond) == 3
2704+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'r0 = cos(time);'
2705+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2706+
2707+
def test_split_cond_multi_alias(self):
2708+
grid = Grid((11, 11))
2709+
time = grid.time_dim
2710+
2711+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2712+
2713+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2714+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2715+
2716+
eq0 = Eq(u.forward, u + cos(time) + sin(time), implicit_dims=ct)
2717+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2718+
eq2 = Eq(u.forward, u.forward + cos(time) - sin(time), implicit_dims=ct)
2719+
2720+
op = Operator([eq0, eq1, eq2])
2721+
cond = FindNodes(Conditional).visit(op)
2722+
assert len(cond) == 3
2723+
print(op)
2724+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'float r3 = cos(time);'
2725+
assert str(cond[0].args['then_body'][0].exprs[1]) == 'float r4 = sin(time);'
2726+
assert str(cond[0].args['then_body'][0].exprs[2]) == 'r0 = r3 + r4;'
2727+
assert str(cond[0].args['then_body'][0].exprs[3]) == 'r1 = r3;'
2728+
assert str(cond[0].args['then_body'][0].exprs[4]) == 'r2 = r4;'
2729+
2730+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2731+
assert str(op.body.body[0].body[0].body[1]) == 'float r1 = 0;'
2732+
assert str(op.body.body[0].body[0].body[2]) == 'float r2 = 0;'
2733+
26882734

26892735
class TestIsoAcoustic:
26902736

0 commit comments

Comments
 (0)