Skip to content

Commit 10bd3a7

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit 10bd3a7

6 files changed

Lines changed: 120 additions & 37 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -131,10 +132,11 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
131132
# AliasList -> Schedule
132133
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
133134

134-
variants.append(Variant(schedule, exprs))
135+
if schedule:
136+
variants.append(Variant(schedule, exprs))
135137

136138
if not variants:
137-
return []
139+
return [], []
138140

139141
# [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140142
schedule, exprs = self._select(variants)
@@ -144,9 +146,9 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
144146
schedule = optimize_schedule_rotations(schedule, self.sregistry)
145147

146148
# Schedule -> [Clusters]_k
147-
processed, subs = lower_schedule(schedule, meta, self.sregistry,
148-
self.opt_ftemps, self.opt_min_dtype,
149-
self.opt_minmem)
149+
processed, subs, inits = lower_schedule(schedule, meta, self.sregistry,
150+
self.opt_ftemps, self.opt_min_dtype,
151+
self.opt_minmem)
150152

151153
# [Clusters]_k -> [Clusters]_k (optimization)
152154
if self.opt_multisubdomain:
@@ -166,7 +168,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
166168

167169
assert len(exprs) == 0
168170

169-
return processed
171+
return processed, inits
170172

171173
def process(self, clusters):
172174
raise NotImplementedError
@@ -302,12 +304,13 @@ def callback(self, clusters, prefix, xtracted=None):
302304

303305
key = lambda c: self._lookup_key(c, d)
304306
processed = list(clusters)
307+
inits = []
305308
for ak, group in as_mapper(clusters, key=key).items():
306309
g = [c for c in group if c.is_dense and c not in xtracted]
307310
if not g:
308311
continue
309312

310-
made = self._aliases_from_clusters(ClusterGroup(g), exclude, ak)
313+
made, cinits = self._aliases_from_clusters(ClusterGroup(g), exclude, ak)
311314

312315
if made:
313316
idx = processed.index(g[0])
@@ -317,7 +320,10 @@ def callback(self, clusters, prefix, xtracted=None):
317320

318321
xtracted.extend(made)
319322

320-
return processed
323+
if cinits:
324+
inits.extend(cinits)
325+
326+
return inits + processed
321327

