Skip to content

Commit 7fc4e81

Browse files
committed
compiler: prevent undefined Temp through global init
1 parent 8ad5ec1 commit 7fc4e81

3 files changed

Lines changed: 33 additions & 2 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Interval, IntervalGroup, IterationSpace, LabeledVector, Queue, Vector, extrema,
1313
maximum, minimum, normalize_properties, relax_properties, unbounded, vmax, vmin
1414
)
15+
from devito.ir.support import null_ispace
1516
from devito.passes.clusters.cse import _cse
1617
from devito.symbolics import (
1718
Uxmapper, estimate_cost, retrieve_functions, reuse_if_untouched, search, sympy_dtype,
@@ -860,6 +861,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
860861
make = TempFunction if opt_ftemps else TempArray
861862

862863
clusters = []
864+
inits = []
863865
subs = {}
864866
for pivot, writeto, ispace, aliaseds, indicess in schedule:
865867
name = sregistry.make_name()
@@ -928,8 +930,11 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928930
assert writeto.size == 0
929931

930932
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
933+
const = meta.guards is None
934+
obj = Temp(name=name, dtype=dtype, is_const=const)
932935
expression = Eq(obj, uxreplace(pivot, subs))
936+
if not const:
937+
inits.append(Eq(obj, 0))
933938

934939
callback = lambda idx: obj # noqa: B023
935940

@@ -959,6 +964,13 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
959964
# Finally, build the alias Cluster
960965
clusters.append(Cluster(expression, ispace, meta.guards, properties))
961966

967+
if inits:
968+
# To avoid undefined variables when an constant (Temp) alias is used
969+
# within different guards/loop, we need to initialize it outside of the loops
970+
# so that it's globally defined.
971+
# See tests/test_operators.py
972+
clusters.insert(0, Cluster(inits, null_ispace))
973+
962974
return clusters, subs
963975

964976

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

tests/test_dse.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,25 @@ def test_space_and_time_invariant_together(self):
26852685
'tx0_blk0y0_blk0xyzyz'
26862686
)
26872687

2688+
def test_split_cond(self):
2689+
grid = Grid((11, 11))
2690+
time = grid.time_dim
2691+
2692+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2693+
2694+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2695+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2696+
2697+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2698+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2699+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2700+
2701+
op = Operator([eq0, eq1, eq2])
2702+
cond = FindNodes(Conditional).visit(op)
2703+
assert len(cond) == 3
2704+
assert str(cond[0].args['then_body'][0].exprs[0]) == 'r0 = cos(time);'
2705+
assert str(op.body.body[0].body[0].body[0]) == 'float r0 = 0;'
2706+
26882707

26892708
class TestIsoAcoustic:
26902709

0 commit comments

Comments
 (0)