322328
def _lookup_key(self, c, d):
323329
ispace = c.ispace.reset()
@@ -390,7 +396,7 @@ def process(self, clusters):
390396
# TODO: to process third- and higher-order derivatives, we could
391397
# extend this by calling `_aliases_from_clusters` repeatedly until
392398
# `made` is empty. To be investigated
393-
made = self._aliases_from_clusters(
399+
made, _ = self._aliases_from_clusters(
394400
ClusterGroup(c), exclude, self._lookup_key(c)
395401
)
396402

@@ -860,6 +866,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860866
make = TempFunction if opt_ftemps else TempArray
861867

862868
clusters = []
869+
inits = []
863870
subs = {}
864871
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865872
name = sregistry.make_name()
@@ -928,8 +935,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928935
assert writeto.size == 0
929936

930937
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
938+
const = not meta.guards
939+
obj = Temp(name=name, dtype=dtype, is_const=const)
932940
expression = Eq(obj, uxreplace(pivot, subs))
941+
if not const:
942+
inits.append(Eq(obj, 0))
933943

934944
callback = lambda idx: obj # noqa: B023
935945

@@ -959,7 +969,14 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959969
# Finally, build the alias Cluster
960970
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961971

962-
return clusters, subs
972+
if inits:
973+
# To avoid undefined variables when an constant (Temp) alias is used
974+
# within different guards/loop, we need to initialize it outside of the loops
975+
# so that it's globally defined.
976+
# See tests/test_dse.py::TestAliases::test_split_cond
977+
inits = [Cluster(inits, null_ispace)]
978+
979+
return clusters, subs, inits
963980

964981

965982
def optimize_clusters_msds(clusters):

devito/passes/clusters/cse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def cse_dtype(exprdtype, cdtype):
4040
Return the dtype of a CSE temporary given the dtype of the expression to be
4141
captured and the cluster's dtype.
4242
"""
43+
if np.issubdtype(cdtype, np.floating) and np.issubdtype(exprdtype, np.integer):
44+
# Integer expression and floating-point cluster: promote to the floating point
45+
# np.promote_types upcast integers (e.g int32 -> Float64) so we
46+
# need to ensure that the promoted type is not larger than the cluster's dtype
47+
return cdtype
48+
4349
if np.issubdtype(cdtype, np.complexfloating):
4450
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
4551
else:
@@ -97,8 +103,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
97103
if cluster.is_fence:
98104
return cluster
99105

100-
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
101-
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
106+
def make(e):
107+
edtype = cse_dtype(e.dtype, dtype)
108+
return CTemp(name=sregistry.make_name(), dtype=edtype)
102109

103110
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
104111

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

examples/performance/00_overview.ipynb

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@
508508
"}\n",
509509
"STOP(section0,timers)\n",
510510
"\n",
511-
"float r1 = 1.0F/h_y;\n",
511+
"const float r1 = 1.0F/h_y;\n",
512512
"\n",
513513
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
514514
"{\n",
@@ -572,7 +572,13 @@
572572
"+ u[t1][x + 4][y + 4][z + 4] = (f[x + 1][y + 1][z + 1]*f[x + 1][y + 1][z + 1])*((-6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 1][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 2][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 5][z + 4]) + (-8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 5][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 7][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 8][z + 4]) + (8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 1][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 3][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4]) + (6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 3][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 6][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 7][z + 4]))*sinf(f[x + 1][y + 1][z + 1]);\n",
573573
" }\n",
574574
" }\n",
575-
" }\n",
575+
" }\n"
576+
]
577+
},
578+
{
579+
"name": "stdout",
580+
"output_type": "stream",
581+
"text": [
576582
"\n"
577583
]
578584
}
@@ -652,7 +658,7 @@
652658
},
653659
{
654660
"cell_type": "code",
655-
"execution_count": 13,
661+
"execution_count": 12,
656662
"metadata": {},
657663
"outputs": [
658664
{
@@ -712,7 +718,7 @@
712718
},
713719
{
714720
"cell_type": "code",
715-
"execution_count": 14,
721+
"execution_count": 13,
716722
"metadata": {},
717723
"outputs": [
718724
{
@@ -753,7 +759,14 @@
753759
" }\n",
754760
" }\n",
755761
" STOP(section0,timers)\n",
756-
"}\n"
762+
"}"
763+
]
764+
},
765+
{
766+
"name": "stdout",
767+
"output_type": "stream",
768+
"text": [
769+
"\n"
757770
]
758771
}
759772
],
@@ -772,7 +785,7 @@
772785
},
773786
{
774787
"cell_type": "code",
775-
"execution_count": 15,
788+
"execution_count": 14,
776789
"metadata": {},
777790
"outputs": [
778791
{
@@ -863,7 +876,7 @@
863876
},
864877
{
865878
"cell_type": "code",
866-
"execution_count": 16,
879+
"execution_count": 15,
867880
"metadata": {},
868881
"outputs": [
869882
{
@@ -919,7 +932,7 @@
919932
},
920933
{
921934
"cell_type": "code",
922-
"execution_count": 17,
935+
"execution_count": 16,
923936
"metadata": {},
924937
"outputs": [
925938
{
@@ -976,7 +989,7 @@
976989
},
977990
{
978991
"cell_type": "code",
979-
"execution_count": 18,
992+
"execution_count": 17,
980993
"metadata": {},
981994
"outputs": [],
982995
"source": [
@@ -994,7 +1007,7 @@
9941007
},
9951008
{
9961009
"cell_type": "code",
997-
"execution_count": 19,
1010+
"execution_count": 18,
9981011
"metadata": {},
9991012
"outputs": [
10001013
{
@@ -1044,7 +1057,7 @@
10441057
},
10451058
{
10461059
"cell_type": "code",
1047-
"execution_count": 20,
1060+
"execution_count": 19,
10481061
"metadata": {},
10491062
"outputs": [
10501063
{
@@ -1112,7 +1125,7 @@
11121125
},
11131126
{
11141127
"cell_type": "code",
1115-
"execution_count": 21,
1128+
"execution_count": 20,
11161129
"metadata": {},
11171130
"outputs": [
11181131
{
@@ -1192,7 +1205,7 @@
11921205
" }\n",
11931206
" STOP(section0,timers)\n",
11941207
"\n",
1195-
" float r1 = 1.0F/h_y;\n",
1208+
" const float r1 = 1.0F/h_y;\n",
11961209
"\n",
11971210
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
11981211
" {\n",
@@ -1279,7 +1292,7 @@
12791292
},
12801293
{
12811294
"cell_type": "code",
1282-
"execution_count": 22,
1295+
"execution_count": 21,
12831296
"metadata": {},
12841297
"outputs": [
12851298
{
@@ -1304,7 +1317,7 @@
13041317
"}\n",
13051318
"STOP(section0,timers)\n",
13061319
"\n",
1307-
"float r1 = 1.0F/h_y;\n",
1320+
"const float r1 = 1.0F/h_y;\n",
13081321
"\n",
13091322
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
13101323
"{\n",
@@ -1369,7 +1382,7 @@
13691382
},
13701383
{
13711384
"cell_type": "code",
1372-
"execution_count": 23,
1385+
"execution_count": 22,
13731386
"metadata": {},
13741387
"outputs": [
13751388
{
@@ -1404,7 +1417,7 @@
14041417
},
14051418
{
14061419
"cell_type": "code",
1407-
"execution_count": 24,
1420+
"execution_count": 23,
14081421
"metadata": {},
14091422
"outputs": [
14101423
{
@@ -1483,7 +1496,7 @@
14831496
" }\n",
14841497
" STOP(section0,timers)\n",
14851498
"\n",
1486-
" float r1 = 1.0F/h_y;\n",
1499+
" const float r1 = 1.0F/h_y;\n",
14871500
"\n",
14881501
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
14891502
" {\n",
@@ -1546,7 +1559,7 @@
15461559
},
15471560
{
15481561
"cell_type": "code",
1549-
"execution_count": 25,
1562+
"execution_count": 24,
15501563
"metadata": {},
15511564
"outputs": [
15521565
{
@@ -1624,8 +1637,8 @@
16241637
" }\n",
16251638
" STOP(section0,timers)\n",
16261639
"\n",
1627-
" float r1 = 1.0F/h_x;\n",
1628-
" float r2 = 1.0F/h_y;\n",
1640+
" const float r1 = 1.0F/h_x;\n",
1641+
" const float r2 = 1.0F/h_y;\n",
16291642
"\n",
16301643
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
16311644
" {\n",
@@ -1718,7 +1731,7 @@
17181731
"name": "python",
17191732
"nbconvert_exporter": "python",
17201733
"pygments_lexer": "ipython3",
1721-
"version": "3.10.12"
1734+
"version": "3.13.11"
17221735
}
17231736
},
17241737
"nbformat": 4,

0 commit comments

Comments
 (0)